Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: How to get server to send ping messages on subscriptions? #1481

Open
hzargar2 opened this issue Mar 7, 2024 · 2 comments
Open
Labels
question Further information is requested

Comments

@hzargar2
Copy link

hzargar2 commented Mar 7, 2024

How can I make the server send ping messages with axum when the client is using subscriptions? I know I can use the client to send ping messages but then the client can choose the keep-alive value and so they may set it very short and increase data transmission and server load. I would like the ping messages to come from the server and the client to respond with pong as opposed to vice versa.

I have tried many options on the axum server to no luck. Is this the right place to put it or do I have to yield ping messages in he actual subscription resolver itself?

axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
      .http1_keepalive(true)
      .tcp_keepalive_retries(Some(1))
      .tcp_keepalive_interval(Some(core::time::Duration::from_secs(30)))
      .tcp_keepalive(Some(core::time::Duration::from_secs(30)))
      .http2_keep_alive_timeout(core::time::Duration::from_secs(60))
      .http2_keep_alive_interval(Some(core::time::Duration::from_secs(30)))
      .serve(app.into_make_service())
      .await
      .expect("Starting server failed.");

Here is a rough example of the stream my subscription resolver returns. If I need to put the code in here, how do I do so witohut actually returning a ping Object

return Ok(async_stream::stream! {

    // some vode to get our output object
    yield Ok(my_output_object);
}

This is my Graphql subscription service struct that creates and adds a Credentials struct from the api-key that is received in the on_connection_init(). Do I have to add the ping in here?

use std::{borrow::Cow, convert::Infallible, future::Future, str::FromStr};
use std::borrow::Borrow;

use async_graphql::{
    Data,
    Executor,
    futures_util,
    futures_util::task::{Context, Poll}, http::{ALL_WEBSOCKET_PROTOCOLS, WebSocketProtocols, WsMessage}, Result,
};
use axum::{
    body::{BoxBody, boxed, HttpBody},
    Error,
    extract::{
        FromRequestParts,
        WebSocketUpgrade, ws::{CloseFrame, Message},
    },
    http::{self, Request, request::Parts, Response, StatusCode},
    RequestExt, response::IntoResponse,
};
use axum::extract::FromRequest;
use futures_util::{
    future,
    future::{BoxFuture, Ready},
    Sink,
    SinkExt, stream::{SplitSink, SplitStream}, Stream, StreamExt,
};
use tower::Service;

use crate::data::credentials::Credentials;

/// A GraphQL protocol extractor.
///
/// It extract GraphQL protocol from `SEC_WEBSOCKET_PROTOCOL` header.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct GraphQLProtocol(WebSocketProtocols);

#[async_trait::async_trait]
impl<S> FromRequestParts<S> for GraphQLProtocol
    where
        S: Send + Sync,
{
    type Rejection = StatusCode;

    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
        parts
            .headers
            .get(http::header::SEC_WEBSOCKET_PROTOCOL)
            .and_then(|value| value.to_str().ok())
            .and_then(|protocols| {
                protocols
                    .split(',')
                    .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
            })
            .map(Self)
            .ok_or(StatusCode::BAD_REQUEST)
    }
}

/// A GraphQL subscription service.
pub struct CustomGraphQLSubscription<E> {
    executor: E,
}

impl<E> Clone for CustomGraphQLSubscription<E>
    where
        E: Executor,
{
    fn clone(&self) -> Self {
        Self {
            executor: self.executor.clone(),
        }
    }
}

impl<E> CustomGraphQLSubscription<E>
    where
        E: Executor,
{
    /// Create a GraphQL subscription service.
    pub fn new(executor: E) -> Self {
        Self { executor }
    }
}

impl<B, E> Service<Request<B>> for CustomGraphQLSubscription<E>
    where
        B: HttpBody + Send + 'static,
        E: Executor,
{
    type Response = Response<BoxBody>;
    type Error = Infallible;
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn call(&mut self, req: Request<B>) -> Self::Future {
        let executor = self.executor.clone();

        Box::pin(async move {
            let (mut parts, mut body) = req.into_parts();

            let protocol = match GraphQLProtocol::from_request_parts(&mut parts, &()).await {
                Ok(protocol) => protocol,
                Err(err) => return Ok(err.into_response().map(boxed)),
            };
            let upgrade = match WebSocketUpgrade::from_request_parts(&mut parts, &()).await {
                Ok(protocol) => protocol,
                Err(err) => return Ok(err.into_response().map(boxed)),
            };

            let executor = executor.clone();

            let resp = upgrade
                .protocols(ALL_WEBSOCKET_PROTOCOLS)
                .on_upgrade(move |stream| {
                    GraphQLWebSocket::new(stream, executor, protocol)
                        .on_connection_init(|x: serde_json::Value| {
                            // on connection init, the headers are sent as a payload in a websocket connection.
                            // Extract api-key value, create struct and pass to resolvers, any data here is appended
                            // to existing data when calling with_data() on GraphQlWebSocket

                            let mut data = Data::default();
                            if let Some(api_key) = x.get("api-key").and_then(|x| x.as_str()).clone()
                            {
                                data.insert(Credentials {
                                    api_key: api_key.to_string(),
                                });
                                futures_util::future::ready(Ok(data))
                            } else {
                                futures_util::future::ready(Err(async_graphql::Error::new(
                                    "Missing `api-key` header.",
                                )))
                            }
                        })
                        .serve()
                });

            Ok(resp.into_response().map(boxed))
        })
    }
}

type DefaultOnConnInitType = fn(serde_json::Value) -> Ready<async_graphql::Result<Data>>;

fn default_on_connection_init(x: serde_json::Value) -> Ready<async_graphql::Result<Data>> {
    futures_util::future::ready(Ok(Data::default()))
}

/// A Websocket connection for GraphQL subscription.
pub struct GraphQLWebSocket<Sink, Stream, E, OnConnInit> {
    sink: Sink,
    stream: Stream,
    executor: E,
    data: Data,
    on_connection_init: OnConnInit,
    protocol: GraphQLProtocol,
}

impl<S, E> GraphQLWebSocket<SplitSink<S, Message>, SplitStream<S>, E, DefaultOnConnInitType>
    where
        S: Stream<Item = Result<Message, Error>> + Sink<Message>,
        E: Executor,
{
    /// Create a [`GraphQLWebSocket`] object.
    pub fn new(stream: S, executor: E, protocol: GraphQLProtocol) -> Self {
        let (sink, stream) = stream.split();
        GraphQLWebSocket::new_with_pair(sink, stream, executor, protocol)
    }
}

impl<Sink, Stream, E> GraphQLWebSocket<Sink, Stream, E, DefaultOnConnInitType>
    where
        Sink: futures_util::sink::Sink<Message>,
        Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
        E: Executor,
{
    /// Create a [`GraphQLWebSocket`] object with sink and stream objects.
    pub fn new_with_pair(
        sink: Sink,
        stream: Stream,
        executor: E,
        protocol: GraphQLProtocol,
    ) -> Self {
        GraphQLWebSocket {
            sink,
            stream,
            executor,
            data: Data::default(),
            on_connection_init: default_on_connection_init,
            protocol,
        }
    }
}

impl<Sink, Stream, E, OnConnInit, OnConnInitFut> GraphQLWebSocket<Sink, Stream, E, OnConnInit>
    where
        Sink: futures_util::sink::Sink<Message>,
        Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
        E: Executor,
        OnConnInit: FnOnce(serde_json::Value) -> OnConnInitFut + Send + 'static,
        OnConnInitFut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
{
    /// Specify the initial subscription context data, usually you can get
    /// something from the incoming request to create it.
    #[must_use]
    pub fn with_data(self, data: Data) -> Self {
        Self { data, ..self }
    }

    /// Specify a callback function to be called when the connection is
    /// initialized.
    ///
    /// You can get something from the payload of [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init) to create [`Data`].
    /// The data returned by this callback function will be merged with the data
    /// specified by [`with_data`].
    pub fn on_connection_init<OnConnInit2, Fut>(
        self,
        callback: OnConnInit2,
    ) -> GraphQLWebSocket<Sink, Stream, E, OnConnInit2>
        where
            OnConnInit2: FnOnce(serde_json::Value) -> Fut + Send + 'static,
            Fut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
    {
        GraphQLWebSocket {
            sink: self.sink,
            stream: self.stream,
            executor: self.executor,
            data: self.data,
            on_connection_init: callback,
            protocol: self.protocol,
        }
    }

    /// Processing subscription requests.
    pub async fn serve(self) {
        let input = self
            .stream
            .take_while(|res| future::ready(res.is_ok()))
            .map(Result::unwrap)
            .filter_map(|msg| {
                if let Message::Text(_) | Message::Binary(_) = msg {
                    future::ready(Some(msg))
                } else {
                    future::ready(None)
                }
            })
            .map(Message::into_data);

        let stream =
            async_graphql::http::WebSocket::new(self.executor.clone(), input, self.protocol.0)
                .connection_data(self.data)
                .on_connection_init(self.on_connection_init)
                .map(|msg| match msg {
                    WsMessage::Text(text) => Message::Text(text),
                    WsMessage::Close(code, status) => Message::Close(Some(CloseFrame {
                        code,
                        reason: Cow::from(status),
                    })),
                });

        let sink = self.sink;
        futures_util::pin_mut!(stream, sink);

        while let Some(item) = stream.next().await {
            let _ = sink.send(item).await;
        }
    }
}

@hzargar2 hzargar2 added the question Further information is requested label Mar 7, 2024
@sunli829
Copy link
Collaborator

sunli829 commented Mar 9, 2024

Not currently, I will consider adding this.

@hzargar2
Copy link
Author

hzargar2 commented Mar 9, 2024

Not currently, I will consider adding this.

Oh okay, great! Not to be pushy since you're probably busy, but do you know how long feature requests like this usually take? :) thanks again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants