Skip to content

Commit

Permalink
Improve consumer Stream implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
paolobarbolini committed Jul 22, 2023
1 parent 0534719 commit 3670156
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 279 deletions.
249 changes: 112 additions & 137 deletions async-nats/src/jetstream/consumer/pull.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,57 +444,47 @@ impl<'a> futures::Stream for Sequence<'a> {
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
match self.next.as_mut() {
None => {
let context = self.context.clone();
let subject = self.subject.clone();
let request = self.request.clone();
let pending_messages = self.pending_messages;

let next = self.next.insert(Box::pin(async move {
let inbox = context.client.new_inbox();
let subscriber = context
.client
.subscribe(inbox.clone())
.await
.map_err(|err| MessagesError::with_source(MessagesErrorKind::Pull, err))?;
let this = self.as_mut().get_mut();

let next = this.next.get_or_insert_with(|| {
let context = this.context.clone();
let subject = this.subject.clone();
let request = this.request.clone();
let pending_messages = this.pending_messages;
let inbox = context.client.new_inbox();

Box::pin(async move {
let subscriber = context
.client
.subscribe(inbox.clone())
.await
.map_err(|err| MessagesError::with_source(MessagesErrorKind::Pull, err))?;

context
.client
.publish_with_reply(subject, inbox, request)
.await
.map_err(|err| MessagesError::with_source(MessagesErrorKind::Pull, err))?;

// TODO(tp): Add timeout config and defaults.
Ok(Batch {
pending_messages,
subscriber,
context,
terminated: false,
timeout: Some(Box::pin(tokio::time::sleep(Duration::from_secs(60)))),
})
})
});

context
.client
.publish_with_reply(subject, inbox, request)
.await
.map_err(|err| MessagesError::with_source(MessagesErrorKind::Pull, err))?;

// TODO(tp): Add timeout config and defaults.
Ok(Batch {
pending_messages,
subscriber,
context,
terminated: false,
timeout: Some(Box::pin(tokio::time::sleep(Duration::from_secs(60)))),
})
}));

match next.as_mut().poll(cx) {
Poll::Ready(result) => {
self.next = None;
Poll::Ready(Some(result.map_err(|err| {
MessagesError::with_source(MessagesErrorKind::Pull, err)
})))
}
Poll::Pending => Poll::Pending,
}
match next.poll_unpin(cx) {
Poll::Ready(result) => {
this.next = None;
Poll::Ready(Some(result.map_err(|err| {
MessagesError::with_source(MessagesErrorKind::Pull, err)
})))
}

Some(next) => match next.as_mut().poll(cx) {
Poll::Ready(result) => {
self.next = None;
Poll::Ready(Some(result.map_err(|err| {
MessagesError::with_source(MessagesErrorKind::Pull, err)
})))
}
Poll::Pending => Poll::Pending,
},
Poll::Pending => Poll::Pending,
}
}
}
Expand Down Expand Up @@ -751,34 +741,32 @@ impl<'a> futures::Stream for Ordered<'a> {
// Poll messages
if let Some(stream) = self.stream.as_mut() {
match stream.poll_next_unpin(cx) {
Poll::Ready(message) => match message {
Some(message) => {
// Do we bail out on all errors?
// Or we want to handle some? (like consumer deleted?)
let message = message?;
let info = message.info().map_err(|err| {
OrderedError::with_source(OrderedErrorKind::Other, err)
})?;
trace!("consumer sequence: {:?}, stream sequence {:?}, consumer sequence in message: {:?} stream sequence in message: {:?}",
Poll::Ready(Some(message)) => {
// Do we bail out on all errors?
// Or we want to handle some? (like consumer deleted?)
let message = message?;
let info = message
.info()
.map_err(|err| OrderedError::with_source(OrderedErrorKind::Other, err))?;
trace!("consumer sequence: {:?}, stream sequence {:?}, consumer sequence in message: {:?} stream sequence in message: {:?}",
self.consumer_sequence,
self.stream_sequence,
info.consumer_sequence,
info.stream_sequence);
if info.consumer_sequence != self.consumer_sequence + 1 {
debug!(
"ordered consumer mismatch. current {}, info: {}",
self.consumer_sequence, info.consumer_sequence
);
recreate = true;
self.consumer_sequence = 0;
} else {
self.stream_sequence = info.stream_sequence;
self.consumer_sequence = info.consumer_sequence;
return Poll::Ready(Some(Ok(message)));
}
if info.consumer_sequence != self.consumer_sequence + 1 {
debug!(
"ordered consumer mismatch. current {}, info: {}",
self.consumer_sequence, info.consumer_sequence
);
recreate = true;
self.consumer_sequence = 0;
} else {
self.stream_sequence = info.stream_sequence;
self.consumer_sequence = info.consumer_sequence;
return Poll::Ready(Some(Ok(message)));
}
None => return Poll::Ready(None),
},
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => (),
}
}
Expand All @@ -792,27 +780,25 @@ impl<'a> futures::Stream for Ordered<'a> {
let consumer_name = self.consumer_name.clone();
let sequence = self.consumer_sequence;
async move {
recreate_consumer_stream(context, config, stream_name, consumer_name, sequence)
recreate_consumer_stream(context, config, stream_name, &consumer_name, sequence)
.await
}
}))
}
// check for recreation future
if let Some(result) = self.create_stream.as_mut() {
match result.poll_unpin(cx) {
Poll::Ready(result) => match result {
Ok(stream) => {
self.create_stream = None;
self.stream = Some(stream);
return self.poll_next(cx);
}
Err(err) => {
return Poll::Ready(Some(Err(OrderedError::with_source(
OrderedErrorKind::Recreate,
err,
))))
}
},
Poll::Ready(Ok(stream)) => {
self.create_stream = None;
self.stream = Some(stream);
return self.poll_next(cx);
}
Poll::Ready(Err(err)) => {
return Poll::Ready(Some(Err(OrderedError::with_source(
OrderedErrorKind::Recreate,
err,
))))
}
Poll::Pending => (),
}
}
Expand Down Expand Up @@ -909,7 +895,7 @@ impl Stream {
debug!("expired pull request")},
}

let request = serde_json::to_vec(&batch).map(Bytes::from).unwrap();
let request = Bytes::from(serde_json::to_vec(&batch).unwrap());
let result = context
.client
.publish_with_reply(subject.clone(), inbox.clone(), request.clone())
Expand Down Expand Up @@ -1048,20 +1034,17 @@ impl futures::Stream for Stream {
if !self.batch_config.idle_heartbeat.is_zero() {
trace!("setting hearbeats");
let timeout = self.batch_config.idle_heartbeat.saturating_mul(2);
self.heartbeat_timeout
let heartbeat_timeout = self
.heartbeat_timeout
.get_or_insert_with(|| Box::pin(tokio::time::sleep(timeout)));

trace!("checking idle hearbeats");
if let Some(hearbeat) = self.heartbeat_timeout.as_mut() {
match hearbeat.poll_unpin(cx) {
Poll::Ready(_) => {
self.heartbeat_timeout = None;
return Poll::Ready(Some(Err(MessagesError::new(
MessagesErrorKind::MissingHeartbeat,
))));
}
Poll::Pending => (),
}
if heartbeat_timeout.poll_unpin(cx).is_ready() {
self.heartbeat_timeout = None;

return Poll::Ready(Some(Err(MessagesError::new(
MessagesErrorKind::MissingHeartbeat,
))));
}
}

Expand All @@ -1078,30 +1061,26 @@ impl futures::Stream for Stream {
}

match self.request_result_rx.poll_recv(cx) {
Poll::Ready(resp) => match resp {
Some(resp) => match resp {
Ok(reset) => {
trace!("request response: {:?}", reset);
debug!("request sent, setting pending messages");
if reset {
self.pending_messages = self.batch_config.batch;
self.pending_bytes = self.batch_config.max_bytes;
} else {
self.pending_messages += self.batch_config.batch;
self.pending_bytes += self.batch_config.max_bytes;
}
self.pending_request = false;
continue;
}
Err(err) => {
return Poll::Ready(Some(Err(MessagesError::with_source(
MessagesErrorKind::Pull,
err,
))))
}
},
None => return Poll::Ready(None),
},
Poll::Ready(Some(Ok(reset))) => {
trace!("request response: {:?}", reset);
debug!("request sent, setting pending messages");
if reset {
self.pending_messages = self.batch_config.batch;
self.pending_bytes = self.batch_config.max_bytes;
} else {
self.pending_messages += self.batch_config.batch;
self.pending_bytes += self.batch_config.max_bytes;
}
self.pending_request = false;
continue;
}
Poll::Ready(Some(Err(err))) => {
return Poll::Ready(Some(Err(MessagesError::with_source(
MessagesErrorKind::Pull,
err,
))))
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => {
trace!("pending result");
}
Expand All @@ -1111,6 +1090,7 @@ impl futures::Stream for Stream {
match self.subscriber.receiver.poll_recv(cx) {
Poll::Ready(maybe_message) => {
self.heartbeat_timeout = None;

match maybe_message {
Some(message) => match message.status.unwrap_or(StatusCode::OK) {
StatusCode::TIMEOUT | StatusCode::REQUEST_TERMINATED => {
Expand Down Expand Up @@ -2260,7 +2240,7 @@ async fn recreate_consumer_stream(
context: Context,
config: Config,
stream_name: String,
consumer_name: String,
consumer_name: &str,
sequence: u64,
) -> Result<Stream, ConsumerRecreateError> {
// TODO(jarema): retry whole operation few times?
Expand All @@ -2270,20 +2250,15 @@ async fn recreate_consumer_stream(
.map_err(|err| {
ConsumerRecreateError::with_source(ConsumerRecreateErrorKind::GetStream, err)
})?;
stream
.delete_consumer(&consumer_name)
.await
.map_err(|err| {
ConsumerRecreateError::with_source(ConsumerRecreateErrorKind::Recreate, err)
})?;

let deliver_policy = {
if sequence == 0 {
DeliverPolicy::All
} else {
DeliverPolicy::ByStartSequence {
start_sequence: sequence + 1,
}
stream.delete_consumer(consumer_name).await.map_err(|err| {
ConsumerRecreateError::with_source(ConsumerRecreateErrorKind::Recreate, err)
})?;

let deliver_policy = if sequence == 0 {
DeliverPolicy::All
} else {
DeliverPolicy::ByStartSequence {
start_sequence: sequence + 1,
}
};
tokio::time::timeout(
Expand Down

0 comments on commit 3670156

Please sign in to comment.