Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use context for Publish methods #96

Merged
merged 10 commits into from
Jul 13, 2022
60 changes: 58 additions & 2 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package amqp091

import (
"context"
"errors"
"reflect"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -1359,9 +1361,47 @@ confirmations start at 1. Exit when all publishings are confirmed.
When Publish does not return an error and the channel is in confirm mode, the
internal counter for DeliveryTags with the first confirmation starts at 1.

Deprecated: Use PublishWithContext instead.
*/
func (ch *Channel) Publish(exchange, key string, mandatory, immediate bool, msg Publishing) error {
_, err := ch.PublishWithDeferredConfirm(exchange, key, mandatory, immediate, msg)
_, err := ch.PublishWithDeferredConfirmWithContext(context.Background(), exchange, key, mandatory, immediate, msg)
return err
}

/*
PublishWithContext sends a Publishing from the client to an exchange on the server.

When you want a single message to be delivered to a single queue, you can
publish to the default exchange with the routingKey of the queue name. This is
because every declared queue gets an implicit route to the default exchange.

Since publishings are asynchronous, any undeliverable message will get returned
by the server. Add a listener with Channel.NotifyReturn to handle any
undeliverable message when calling publish with either the mandatory or
immediate parameters as true.

Publishings can be undeliverable when the mandatory flag is true and no queue is
bound that matches the routing key, or when the immediate flag is true and no
consumer on the matched queue is ready to accept the delivery.

This can return an error when the channel, connection or socket is closed. The
error or lack of an error does not indicate whether the server has received this
publishing.

It is possible for publishing to not reach the broker if the underlying socket
is shut down without pending publishing packets being flushed from the kernel
buffers. The easy way of making it probable that all publishings reach the
server is to always call Connection.Close before terminating your publishing
application. The way to ensure that all publishings reach the server is to add
a listener to Channel.NotifyPublish and put the channel in confirm mode with
Channel.Confirm. Publishing delivery tags and their corresponding
confirmations start at 1. Exit when all publishings are confirmed.

When Publish does not return an error and the channel is in confirm mode, the
internal counter for DeliveryTags with the first confirmation starts at 1.
*/
func (ch *Channel) PublishWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg Publishing) error {
_, err := ch.PublishWithDeferredConfirmWithContext(ctx, exchange, key, mandatory, immediate, msg)
return err
}

Expand All @@ -1370,8 +1410,24 @@ PublishWithDeferredConfirm behaves identically to Publish but additionally retur
DeferredConfirmation, allowing the caller to wait on the publisher confirmation
for this message. If the channel has not been put into confirm mode,
the DeferredConfirmation will be nil.

Deprecated: Use PublishWithDeferredConfirmWithContext instead.
*/
func (ch *Channel) PublishWithDeferredConfirm(exchange, key string, mandatory, immediate bool, msg Publishing) (*DeferredConfirmation, error) {
return ch.PublishWithDeferredConfirmWithContext(context.Background(), exchange, key, mandatory, immediate, msg)
}

/*
PublishWithDeferredConfirmWithContext behaves identically to Publish but additionally returns a
DeferredConfirmation, allowing the caller to wait on the publisher confirmation
for this message. If the channel has not been put into confirm mode,
the DeferredConfirmation will be nil.
*/
func (ch *Channel) PublishWithDeferredConfirmWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg Publishing) (*DeferredConfirmation, error) {
if ctx == nil {
return nil, errors.New("amqp091-go: nil Context")
}
lukebakken marked this conversation as resolved.
Show resolved Hide resolved

if err := msg.Headers.Validate(); err != nil {
return nil, err
}
Expand Down Expand Up @@ -1405,7 +1461,7 @@ func (ch *Channel) PublishWithDeferredConfirm(exchange, key string, mandatory, i
}

if ch.confirming {
return ch.confirms.Publish(), nil
return ch.confirms.Publish(ctx), nil
}

return nil, nil
Expand Down
17 changes: 9 additions & 8 deletions confirms.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package amqp091

import (
"context"
"sync"
)

Expand Down Expand Up @@ -38,12 +39,12 @@ func (c *confirms) Listen(l chan Confirmation) {
}

// Publish increments the publishing counter
func (c *confirms) Publish() *DeferredConfirmation {
func (c *confirms) Publish(ctx context.Context) *DeferredConfirmation {
c.publishedMut.Lock()
defer c.publishedMut.Unlock()

c.published++
return c.deferredConfirmations.Add(c.published)
return c.deferredConfirmations.Add(ctx, c.published)
}

// confirm confirms one publishing, increments the expecting delivery tag, and
Expand Down Expand Up @@ -124,12 +125,12 @@ func newDeferredConfirmations() *deferredConfirmations {
}
}

func (d *deferredConfirmations) Add(tag uint64) *DeferredConfirmation {
func (d *deferredConfirmations) Add(ctx context.Context, tag uint64) *DeferredConfirmation {
d.m.Lock()
defer d.m.Unlock()

dc := &DeferredConfirmation{DeliveryTag: tag}
dc.wg.Add(1)
dc.ctx, dc.cancel = context.WithCancel(ctx)
d.confirmations[tag] = dc
return dc
}
Expand All @@ -144,7 +145,7 @@ func (d *deferredConfirmations) Confirm(confirmation Confirmation) {
return
}
dc.confirmation = confirmation
dc.wg.Done()
dc.cancel()
delete(d.confirmations, confirmation.DeliveryTag)
}

Expand All @@ -155,7 +156,7 @@ func (d *deferredConfirmations) ConfirmMultiple(confirmation Confirmation) {
for k, v := range d.confirmations {
if k <= confirmation.DeliveryTag {
v.confirmation = Confirmation{DeliveryTag: k, Ack: confirmation.Ack}
v.wg.Done()
v.cancel()
delete(d.confirmations, k)
}
}
Expand All @@ -168,13 +169,13 @@ func (d *deferredConfirmations) Close() {

for k, v := range d.confirmations {
v.confirmation = Confirmation{DeliveryTag: k, Ack: false}
v.wg.Done()
v.cancel()
delete(d.confirmations, k)
}
}

// Waits for publisher confirmation. Returns true if server successfully received the publishing.
func (d *DeferredConfirmation) Wait() bool {
d.wg.Wait()
<-d.ctx.Done()
lukebakken marked this conversation as resolved.
Show resolved Hide resolved
return d.confirmation.Ack
}
92 changes: 79 additions & 13 deletions confirms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@
package amqp091

import (
"context"
"sync"
"testing"
"time"
)

func TestConfirmOneResequences(t *testing.T) {
deadLine, _ := t.Deadline()
ctx, cancel := context.WithDeadline(context.Background(), deadLine)
defer cancel()

var (
fixtures = []Confirmation{
{1, true},
Expand All @@ -25,7 +30,7 @@ func TestConfirmOneResequences(t *testing.T) {
c.Listen(l)

for i := range fixtures {
if want, got := uint64(i+1), c.Publish(); want != got.DeliveryTag {
if want, got := uint64(i+1), c.Publish(ctx); want != got.DeliveryTag {
t.Fatalf("expected publish to return the 1 based delivery tag published, want: %d, got: %d", want, got.DeliveryTag)
}
}
Expand All @@ -49,6 +54,10 @@ func TestConfirmOneResequences(t *testing.T) {
}

func TestConfirmAndPublishDoNotDeadlock(t *testing.T) {
deadLine, _ := t.Deadline()
ctx, cancel := context.WithDeadline(context.Background(), deadLine)
defer cancel()

var (
c = newConfirms()
l = make(chan Confirmation)
Expand All @@ -63,12 +72,16 @@ func TestConfirmAndPublishDoNotDeadlock(t *testing.T) {
}()

for i := 0; i < iterations; i++ {
c.Publish()
c.Publish(ctx)
<-l
}
}

func TestConfirmMixedResequences(t *testing.T) {
deadLine, _ := t.Deadline()
ctx, cancel := context.WithDeadline(context.Background(), deadLine)
defer cancel()

var (
fixtures = []Confirmation{
{1, true},
Expand All @@ -81,7 +94,7 @@ func TestConfirmMixedResequences(t *testing.T) {
c.Listen(l)

for range fixtures {
c.Publish()
c.Publish(ctx)
}

c.One(fixtures[0])
Expand All @@ -103,6 +116,10 @@ func TestConfirmMixedResequences(t *testing.T) {
}

func TestConfirmMultipleResequences(t *testing.T) {
deadLine, _ := t.Deadline()
ctx, cancel := context.WithDeadline(context.Background(), deadLine)
defer cancel()

var (
fixtures = []Confirmation{
{1, true},
Expand All @@ -116,7 +133,7 @@ func TestConfirmMultipleResequences(t *testing.T) {
c.Listen(l)

for range fixtures {
c.Publish()
c.Publish(ctx)
}

c.Multiple(fixtures[len(fixtures)-1])
Expand All @@ -129,6 +146,9 @@ func TestConfirmMultipleResequences(t *testing.T) {
}

func BenchmarkSequentialBufferedConfirms(t *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

var (
c = newConfirms()
l = make(chan Confirmation, 10)
Expand All @@ -140,11 +160,15 @@ func BenchmarkSequentialBufferedConfirms(t *testing.B) {
if i > cap(l)-1 {
<-l
}
c.One(Confirmation{c.Publish().DeliveryTag, true})
c.One(Confirmation{c.Publish(ctx).DeliveryTag, true})
}
}

func TestConfirmsIsThreadSafe(t *testing.T) {
deadLine, _ := t.Deadline()
ctx, cancel := context.WithDeadline(context.Background(), deadLine)
defer cancel()

const count = 1000
const timeout = 5 * time.Second
var (
Expand All @@ -158,7 +182,7 @@ func TestConfirmsIsThreadSafe(t *testing.T) {
c.Listen(l)

for i := 0; i < count; i++ {
go func() { pub <- Confirmation{c.Publish().DeliveryTag, true} }()
go func() { pub <- Confirmation{c.Publish(ctx).DeliveryTag, true} }()
}

for i := 0; i < count; i++ {
Expand All @@ -184,7 +208,7 @@ func TestDeferredConfirmationsConfirm(t *testing.T) {
for i, ack := range []bool{true, false} {
var result bool
deliveryTag := uint64(i + 1)
dc := dcs.Add(deliveryTag)
dc := dcs.Add(context.Background(), deliveryTag)
wg.Add(1)
go func() {
result = dc.Wait()
Expand All @@ -202,9 +226,9 @@ func TestDeferredConfirmationsConfirmMultiple(t *testing.T) {
dcs := newDeferredConfirmations()
var wg sync.WaitGroup
var result bool
dc1 := dcs.Add(1)
dc2 := dcs.Add(2)
dc3 := dcs.Add(3)
dc1 := dcs.Add(context.Background(), 1)
dc2 := dcs.Add(context.Background(), 2)
dc3 := dcs.Add(context.Background(), 3)
wg.Add(1)
go func() {
result = dc1.Wait() && dc2.Wait() && dc3.Wait()
Expand All @@ -221,9 +245,9 @@ func TestDeferredConfirmationsClose(t *testing.T) {
dcs := newDeferredConfirmations()
var wg sync.WaitGroup
var result bool
dc1 := dcs.Add(1)
dc2 := dcs.Add(2)
dc3 := dcs.Add(3)
dc1 := dcs.Add(context.Background(), 1)
dc2 := dcs.Add(context.Background(), 2)
dc3 := dcs.Add(context.Background(), 3)
wg.Add(1)
go func() {
result = !dc1.Wait() && !dc2.Wait() && !dc3.Wait()
Expand All @@ -235,3 +259,45 @@ func TestDeferredConfirmationsClose(t *testing.T) {
t.Fatal("expected to receive false for nacked confirmations, received true")
}
}

func TestDeferredConfirmationsContextCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()

dcs := newDeferredConfirmations()
var wg sync.WaitGroup
var result bool
dc1 := dcs.Add(ctx, 1)
dc2 := dcs.Add(ctx, 2)
dc3 := dcs.Add(ctx, 3)
wg.Add(1)
go func() {
result = !dc1.Wait() && !dc2.Wait() && !dc3.Wait()
wg.Done()
}()
wg.Wait()
if !result {
t.Fatal("expected to receive false for timeout confirmations, received true")
}
}

func TestDeferredConfirmationsContextTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
defer cancel()

dcs := newDeferredConfirmations()
var wg sync.WaitGroup
var result bool
dc1 := dcs.Add(ctx, 1)
dc2 := dcs.Add(ctx, 2)
dc3 := dcs.Add(ctx, 3)
wg.Add(1)
go func() {
result = !dc1.Wait() && !dc2.Wait() && !dc3.Wait()
wg.Done()
}()
wg.Wait()
if !result {
t.Fatal("expected to receive false for timeout confirmations, received true")
}
}
8 changes: 7 additions & 1 deletion example_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package amqp091_test

import (
"context"
"errors"
"fmt"
"log"
Expand Down Expand Up @@ -242,7 +243,12 @@ func (client *Client) UnsafePush(data []byte) error {
if !client.isReady {
return errNotConnected
}
return client.channel.Publish(

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

return client.channel.PublishWithContext(
ctx,
"", // Exchange
client.queueName, // Routing key
false, // Mandatory
Expand Down