Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use context for Publish methods #96

Merged
merged 10 commits into from
Jul 13, 2022
15 changes: 10 additions & 5 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package amqp091

import (
"context"
"errors"
"reflect"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -1358,10 +1360,9 @@ confirmations start at 1. Exit when all publishings are confirmed.

When Publish does not return an error and the channel is in confirm mode, the
internal counter for DeliveryTags with the first confirmation starts at 1.

*/
func (ch *Channel) Publish(exchange, key string, mandatory, immediate bool, msg Publishing) error {
_, err := ch.PublishWithDeferredConfirm(exchange, key, mandatory, immediate, msg)
func (ch *Channel) Publish(ctx context.Context, exchange, key string, mandatory, immediate bool, msg Publishing) error {
_, err := ch.PublishWithDeferredConfirm(ctx, exchange, key, mandatory, immediate, msg)
return err
}

Expand All @@ -1371,7 +1372,11 @@ DeferredConfirmation, allowing the caller to wait on the publisher confirmation
for this message. If the channel has not been put into confirm mode,
the DeferredConfirmation will be nil.
*/
func (ch *Channel) PublishWithDeferredConfirm(exchange, key string, mandatory, immediate bool, msg Publishing) (*DeferredConfirmation, error) {
func (ch *Channel) PublishWithDeferredConfirm(ctx context.Context, exchange, key string, mandatory, immediate bool, msg Publishing) (*DeferredConfirmation, error) {
if ctx == nil {
return nil, errors.New("amqp091-go: nil Context")
}
lukebakken marked this conversation as resolved.
Show resolved Hide resolved

if err := msg.Headers.Validate(); err != nil {
return nil, err
}
Expand Down Expand Up @@ -1405,7 +1410,7 @@ func (ch *Channel) PublishWithDeferredConfirm(exchange, key string, mandatory, i
}

if ch.confirming {
return ch.confirms.Publish(), nil
return ch.confirms.Publish(ctx), nil
}

return nil, nil
Expand Down
45 changes: 33 additions & 12 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package amqp091

import (
"bytes"
"context"
"io"
"reflect"
"testing"
Expand Down Expand Up @@ -398,6 +399,10 @@ func TestOpenFailedVhost(t *testing.T) {
}

func TestConfirmMultipleOrdersDeliveryTags(t *testing.T) {
deadLine, _ := t.Deadline()
ctx, cancel := context.WithDeadline(context.Background(), deadLine)
defer cancel()

rwc, srv := newSession(t)
defer rwc.Close()

Expand Down Expand Up @@ -450,16 +455,16 @@ func TestConfirmMultipleOrdersDeliveryTags(t *testing.T) {

go func() {
var e error
if e = ch.Publish("", "q", false, false, Publishing{Body: []byte("pub 1")}); e != nil {
if e = ch.Publish(ctx, "", "q", false, false, Publishing{Body: []byte("pub 1")}); e != nil {
t.Errorf("publish error: %v", err)
}
if e = ch.Publish("", "q", false, false, Publishing{Body: []byte("pub 2")}); e != nil {
if e = ch.Publish(ctx, "", "q", false, false, Publishing{Body: []byte("pub 2")}); e != nil {
t.Errorf("publish error: %v", err)
}
if e = ch.Publish("", "q", false, false, Publishing{Body: []byte("pub 3")}); e != nil {
if e = ch.Publish(ctx, "", "q", false, false, Publishing{Body: []byte("pub 3")}); e != nil {
t.Errorf("publish error: %v", err)
}
if e = ch.Publish("", "q", false, false, Publishing{Body: []byte("pub 4")}); e != nil {
if e = ch.Publish(ctx, "", "q", false, false, Publishing{Body: []byte("pub 4")}); e != nil {
t.Errorf("publish error: %v", err)
}
}()
Expand All @@ -473,16 +478,16 @@ func TestConfirmMultipleOrdersDeliveryTags(t *testing.T) {

go func() {
var e error
if e = ch.Publish("", "q", false, false, Publishing{Body: []byte("pub 5")}); e != nil {
if e = ch.Publish(ctx, "", "q", false, false, Publishing{Body: []byte("pub 5")}); e != nil {
t.Errorf("publish error: %v", err)
}
if e = ch.Publish("", "q", false, false, Publishing{Body: []byte("pub 6")}); e != nil {
if e = ch.Publish(ctx, "", "q", false, false, Publishing{Body: []byte("pub 6")}); e != nil {
t.Errorf("publish error: %v", err)
}
if e = ch.Publish("", "q", false, false, Publishing{Body: []byte("pub 7")}); e != nil {
if e = ch.Publish(ctx, "", "q", false, false, Publishing{Body: []byte("pub 7")}); e != nil {
t.Errorf("publish error: %v", err)
}
if e = ch.Publish("", "q", false, false, Publishing{Body: []byte("pub 8")}); e != nil {
if e = ch.Publish(ctx, "", "q", false, false, Publishing{Body: []byte("pub 8")}); e != nil {
t.Errorf("publish error: %v", err)
}
}()
Expand All @@ -496,6 +501,10 @@ func TestConfirmMultipleOrdersDeliveryTags(t *testing.T) {
}

func TestDeferredConfirmations(t *testing.T) {
deadLine, _ := t.Deadline()
ctx, cancel := context.WithDeadline(context.Background(), deadLine)
defer cancel()

rwc, srv := newSession(t)
defer rwc.Close()

Expand Down Expand Up @@ -529,7 +538,7 @@ func TestDeferredConfirmations(t *testing.T) {

var results []*DeferredConfirmation
for i := 1; i < 5; i++ {
dc, err := ch.PublishWithDeferredConfirm("", "q", false, false, Publishing{Body: []byte("pub")})
dc, err := ch.PublishWithDeferredConfirm(ctx, "", "q", false, false, Publishing{Body: []byte("pub")})
if err != nil {
t.Fatalf("failed to PublishWithDeferredConfirm: %v", err)
}
Expand Down Expand Up @@ -663,6 +672,10 @@ func TestNotifyClosesAllChansAfterConnectionClose(t *testing.T) {

// Should not panic when sending bodies split at different boundaries
func TestPublishBodySliceIssue74(t *testing.T) {
deadLine, _ := t.Deadline()
ctx, cancel := context.WithDeadline(context.Background(), deadLine)
defer cancel()

rwc, srv := newSession(t)
defer rwc.Close()

Expand Down Expand Up @@ -698,7 +711,7 @@ func TestPublishBodySliceIssue74(t *testing.T) {

for i := 0; i < publishings; i++ {
go func(ii int) {
if err := ch.Publish("", "q", false, false, Publishing{Body: base[0:ii]}); err != nil {
if err := ch.Publish(ctx, "", "q", false, false, Publishing{Body: base[0:ii]}); err != nil {
t.Errorf("publish error: %v", err)
}
}(i)
Expand All @@ -709,6 +722,10 @@ func TestPublishBodySliceIssue74(t *testing.T) {

// Should not panic when server and client have frame_size of 0
func TestPublishZeroFrameSizeIssue161(t *testing.T) {
deadLine, _ := t.Deadline()
ctx, cancel := context.WithDeadline(context.Background(), deadLine)
defer cancel()

rwc, srv := newSession(t)
defer rwc.Close()

Expand Down Expand Up @@ -746,7 +763,7 @@ func TestPublishZeroFrameSizeIssue161(t *testing.T) {

for i := 0; i < publishings; i++ {
go func() {
if err := ch.Publish("", "q", false, false, Publishing{Body: []byte("anything")}); err != nil {
if err := ch.Publish(ctx, "", "q", false, false, Publishing{Body: []byte("anything")}); err != nil {
t.Errorf("publish error: %v", err)
}
}()
Expand All @@ -756,6 +773,10 @@ func TestPublishZeroFrameSizeIssue161(t *testing.T) {
}

func TestPublishAndShutdownDeadlockIssue84(t *testing.T) {
deadLine, _ := t.Deadline()
ctx, cancel := context.WithDeadline(context.Background(), deadLine)
defer cancel()

rwc, srv := newSession(t)
defer rwc.Close()

Expand All @@ -779,7 +800,7 @@ func TestPublishAndShutdownDeadlockIssue84(t *testing.T) {

defer time.AfterFunc(500*time.Millisecond, func() { t.Fatalf("Publish deadlock") }).Stop()
for {
if err := ch.Publish("exchange", "q", false, false, Publishing{Body: []byte("test")}); err != nil {
if err := ch.Publish(ctx, "exchange", "q", false, false, Publishing{Body: []byte("test")}); err != nil {
t.Log("successfully caught disconnect error", err)
return
}
Expand Down
17 changes: 9 additions & 8 deletions confirms.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package amqp091

import (
"context"
"sync"
)

Expand Down Expand Up @@ -38,12 +39,12 @@ func (c *confirms) Listen(l chan Confirmation) {
}

// Publish increments the publishing counter
func (c *confirms) Publish() *DeferredConfirmation {
func (c *confirms) Publish(ctx context.Context) *DeferredConfirmation {
c.publishedMut.Lock()
defer c.publishedMut.Unlock()

c.published++
return c.deferredConfirmations.Add(c.published)
return c.deferredConfirmations.Add(ctx, c.published)
}

// confirm confirms one publishing, increments the expecting delivery tag, and
Expand Down Expand Up @@ -124,12 +125,12 @@ func newDeferredConfirmations() *deferredConfirmations {
}
}

func (d *deferredConfirmations) Add(tag uint64) *DeferredConfirmation {
func (d *deferredConfirmations) Add(ctx context.Context, tag uint64) *DeferredConfirmation {
d.m.Lock()
defer d.m.Unlock()

dc := &DeferredConfirmation{DeliveryTag: tag}
dc.wg.Add(1)
dc.ctx, dc.cancel = context.WithCancel(ctx)
d.confirmations[tag] = dc
return dc
}
Expand All @@ -144,7 +145,7 @@ func (d *deferredConfirmations) Confirm(confirmation Confirmation) {
return
}
dc.confirmation = confirmation
dc.wg.Done()
dc.cancel()
delete(d.confirmations, confirmation.DeliveryTag)
}

Expand All @@ -155,7 +156,7 @@ func (d *deferredConfirmations) ConfirmMultiple(confirmation Confirmation) {
for k, v := range d.confirmations {
if k <= confirmation.DeliveryTag {
v.confirmation = Confirmation{DeliveryTag: k, Ack: confirmation.Ack}
v.wg.Done()
v.cancel()
delete(d.confirmations, k)
}
}
Expand All @@ -168,13 +169,13 @@ func (d *deferredConfirmations) Close() {

for k, v := range d.confirmations {
v.confirmation = Confirmation{DeliveryTag: k, Ack: false}
v.wg.Done()
v.cancel()
delete(d.confirmations, k)
}
}

// Waits for publisher confirmation. Returns true if server successfully received the publishing.
func (d *DeferredConfirmation) Wait() bool {
d.wg.Wait()
<-d.ctx.Done()
lukebakken marked this conversation as resolved.
Show resolved Hide resolved
return d.confirmation.Ack
}