Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(subscription): add configurable heartbeat for websocket protocol #4802

Merged
merged 11 commits into from Mar 20, 2024
17 changes: 17 additions & 0 deletions .changesets/feat_subscription_websocket_heartbeat.md
@@ -0,0 +1,17 @@
### feat(subscription): add configurable heartbeat for websocket protocol ([Issue #4621](https://github.com/apollographql/router/issues/4621))

Add the ability to enable heartbeat for cases where the subgraph drops idle connections.
For example, https://netflix.github.io/dgs/

Example of configuration:
abernix marked this conversation as resolved.
Show resolved Hide resolved

```yaml
subscription:
mode:
passthrough:
all:
path: /graphql
heartbeat_interval: enable #Optional
IvanGoncharov marked this conversation as resolved.
Show resolved Hide resolved
```

By [@IvanGoncharov](https://github.com/IvanGoncharov) in https://github.com/apollographql/router/pull/4802
Expand Up @@ -2404,14 +2404,21 @@ expression: "&schema"
"properties": {
"heartbeat_interval": {
"description": "Heartbeat interval for callback mode (default: 5secs)",
"default": "5s",
"default": "enabled",
"anyOf": [
{
"type": "string",
"enum": [
"disabled"
]
},
{
"description": "enable with default interval of 5s",
"type": "string",
"enum": [
"enabled"
]
},
{
"type": "string"
}
Expand Down Expand Up @@ -2464,6 +2471,28 @@ expression: "&schema"
"default": null,
"type": "object",
"properties": {
"heartbeat_interval": {
"description": "Heartbeat interval for graphql-ws protocol (default: disabled)",
"default": "disabled",
"anyOf": [
{
"type": "string",
"enum": [
"disabled"
]
},
{
"description": "enable with default interval of 5s",
"type": "string",
"enum": [
"enabled"
]
},
{
"type": "string"
}
]
},
"path": {
"description": "Path on which WebSockets are listening",
"default": null,
Expand Down Expand Up @@ -2491,6 +2520,28 @@ expression: "&schema"
"description": "WebSocket configuration for a specific subgraph",
"type": "object",
"properties": {
"heartbeat_interval": {
"description": "Heartbeat interval for graphql-ws protocol (default: disabled)",
"default": "disabled",
"anyOf": [
{
"type": "string",
"enum": [
"disabled"
]
},
{
"description": "enable with default interval of 5s",
"type": "string",
"enum": [
"enabled"
]
},
{
"type": "string"
}
]
},
"path": {
"description": "Path on which WebSockets are listening",
"default": null,
Expand Down
69 changes: 44 additions & 25 deletions apollo-router/src/plugins/subscription.rs
Expand Up @@ -112,7 +112,7 @@ impl SubscriptionModeConfig {
if callback_cfg.subgraphs.contains(service_name) || callback_cfg.subgraphs.is_empty() {
let callback_cfg = CallbackMode {
public_url: callback_cfg.public_url.clone(),
heartbeat_interval: callback_cfg.heartbeat_interval.clone(),
heartbeat_interval: callback_cfg.heartbeat_interval,
listen: callback_cfg.listen.clone(),
path: callback_cfg.path.clone(),
subgraphs: HashSet::new(), // We don't need it
Expand Down Expand Up @@ -151,7 +151,7 @@ pub(crate) struct CallbackMode {
pub(crate) public_url: url::Url,

/// Heartbeat interval for callback mode (default: 5secs)
#[serde(default = "HeartbeatInterval::default")]
#[serde(default = "HeartbeatInterval::new_enabled")]
pub(crate) heartbeat_interval: HeartbeatInterval,
// `skip_serializing` We don't need it in the context
/// Listen address on which the callback must listen (default: 127.0.0.1:4000)
Expand All @@ -168,25 +168,43 @@ pub(crate) struct CallbackMode {
pub(crate) subgraphs: HashSet<String>,
}

#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, JsonSchema)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Deserialize, Serialize, JsonSchema)]
#[serde(rename_all = "snake_case", untagged)]
pub(crate) enum HeartbeatInterval {
Disabled(Disabled),
IvanGoncharov marked this conversation as resolved.
Show resolved Hide resolved
/// enable with default interval of 5s
Enabled(Enabled),
#[serde(with = "humantime_serde")]
#[schemars(with = "String")]
Duration(Duration),
IvanGoncharov marked this conversation as resolved.
Show resolved Hide resolved
}

#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, JsonSchema)]
impl HeartbeatInterval {
pub(crate) fn new_enabled() -> Self {
Self::Enabled(Enabled::Enabled)
}
pub(crate) fn new_disabled() -> Self {
Self::Disabled(Disabled::Disabled)
}
pub(crate) fn into_option(self) -> Option<Duration> {
match self {
Self::Disabled(_) => None,
Self::Enabled(_) => Some(Duration::from_secs(5)),
Self::Duration(duration) => Some(duration),
}
}
}

#[derive(Debug, Copy, Clone, PartialEq, Eq, Deserialize, Serialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub(crate) enum Disabled {
Disabled,
}

impl Default for HeartbeatInterval {
fn default() -> Self {
Self::Duration(Duration::from_secs(5))
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Deserialize, Serialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub(crate) enum Enabled {
Enabled,
}

/// Using websocket to directly connect to subgraph
Expand All @@ -197,14 +215,19 @@ pub(crate) struct PassthroughMode {
subgraph: SubgraphPassthroughMode,
}

#[derive(Debug, Clone, PartialEq, Eq, Default, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields, default)]
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields)]
/// WebSocket configuration for a specific subgraph
pub(crate) struct WebSocketConfiguration {
/// Path on which WebSockets are listening
#[serde(default)]
pub(crate) path: Option<String>,
/// Which WebSocket GraphQL protocol to use for this subgraph possible values are: 'graphql_ws' | 'graphql_transport_ws' (default: graphql_ws)
#[serde(default)]
pub(crate) protocol: WebSocketProtocol,
/// Heartbeat interval for graphql-ws protocol (default: disabled)
#[serde(default = "HeartbeatInterval::new_disabled")]
pub(crate) heartbeat_interval: HeartbeatInterval,
}

fn default_path() -> String {
Expand All @@ -228,21 +251,17 @@ impl Plugin for Subscription {
.clone(),
);
#[cfg(not(test))]
match init
.config
.mode
.callback
.as_ref()
.expect("we checked in the condition the callback conf")
.heartbeat_interval
{
HeartbeatInterval::Duration(duration) => {
init.notify.set_ttl(Some(duration)).await?;
}
HeartbeatInterval::Disabled(_) => {
init.notify.set_ttl(None).await?;
}
}
init.notify
.set_ttl(
init.config
.mode
.callback
.as_ref()
.expect("we checked in the condition the callback conf")
.heartbeat_interval
.into_option(),
)
.await?;
}

Ok(Subscription {
Expand Down
48 changes: 44 additions & 4 deletions apollo-router/src/protocols/websocket.rs
Expand Up @@ -219,6 +219,7 @@ pub(crate) struct GraphqlWebSocket<S> {
stream: S,
id: String,
protocol: WebSocketProtocol,
heartbeat_interval: Option<tokio::time::Interval>,
// Booleans for state machine when closing the stream
completed: bool,
terminated: bool,
Expand All @@ -234,6 +235,7 @@ where
id: String,
protocol: WebSocketProtocol,
connection_params: Option<Value>,
heartbeat_interval: Option<tokio::time::Duration>,
) -> Result<Self, graphql::Error> {
let connection_init_msg = match connection_params {
Some(connection_params) => ClientMessage::ConnectionInit {
Expand Down Expand Up @@ -285,10 +287,22 @@ where
.build());
}

let heartbeat_interval = if protocol == WebSocketProtocol::GraphqlWs {
heartbeat_interval.map(|duration| {
let mut interval =
tokio::time::interval_at(tokio::time::Instant::now() + duration, duration);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
interval
})
} else {
None
};

Ok(Self {
stream,
id,
protocol,
heartbeat_interval,
completed: false,
terminated: false,
})
Expand Down Expand Up @@ -372,11 +386,15 @@ where
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let mut this = self.as_mut().project();
let mut stream = Pin::new(&mut this.stream);

match Pin::new(&mut this.stream).poll_next(cx) {
match stream.as_mut().poll_next(cx) {
Poll::Ready(message) => match message {
Some(server_message) => match server_message {
Ok(server_message) => {
if let Some(heartbeat_interval) = this.heartbeat_interval {
heartbeat_interval.reset();
}
if let Some(id) = &server_message.id() {
if this.id != id {
tracing::error!("we should not receive data from other subscriptions, closing the stream");
Expand All @@ -386,8 +404,7 @@ where
if let ServerMessage::Ping { .. } = server_message {
// Send pong asynchronously
let _ = Pin::new(
&mut Pin::new(&mut this.stream)
.send(ClientMessage::Pong { payload: None }),
&mut stream.as_mut().send(ClientMessage::Pong { payload: None }),
)
.poll(cx);
}
Expand All @@ -414,11 +431,32 @@ where
},
None => Poll::Ready(None),
},
Poll::Pending => Poll::Pending,
Poll::Pending => {
if let Some(heartbeat_interval) = this.heartbeat_interval {
match heartbeat_interval.poll_tick(cx) {
Poll::Ready(_) => send_heartbeat(this.stream, cx),
Poll::Pending => (),
};
}
Poll::Pending
}
}
}
}

fn send_heartbeat(mut stream: Pin<&mut impl Sink<ClientMessage>>, cx: &mut std::task::Context<'_>) {
if stream.as_mut().poll_flush(cx).map(|result| result.is_ok()) == Poll::Ready(true)
&& stream.as_mut().poll_ready(cx).map(|result| result.is_ok()) == Poll::Ready(true)
IvanGoncharov marked this conversation as resolved.
Show resolved Hide resolved
{
// Ignore error
let _ = stream.start_send(ClientMessage::Ping {
payload: Some(serde_json_bytes::Value::String(
"APOLLO_ROUTER_HEARTBEAT".into(),
bnjjj marked this conversation as resolved.
Show resolved Hide resolved
)),
});
}
}

impl<S> Sink<graphql::Request> for GraphqlWebSocket<S>
where
S: Stream<Item = serde_json::Result<ServerMessage>> + Sink<ClientMessage>,
Expand Down Expand Up @@ -800,6 +838,7 @@ mod tests {
Some(serde_json_bytes::json!({
"token": "XXX"
})),
None,
)
.await
.unwrap();
Expand Down Expand Up @@ -865,6 +904,7 @@ mod tests {
sub_uuid.to_string(),
WebSocketProtocol::SubscriptionsTransportWs,
None,
None,
)
.await
.unwrap();
Expand Down
17 changes: 8 additions & 9 deletions apollo-router/src/services/subgraph_service.rs
Expand Up @@ -47,7 +47,6 @@ use crate::plugins::authentication::subgraph::SigningParamsConfig;
use crate::plugins::file_uploads;
use crate::plugins::subscription::create_verifier;
use crate::plugins::subscription::CallbackMode;
use crate::plugins::subscription::HeartbeatInterval;
use crate::plugins::subscription::SubscriptionConfig;
use crate::plugins::subscription::SubscriptionMode;
use crate::plugins::subscription::WebSocketConfiguration;
Expand Down Expand Up @@ -330,12 +329,10 @@ impl tower::Service<SubgraphRequest> for SubgraphService {
subscription_id,
callback_url,
verifier,
heartbeat_interval_ms: match heartbeat_interval {
HeartbeatInterval::Disabled(_) => 0,
HeartbeatInterval::Duration(duration) => {
duration.as_millis() as u64
}
},
heartbeat_interval_ms: heartbeat_interval
.into_option()
.map(|duration| duration.as_millis() as u64)
.unwrap_or(0),
};
body.extensions.insert(
"subscription",
Expand Down Expand Up @@ -579,6 +576,7 @@ async fn call_websocket(
subscription_hash,
subgraph_cfg.protocol,
connection_params,
subgraph_cfg.heartbeat_interval.into_option(),
)
.await
.map_err(|_| FetchError::SubrequestWsError {
Expand Down Expand Up @@ -1054,7 +1052,7 @@ mod tests {
use crate::graphql::Error;
use crate::graphql::Request;
use crate::graphql::Response;
use crate::plugins::subscription::Disabled;
use crate::plugins::subscription::HeartbeatInterval;
use crate::plugins::subscription::SubgraphPassthroughMode;
use crate::plugins::subscription::SubscriptionModeConfig;
use crate::plugins::subscription::SUBSCRIPTION_CALLBACK_HMAC_KEY;
Expand Down Expand Up @@ -1639,7 +1637,7 @@ mod tests {
listen: None,
path: Some("/testcallback".to_string()),
subgraphs: vec![String::from("testbis")].into_iter().collect(),
heartbeat_interval: HeartbeatInterval::Disabled(Disabled::Disabled),
heartbeat_interval: HeartbeatInterval::new_disabled(),
}),
passthrough: Some(SubgraphPassthroughMode {
all: None,
Expand All @@ -1648,6 +1646,7 @@ mod tests {
WebSocketConfiguration {
path: Some(String::from("/ws")),
protocol: WebSocketProtocol::default(),
heartbeat_interval: HeartbeatInterval::new_disabled(),
},
)]
.into(),
Expand Down