Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(cluster): refactor flight service actions #15419

Merged
merged 12 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions src/query/service/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ derive-visitor = { workspace = true }
ethnum = { workspace = true }
futures = { workspace = true }
futures-util = { workspace = true }
h2 = "0.3.26"
headers = "0.4.0"
highway = "1.1"
http = { workspace = true }
Expand All @@ -151,7 +150,7 @@ opentelemetry_sdk = { workspace = true }
parking_lot = { workspace = true }
parquet = { workspace = true }
paste = "1.0.9"
petgraph = "0.6.2"
petgraph = { version = "0.6.2", features = ["serde-1"] }
pin-project-lite = "0.2.9"
poem = { workspace = true }
prost = { workspace = true }
Expand Down
117 changes: 21 additions & 96 deletions src/query/service/src/schedulers/fragments/query_fragment_actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@ use databend_common_exception::ErrorCode;
use databend_common_exception::Result;
use databend_common_expression::DataSchemaRef;
use databend_common_meta_types::NodeInfo;
use itertools::Itertools;

use crate::clusters::ClusterHelper;
use crate::servers::flight::v1::exchange::DataExchange;
use crate::servers::flight::v1::packets::ConnectionInfo;
use crate::servers::flight::v1::packets::DataflowDiagramBuilder;
use crate::servers::flight::v1::packets::ExecutePartialQueryPacket;
use crate::servers::flight::v1::packets::FragmentPlanPacket;
use crate::servers::flight::v1::packets::InitNodesChannelPacket;
use crate::servers::flight::v1::packets::QueryEnv;
use crate::servers::flight::v1::packets::QueryFragmentsPlanPacket;
use crate::sessions::QueryContext;
use crate::sessions::TableContext;
Expand Down Expand Up @@ -205,56 +204,20 @@ impl QueryFragmentsActions {
))
}

pub fn get_init_nodes_channel_packets(&self) -> Result<Vec<InitNodesChannelPacket>> {
let nodes_info = Self::nodes_info(&self.ctx);
let local_id = self.ctx.get_cluster().local_id.clone();
let connections_info = self.fragments_connections();
let statistics_connections = self.statistics_connections();

let mut init_nodes_channel_packets = Vec::with_capacity(connections_info.len());

for (executor, fragments_connections) in &connections_info {
if !nodes_info.contains_key(executor) {
return Err(ErrorCode::NotFoundClusterNode(format!(
"Not found node {} in cluster. cluster nodes: {:?}",
executor,
nodes_info.keys().cloned().collect::<Vec<_>>()
)));
}

let executor_node_info = &nodes_info[executor];
let mut connections_info = Vec::with_capacity(fragments_connections.len());

for (source, fragments) in fragments_connections {
if !nodes_info.contains_key(source) {
return Err(ErrorCode::NotFoundClusterNode(format!(
"Not found node {} in cluster. cluster nodes: {:?}",
source,
nodes_info.keys().cloned().collect::<Vec<_>>()
)));
}
pub fn get_query_env(&self) -> Result<QueryEnv> {
let mut builder = DataflowDiagramBuilder::create(self.ctx.get_cluster().nodes.clone());

connections_info.push(ConnectionInfo::create(
nodes_info[source].clone(),
fragments.iter().cloned().unique().collect::<Vec<_>>(),
));
}
self.fragments_connections(&mut builder)?;
self.statistics_connections(&mut builder)?;

init_nodes_channel_packets.push(InitNodesChannelPacket::create(
self.ctx.get_id(),
executor_node_info.clone(),
connections_info,
match executor_node_info.id == local_id {
true => statistics_connections.clone(),
false => vec![],
},
self.ctx
.get_settings()
.get_create_query_flight_client_with_current_rt()?,
));
}

Ok(init_nodes_channel_packets)
Ok(QueryEnv {
query_id: self.ctx.get_id(),
dataflow_diagram: Arc::new(builder.build()),
create_rpc_clint_with_current_rt: self
.ctx
.get_settings()
.get_create_query_flight_client_with_current_rt()?,
})
}

pub fn get_execute_partial_query_packets(&self) -> Result<Vec<ExecutePartialQueryPacket>> {
Expand All @@ -273,9 +236,7 @@ impl QueryFragmentsActions {
}

/// unique map(target, map(source, vec(fragment_id)))
fn fragments_connections(&self) -> HashMap<String, HashMap<String, Vec<usize>>> {
let mut target_source_fragments = HashMap::<String, HashMap<String, Vec<usize>>>::new();

fn fragments_connections(&self, builder: &mut DataflowDiagramBuilder) -> Result<()> {
for fragment_actions in &self.fragments_actions {
if let Some(exchange) = &fragment_actions.data_exchange {
let fragment_id = fragment_actions.fragment_id;
Expand All @@ -285,59 +246,23 @@ impl QueryFragmentsActions {
let source = fragment_action.executor.to_string();

for destination in &destinations {
if &source == destination {
continue;
}

if target_source_fragments.contains_key(destination) {
let source_fragments = target_source_fragments
.get_mut(destination)
.expect("Target fragments expect source");

if source_fragments.contains_key(&source) {
source_fragments
.get_mut(&source)
.expect("Source target fragments expect destination")
.push(fragment_id);

continue;
}
}

if target_source_fragments.contains_key(destination) {
let source_fragments = target_source_fragments
.get_mut(destination)
.expect("Target fragments expect source");

source_fragments.insert(source.clone(), vec![fragment_id]);
continue;
}

let mut target_fragments = HashMap::new();
target_fragments.insert(source.clone(), vec![fragment_id]);
target_source_fragments.insert(destination.clone(), target_fragments);
builder.add_data_edge(&source, destination, fragment_id)?;
}
}
}
}

target_source_fragments
Ok(())
}

fn statistics_connections(&self) -> Vec<ConnectionInfo> {
fn statistics_connections(&self, builder: &mut DataflowDiagramBuilder) -> Result<()> {
let local_id = self.ctx.get_cluster().local_id.clone();
let nodes_info = Self::nodes_info(&self.ctx);
let mut target_source_connections = Vec::with_capacity(nodes_info.len());

for (id, node_info) in nodes_info {
if local_id == id {
continue;
}

target_source_connections.push(ConnectionInfo::create(node_info, vec![]));
for (_id, node_info) in Self::nodes_info(&self.ctx) {
builder.add_statistics_edge(&node_info.id, &local_id)?;
}

target_source_connections
Ok(())
}

fn nodes_info(ctx: &Arc<QueryContext>) -> HashMap<String, Arc<NodeInfo>> {
Expand Down
89 changes: 35 additions & 54 deletions src/query/service/src/servers/flight/flight_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::convert::TryInto;
use std::error::Error;
use std::sync::Arc;

use async_channel::Receiver;
Expand All @@ -31,14 +29,15 @@ use futures_util::future::Either;
use minitrace::full_name;
use minitrace::future::FutureExt;
use minitrace::Span;
use serde::Deserialize;
use serde::Serialize;
use tonic::transport::channel::Channel;
use tonic::Request;
use tonic::Status;
use tonic::Streaming;

use crate::pipelines::executor::WatchNotify;
use crate::servers::flight::request_builder::RequestBuilder;
use crate::servers::flight::v1::actions::FlightAction;
use crate::servers::flight::v1::packets::DataPacket;

pub struct FlightClient {
Expand All @@ -55,12 +54,40 @@ impl FlightClient {
}

#[async_backtrace::framed]
pub async fn execute_action(&mut self, action: FlightAction, timeout: u64) -> Result<()> {
if let Err(cause) = self.do_action(action, timeout).await {
return Err(cause.add_message_back("(while in query flight)"));
}
#[minitrace::trace]
pub async fn do_action<T, Res>(&mut self, path: &str, message: T, timeout: u64) -> Result<Res>
where
T: Serialize,
Res: for<'a> Deserialize<'a>,
{
let mut request =
databend_common_tracing::inject_span_to_tonic_request(Request::new(Action {
r#type: path.to_string(),
body: serde_json::to_vec(&message).map_err(|cause| {
ErrorCode::BadArguments(format!(
"Request payload serialize error while in {:?}, cause: {}",
path, cause
))
})?,
}));

drop(message);
request.set_timeout(Duration::from_secs(timeout));

Ok(())
let response = self.inner.do_action(request).await?;

match response.into_inner().message().await? {
Some(response) => serde_json::from_slice::<Res>(&response.body).map_err(|cause| {
ErrorCode::BadBytes(format!(
"Response payload deserialize error while in {:?}, cause: {}",
path, cause
))
}),
None => Err(ErrorCode::EmptyDataFromServer(format!(
"Can not receive data from flight server, action: {:?}",
path
))),
}
}

#[async_backtrace::framed]
Expand Down Expand Up @@ -158,27 +185,6 @@ impl FlightClient {
Err(status) => Err(ErrorCode::from(status).add_message_back("(while in query flight)")),
}
}

// Execute do_action.
#[async_backtrace::framed]
#[minitrace::trace]
async fn do_action(&mut self, action: FlightAction, timeout: u64) -> Result<Vec<u8>> {
let action: Action = action.try_into()?;
let action_type = action.r#type.clone();
let request = Request::new(action);
let mut request = databend_common_tracing::inject_span_to_tonic_request(request);
request.set_timeout(Duration::from_secs(timeout));

let response = self.inner.do_action(request).await?;

match response.into_inner().message().await? {
Some(response) => Ok(response.body),
None => Err(ErrorCode::EmptyDataFromServer(format!(
"Can not receive data from flight server, action: {:?}",
action_type
))),
}
}
}

pub struct FlightReceiver {
Expand Down Expand Up @@ -284,28 +290,3 @@ impl FlightExchange {
}
}
}

#[allow(dead_code)]
fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
let mut err: &(dyn Error + 'static) = err_status;

loop {
if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
return Some(io_err);
}

// h2::Error do not expose std::io::Error with `source()`
// https://github.com/hyperium/h2/pull/462
use h2::Error as h2Error;
if let Some(h2_err) = err.downcast_ref::<h2Error>() {
if let Some(io_err) = h2_err.get_io() {
return Some(io_err);
}
}

err = match err.source() {
Some(err) => err,
None => return None,
};
}
}