From 51c88611c5fd83d95271d9700c2ced7126048c91 Mon Sep 17 00:00:00 2001 From: Martin Bechtle Date: Mon, 23 Nov 2020 11:46:07 +0100 Subject: [PATCH] SQS Parallel processing support (#275) Signed-off-by: Martin Bechtle --- .gitignore | 5 +- README.md | 3 +- component/async/amqp/amqp.go | 4 + component/async/amqp/amqp_test.go | 4 +- component/async/async.go | 1 + component/async/component.go | 48 ++++++++- component/async/component_test.go | 91 +++++++++++++--- component/async/kafka/group/group.go | 4 + component/async/kafka/group/group_test.go | 15 +-- component/async/kafka/simple/simple.go | 4 + component/async/kafka/simple/simple_test.go | 4 +- component/async/sqs/option.go | 26 ++--- component/async/sqs/option_test.go | 62 +++-------- component/async/sqs/sqs.go | 47 ++++---- component/async/sqs/sqs_test.go | 8 +- examples/sqs-simple/main.go | 108 +++++++++++++++++++ test/docker/aws/consumer_integration_test.go | 14 ++- 17 files changed, 328 insertions(+), 120 deletions(-) create mode 100644 examples/sqs-simple/main.go diff --git a/.gitignore b/.gitignore index c87902091..270b7d544 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,7 @@ gotestsum-report.xml gotestsum-report.xml # coverage file -coverage.txt \ No newline at end of file +coverage.txt + +# patron binaries +/cmd/patron/patron diff --git a/README.md b/README.md index 6ba1e07a5..8b4458a47 100644 --- a/README.md +++ b/README.md @@ -152,7 +152,8 @@ Detailed examples can be found in the [examples](/examples) folder with the foll - [Kafka Component, HTTP Component, HTTP Authentication, Kafka Tracing](/examples/kafka/main.go) - [Kafka Component, AMQP Tracing](/examples/amqp/main.go) - [AMQP Component, AWS SNS](/examples/sns/main.go) -- [AWS SQS](/examples/sqs/main.go) +- [AWS SQS consumer performing gRPC request](/examples/sqs/main.go) +- [AWS SQS consumer, highly customised](/examples/sqs-simple/main.go) - [gRPC](/examples/grpc/main.go) ## Processors diff --git a/component/async/amqp/amqp.go b/component/async/amqp/amqp.go index ccc0a86ac..2ddd25e2c 100644 --- a/component/async/amqp/amqp.go +++ b/component/async/amqp/amqp.go @@ -157,6 +157,10 @@ type consumer struct { conn *amqp.Connection } +func (c *consumer) OutOfOrder() bool { + return true +} + // Consume starts of consuming a AMQP queue. func (c *consumer) Consume(ctx context.Context) (<-chan async.Message, <-chan error, error) { deliveries, err := c.consume() diff --git a/component/async/amqp/amqp_test.go b/component/async/amqp/amqp_test.go index 2e34fd51d..741afa4b1 100644 --- a/component/async/amqp/amqp_test.go +++ b/component/async/amqp/amqp_test.go @@ -10,6 +10,7 @@ import ( "github.com/opentracing/opentracing-go/mocktracer" "github.com/streadway/amqp" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var validExch, _ = NewExchange("e", amqp.ExchangeDirect) @@ -130,7 +131,8 @@ func TestFactory_Create(t *testing.T) { assert.Nil(t, got) } else { assert.NoError(t, err) - assert.NotNil(t, got) + require.NotNil(t, got) + assert.True(t, got.OutOfOrder()) } }) } diff --git a/component/async/async.go b/component/async/async.go index 63fad8acf..3fa381c31 100644 --- a/component/async/async.go +++ b/component/async/async.go @@ -43,6 +43,7 @@ type ConsumerFactory interface { // Consumer interface which every specific consumer has to implement. type Consumer interface { Consume(context.Context) (<-chan Message, <-chan error, error) + OutOfOrder() bool Close() error } diff --git a/component/async/component.go b/component/async/component.go index 95ec2d2b8..fb4276e9a 100644 --- a/component/async/component.go +++ b/component/async/component.go @@ -40,6 +40,9 @@ type Component struct { cf ConsumerFactory retries int retryWait time.Duration + concurrency int + jobs chan Message + jobErr chan error } // Builder gathers all required properties in order to construct a component @@ -51,6 +54,7 @@ type Builder struct { cf ConsumerFactory retries uint retryWait time.Duration + concurrency uint } // New initializes a new builder for a component with the given name @@ -95,6 +99,15 @@ func (cb *Builder) WithRetries(retries uint) *Builder { return cb } +// WithConcurrency specifies the number of worker goroutines for processing messages in parallel +// default value is '1' +// do NOT enable concurrency value for in-order consumers, such as Kafka or FIFO SQS +func (cb *Builder) WithConcurrency(concurrency uint) *Builder { + log.Infof(propSetMSG, "concurrency", cb.name) + cb.concurrency = concurrency + return cb +} + // WithRetryWait specifies the duration for the component to wait between retries // default value is '0' // it will append an error to the builder if the value is smaller than '0'. @@ -121,6 +134,15 @@ func (cb *Builder) Create() (*Component, error) { failStrategy: cb.failStrategy, retries: int(cb.retries), retryWait: cb.retryWait, + concurrency: int(cb.concurrency), + jobs: make(chan Message), + jobErr: make(chan error), + } + + if cb.concurrency > 1 { + for w := 1; w <= c.concurrency; w++ { + go c.worker() + } } return c, nil @@ -145,11 +167,15 @@ func (c *Component) Run(ctx context.Context) error { } } + close(c.jobs) return err } func (c *Component) processing(ctx context.Context) error { cns, err := c.cf.Create() + if c.concurrency > 1 && !cns.OutOfOrder() { + return fmt.Errorf("async component creation: cannot create in-order component with concurrency > 1") + } if err != nil { return fmt.Errorf("failed to create consumer: %w", err) } @@ -169,7 +195,7 @@ func (c *Component) processing(ctx context.Context) error { select { case msg := <-chMsg: log.FromContext(msg.Context()).Debug("consumer received a new message") - err := c.processMessage(msg) + err := c.dispatchMessage(msg) if err != nil { return err } @@ -180,19 +206,37 @@ func (c *Component) processing(ctx context.Context) error { return cns.Close() case err := <-chErr: return fmt.Errorf("an error occurred during message consumption: %w", err) + case err := <-c.jobErr: + return fmt.Errorf("an error occurred during concurrent message consumption: %w", err) } } } +func (c *Component) dispatchMessage(msg Message) error { + if c.concurrency > 1 { + c.jobs <- msg + return nil + } + return c.processMessage(msg) +} + func (c *Component) processMessage(msg Message) error { err := c.proc(msg) if err != nil { return c.executeFailureStrategy(msg, err) } - return msg.Ack() } +func (c *Component) worker() { + for msg := range c.jobs { + err := c.processMessage(msg) + if err != nil { + c.jobErr <- err + } + } +} + var errInvalidFS = errors.New("invalid failure strategy") func (c *Component) executeFailureStrategy(msg Message, err error) error { diff --git a/component/async/component_test.go b/component/async/component_test.go index df6964384..8ecbb6e89 100644 --- a/component/async/component_test.go +++ b/component/async/component_test.go @@ -3,11 +3,12 @@ package async import ( "context" "errors" - "strings" + "sync" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNew(t *testing.T) { @@ -75,12 +76,13 @@ func TestNew(t *testing.T) { } type proxyBuilder struct { - proc mockProcessor - cnr mockConsumer - cf ConsumerFactory - fs FailStrategy - retries int - retryWait time.Duration + proc mockProcessor + cnr mockConsumer + cf ConsumerFactory + fs FailStrategy + retries int + retryWait time.Duration + concurrency uint } func run(ctx context.Context, t *testing.T, builder *proxyBuilder) error { @@ -92,11 +94,33 @@ func run(ctx context.Context, t *testing.T, builder *proxyBuilder) error { WithFailureStrategy(builder.fs). WithRetries(uint(builder.retries)). WithRetryWait(builder.retryWait). + WithConcurrency(builder.concurrency). Create() assert.NoError(t, err) return cmp.Run(ctx) } +// TestCreate_ReturnsError expects an error when concurrency > 1 and component does not allow out of order processing +func TestCreate_ReturnsError(t *testing.T) { + cnr := mockConsumer{} + builder := proxyBuilder{ + cnr: cnr, + cf: &mockConsumerFactory{c: &cnr}, + concurrency: 2, + } + cmp, err := New("test", builder.cf, builder.proc.Process). + WithFailureStrategy(builder.fs). + WithRetries(uint(builder.retries)). + WithRetryWait(builder.retryWait). + WithConcurrency(builder.concurrency). + Create() + require.NotNil(t, cmp) + require.NoError(t, err) + got := cmp.processing(context.Background()) + want := "async component creation: cannot create in-order component with concurrency > 1" + assert.EqualError(t, got, want) +} + // TestRun_ReturnsError expects a consumer consume Error func TestRun_ReturnsError(t *testing.T) { builder := proxyBuilder{ @@ -104,8 +128,7 @@ func TestRun_ReturnsError(t *testing.T) { } err := run(context.Background(), t, &builder) - assert.Error(t, err) - assert.True(t, strings.Contains(err.Error(), errConsumer.Error())) + assert.True(t, errors.Is(err, errConsumer)) assert.Equal(t, 0, builder.proc.execs) } @@ -144,7 +167,7 @@ func TestRun_Process_Error_NackExitStrategy(t *testing.T) { err := run(ctx, t, &builder) assert.Error(t, err) - assert.True(t, strings.Contains(err.Error(), errProcess.Error())) + assert.True(t, errors.Is(err, errProcess)) assert.Equal(t, 1, builder.proc.execs) } @@ -202,7 +225,31 @@ func TestRun_ProcessError_WithNackError(t *testing.T) { err := run(ctx, t, &builder) assert.Error(t, err) - assert.True(t, strings.Contains(err.Error(), errNack.Error())) + assert.True(t, errors.Is(err, errNack)) + assert.Equal(t, 1, builder.proc.execs) +} + +// TestRun_ParallelProcessError_WithNackError expects a PROC ERROR +// same as TestRun_ProcessError_WithNackError, just with concurrency +func TestRun_ParallelProcessError_WithNackError(t *testing.T) { + builder := proxyBuilder{ + proc: mockProcessor{errReturn: true}, + cnr: mockConsumer{ + chMsg: make(chan Message, 10), + chErr: make(chan error, 10), + outOfOrder: true, + }, + fs: NackStrategy, + concurrency: 10, + } + + ctx := context.Background() + builder.cnr.chMsg <- &mockMessage{ctx: ctx, nackError: true} + + err := run(ctx, t, &builder) + + assert.Error(t, err) + assert.True(t, errors.Is(err, errNack)) assert.Equal(t, 1, builder.proc.execs) } @@ -260,7 +307,7 @@ func TestRun_ProcessError_WithAckError(t *testing.T) { err := run(ctx, t, &builder) assert.Error(t, err) - assert.True(t, strings.Contains(err.Error(), errAck.Error())) + assert.True(t, errors.Is(err, errAck)) assert.Equal(t, 1, builder.proc.execs) } @@ -298,8 +345,7 @@ func TestRun_ConsumeError(t *testing.T) { builder.cnr.chErr <- errConsumer err := run(ctx, t, &builder) - assert.Error(t, err) - assert.True(t, strings.Contains(err.Error(), errConsumer.Error())) + assert.True(t, errors.Is(err, errConsumer)) assert.Equal(t, 0, builder.proc.execs) } @@ -398,7 +444,7 @@ func (mm *mockMessage) Context() context.Context { } // Decode is not called in our tests, because the mockProcessor will ignore the message decoding -func (mm *mockMessage) Decode(v interface{}) error { +func (mm *mockMessage) Decode(interface{}) error { return nil } @@ -430,19 +476,28 @@ func (mm *mockMessage) Payload() []byte { type mockProcessor struct { errReturn bool + mux sync.Mutex execs int } var errProcess = errors.New("PROC ERROR") -func (mp *mockProcessor) Process(msg Message) error { +func (mp *mockProcessor) Process(Message) error { + mp.mux.Lock() mp.execs++ + mp.mux.Unlock() if mp.errReturn { return errProcess } return nil } +func (mp *mockProcessor) GetExecs() int { + mp.mux.Lock() + defer mp.mux.Unlock() + return mp.execs +} + type mockConsumerFactory struct { c Consumer errRet bool @@ -464,9 +519,11 @@ type mockConsumer struct { clsError bool chMsg chan Message chErr chan error + outOfOrder bool } -func (mc *mockConsumer) SetTimeout(timeout time.Duration) { +func (mc *mockConsumer) OutOfOrder() bool { + return mc.outOfOrder } var errConsumer = errors.New("CONSUMER ERROR") diff --git a/component/async/kafka/group/group.go b/component/async/kafka/group/group.go index 7d89c77c4..dd2ea5ad7 100644 --- a/component/async/kafka/group/group.go +++ b/component/async/kafka/group/group.go @@ -45,6 +45,10 @@ func New(name, group string, topics, brokers []string, oo ...kafka.OptionFunc) ( return &Factory{name: name, group: group, topics: topics, brokers: brokers, oo: oo}, nil } +func (c *consumer) OutOfOrder() bool { + return false +} + // Create a new consumer. func (f *Factory) Create() (async.Consumer, error) { diff --git a/component/async/kafka/group/group_test.go b/component/async/kafka/group/group_test.go index efe869078..2568a6a63 100644 --- a/component/async/kafka/group/group_test.go +++ b/component/async/kafka/group/group_test.go @@ -7,11 +7,13 @@ import ( "time" "github.com/Shopify/sarama" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/beatlabs/patron/component/async" "github.com/beatlabs/patron/component/async/kafka" "github.com/beatlabs/patron/encoding" "github.com/beatlabs/patron/encoding/json" - "github.com/stretchr/testify/assert" ) func TestNew(t *testing.T) { @@ -121,12 +123,13 @@ func TestFactory_Create(t *testing.T) { assert.Nil(t, got) } else { assert.NoError(t, err) - assert.NotNil(t, got) + require.NotNil(t, got) consumer, ok := got.(*consumer) assert.True(t, ok, "consumer is not of type group.consumer") assert.Equal(t, tt.fields.brokers, consumer.config.Brokers) assert.Equal(t, tt.fields.topics, consumer.topics) assert.True(t, strings.HasSuffix(consumer.config.SaramaConfig.ClientID, tt.fields.clientName)) + assert.False(t, got.OutOfOrder()) } }) } @@ -154,12 +157,12 @@ type mockConsumerSession struct{} func (m *mockConsumerSession) Claims() map[string][]int32 { return nil } func (m *mockConsumerSession) MemberID() string { return "" } func (m *mockConsumerSession) GenerationID() int32 { return 0 } -func (m *mockConsumerSession) MarkOffset(topic string, partition int32, offset int64, metadata string) { +func (m *mockConsumerSession) MarkOffset(string, int32, int64, string) { } -func (m *mockConsumerSession) ResetOffset(topic string, partition int32, offset int64, metadata string) { +func (m *mockConsumerSession) ResetOffset(string, int32, int64, string) { } -func (m *mockConsumerSession) MarkMessage(msg *sarama.ConsumerMessage, metadata string) {} -func (m *mockConsumerSession) Context() context.Context { return context.Background() } +func (m *mockConsumerSession) MarkMessage(*sarama.ConsumerMessage, string) {} +func (m *mockConsumerSession) Context() context.Context { return context.Background() } func TestHandler_ConsumeClaim(t *testing.T) { diff --git a/component/async/kafka/simple/simple.go b/component/async/kafka/simple/simple.go index eaff4aad9..cf66ae099 100644 --- a/component/async/kafka/simple/simple.go +++ b/component/async/kafka/simple/simple.go @@ -40,6 +40,10 @@ func New(name, topic string, brokers []string, oo ...kafka.OptionFunc) (*Factory return &Factory{name: name, topic: topic, brokers: brokers, oo: oo}, nil } +func (c *consumer) OutOfOrder() bool { + return false +} + // Create a new consumer. func (f *Factory) Create() (async.Consumer, error) { diff --git a/component/async/kafka/simple/simple_test.go b/component/async/kafka/simple/simple_test.go index 7fb1e8695..477c5311b 100644 --- a/component/async/kafka/simple/simple_test.go +++ b/component/async/kafka/simple/simple_test.go @@ -5,6 +5,7 @@ import ( "github.com/beatlabs/patron/component/async/kafka" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNew(t *testing.T) { @@ -91,7 +92,8 @@ func TestFactory_Create(t *testing.T) { assert.Nil(t, got) } else { assert.NoError(t, err) - assert.NotNil(t, got) + require.NotNil(t, got) + assert.False(t, got.OutOfOrder()) } }) } diff --git a/component/async/sqs/option.go b/component/async/sqs/option.go index e0650117f..acc0d74dc 100644 --- a/component/async/sqs/option.go +++ b/component/async/sqs/option.go @@ -13,53 +13,49 @@ type OptionFunc func(*Factory) error // MaxMessages option for setting the max number of messages fetched. // Allowed values are between 1 and 10. +// If this option is unused, it defaults to 3. +// If messages can be processed very quickly, maxing out this value is fine, otherwise having a high value is risky as it might trigger the visibility timeout. +// Having a value too small isn't recommended either, as it increases the number of SQS API requests, thus AWS costs. func MaxMessages(maxMessages int64) OptionFunc { return func(f *Factory) error { if maxMessages <= 0 || maxMessages > 10 { return errors.New("max messages should be between 1 and 10") } - f.maxMessages = maxMessages + f.maxMessages = &maxMessages return nil } } // PollWaitSeconds sets the wait time for the long polling mechanism in seconds. // Allowed values are between 0 and 20. 0 enables short polling. +// If this option is unused, it defaults to the queue's default poll settings. func PollWaitSeconds(pollWaitSeconds int64) OptionFunc { return func(f *Factory) error { if pollWaitSeconds < 0 || pollWaitSeconds > 20 { return errors.New("poll wait seconds should be between 0 and 20") } - f.pollWaitSeconds = pollWaitSeconds + f.pollWaitSeconds = &pollWaitSeconds return nil } } // VisibilityTimeout sets the time a message is invisible after it has been requested. +// This is a built in resiliency mechanism so that, should the consumer fail to acknowledge the message within such timeout, +// it will become visible again and thus available for retries. // Allowed values are between 0 and and 12 hours in seconds. +// If this option is unused, it defaults to the queue's default visibility settings. func VisibilityTimeout(visibilityTimeout int64) OptionFunc { return func(f *Factory) error { if visibilityTimeout < 0 || visibilityTimeout > twelveHoursInSeconds { return fmt.Errorf("visibility timeout should be between 0 and %d seconds", twelveHoursInSeconds) } - f.visibilityTimeout = visibilityTimeout - return nil - } -} - -// Buffer sets the concurrency of the messages processing. -// 0 means wait for the previous messages to be processed. -func Buffer(buffer int) OptionFunc { - return func(f *Factory) error { - if buffer < 0 { - return errors.New("buffer should be greater or equal to zero") - } - f.buffer = buffer + f.visibilityTimeout = &visibilityTimeout return nil } } // QueueStatsInterval sets the interval at which we retrieve queue stats. +// If this option is unused, it defaults to 20 seconds. func QueueStatsInterval(interval time.Duration) OptionFunc { return func(f *Factory) error { if interval == 0 { diff --git a/component/async/sqs/option_test.go b/component/async/sqs/option_test.go index 17035e213..a27aa4cd1 100644 --- a/component/async/sqs/option_test.go +++ b/component/async/sqs/option_test.go @@ -4,27 +4,28 @@ import ( "testing" "time" + "github.com/aws/aws-sdk-go/aws" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestMaxMessages(t *testing.T) { type args struct { - maxMessages int64 + maxMessages *int64 } tests := map[string]struct { args args expectedErr string }{ "success": { - args: args{maxMessages: 5}, + args: args{maxMessages: aws.Int64(5)}, }, "zero message size": { - args: args{maxMessages: 0}, + args: args{maxMessages: aws.Int64(0)}, expectedErr: "max messages should be between 1 and 10", }, "over max message size": { - args: args{maxMessages: 11}, + args: args{maxMessages: aws.Int64(11)}, expectedErr: "max messages should be between 1 and 10", }, } @@ -32,7 +33,7 @@ func TestMaxMessages(t *testing.T) { t.Run(name, func(t *testing.T) { f, err := NewFactory(&stubQueue{}, "queue") require.NoError(t, err) - err = MaxMessages(tt.args.maxMessages)(f) + err = MaxMessages(*tt.args.maxMessages)(f) if tt.expectedErr != "" { assert.EqualError(t, err, tt.expectedErr) } else { @@ -45,21 +46,21 @@ func TestMaxMessages(t *testing.T) { func TestPollWaitSeconds(t *testing.T) { type args struct { - waitSeconds int64 + waitSeconds *int64 } tests := map[string]struct { args args expectedErr string }{ "success": { - args: args{waitSeconds: 5}, + args: args{waitSeconds: aws.Int64(5)}, }, "negative message size": { - args: args{waitSeconds: -1}, + args: args{waitSeconds: aws.Int64(-1)}, expectedErr: "poll wait seconds should be between 0 and 20", }, "over max wait seconds": { - args: args{waitSeconds: 21}, + args: args{waitSeconds: aws.Int64(21)}, expectedErr: "poll wait seconds should be between 0 and 20", }, } @@ -67,7 +68,7 @@ func TestPollWaitSeconds(t *testing.T) { t.Run(name, func(t *testing.T) { f, err := NewFactory(&stubQueue{}, "queue") require.NoError(t, err) - err = PollWaitSeconds(tt.args.waitSeconds)(f) + err = PollWaitSeconds(*tt.args.waitSeconds)(f) if tt.expectedErr != "" { assert.EqualError(t, err, tt.expectedErr) } else { @@ -80,21 +81,21 @@ func TestPollWaitSeconds(t *testing.T) { func TestVisibilityTimeout(t *testing.T) { type args struct { - timeout int64 + timeout *int64 } tests := map[string]struct { args args expectedErr string }{ "success": { - args: args{timeout: 5}, + args: args{timeout: aws.Int64(5)}, }, "negative message size": { - args: args{timeout: -1}, + args: args{timeout: aws.Int64(-1)}, expectedErr: "visibility timeout should be between 0 and 43200 seconds", }, "over max wait seconds": { - args: args{timeout: twelveHoursInSeconds + 1}, + args: args{timeout: aws.Int64(twelveHoursInSeconds + 1)}, expectedErr: "visibility timeout should be between 0 and 43200 seconds", }, } @@ -102,7 +103,7 @@ func TestVisibilityTimeout(t *testing.T) { t.Run(name, func(t *testing.T) { f, err := NewFactory(&stubQueue{}, "queue") require.NoError(t, err) - err = VisibilityTimeout(tt.args.timeout)(f) + err = VisibilityTimeout(*tt.args.timeout)(f) if tt.expectedErr != "" { assert.EqualError(t, err, tt.expectedErr) } else { @@ -113,37 +114,6 @@ func TestVisibilityTimeout(t *testing.T) { } } -func TestBuffer(t *testing.T) { - type args struct { - buffer int - } - tests := map[string]struct { - args args - expectedErr string - }{ - "success": { - args: args{buffer: 5}, - }, - "negative message size": { - args: args{buffer: -1}, - expectedErr: "buffer should be greater or equal to zero", - }, - } - for name, tt := range tests { - t.Run(name, func(t *testing.T) { - f, err := NewFactory(&stubQueue{}, "queue") - require.NoError(t, err) - err = Buffer(tt.args.buffer)(f) - if tt.expectedErr != "" { - assert.EqualError(t, err, tt.expectedErr) - } else { - assert.NoError(t, err) - assert.Equal(t, f.buffer, tt.args.buffer) - } - }) - } -} - func TestQueueStatsInterval(t *testing.T) { type args struct { interval time.Duration diff --git a/component/async/sqs/sqs.go b/component/async/sqs/sqs.go index 958612e00..4dc522c07 100644 --- a/component/async/sqs/sqs.go +++ b/component/async/sqs/sqs.go @@ -135,9 +135,9 @@ type Factory struct { queueName string queue sqsiface.SQSAPI queueURL string - maxMessages int64 - pollWaitSeconds int64 - visibilityTimeout int64 + maxMessages *int64 + pollWaitSeconds *int64 + visibilityTimeout *int64 buffer int statsInterval time.Duration } @@ -160,14 +160,12 @@ func NewFactory(queue sqsiface.SQSAPI, queueName string, oo ...OptionFunc) (*Fac } f := &Factory{ - queueName: queueName, - queueURL: *url.QueueUrl, - queue: queue, - maxMessages: 10, - pollWaitSeconds: 20, - visibilityTimeout: 30, - buffer: 0, - statsInterval: 10 * time.Second, + queueName: queueName, + queueURL: *url.QueueUrl, + queue: queue, + maxMessages: aws.Int64(3), + buffer: 0, + statsInterval: 10 * time.Second, } for _, o := range oo { @@ -188,7 +186,6 @@ func (f *Factory) Create() (async.Consumer, error) { queueURL: f.queueURL, maxMessages: f.maxMessages, pollWaitSeconds: f.pollWaitSeconds, - buffer: f.buffer, visibilityTimeout: f.visibilityTimeout, statsInterval: f.statsInterval, }, nil @@ -198,18 +195,21 @@ type consumer struct { queueName string queueURL string queue sqsiface.SQSAPI - maxMessages int64 - pollWaitSeconds int64 - visibilityTimeout int64 - buffer int + maxMessages *int64 + pollWaitSeconds *int64 + visibilityTimeout *int64 statsInterval time.Duration cnl context.CancelFunc } +func (c *consumer) OutOfOrder() bool { + return true +} + // Consume messages from SQS and send them to the channel. func (c *consumer) Consume(ctx context.Context) (<-chan async.Message, <-chan error, error) { - chMsg := make(chan async.Message, c.buffer) - chErr := make(chan error, c.buffer) + chMsg := make(chan async.Message) + chErr := make(chan error) sqsCtx, cnl := context.WithCancel(ctx) c.cnl = cnl @@ -218,12 +218,12 @@ func (c *consumer) Consume(ctx context.Context) (<-chan async.Message, <-chan er if sqsCtx.Err() != nil { return } - log.Debugf("polling SQS queue %s for messages", c.queueName) + log.Debugf("consume: polling SQS queue %s for %d messages", c.queueName, *c.maxMessages) output, err := c.queue.ReceiveMessageWithContext(sqsCtx, &sqs.ReceiveMessageInput{ - QueueUrl: aws.String(c.queueURL), - MaxNumberOfMessages: aws.Int64(c.maxMessages), - WaitTimeSeconds: aws.Int64(c.pollWaitSeconds), - VisibilityTimeout: aws.Int64(c.visibilityTimeout), + QueueUrl: &c.queueURL, + MaxNumberOfMessages: c.maxMessages, + WaitTimeSeconds: c.pollWaitSeconds, + VisibilityTimeout: c.visibilityTimeout, AttributeNames: aws.StringSlice([]string{ sqsAttributeSentTimestamp, }), @@ -239,6 +239,7 @@ func (c *consumer) Consume(ctx context.Context) (<-chan async.Message, <-chan er return } + log.Debugf("Consume: received %d messages", len(output.Messages)) messageCountInc(c.queueName, fetchedMessageState, len(output.Messages)) for _, msg := range output.Messages { diff --git a/component/async/sqs/sqs_test.go b/component/async/sqs/sqs_test.go index 91f6b9dd9..f461a4a19 100644 --- a/component/async/sqs/sqs_test.go +++ b/component/async/sqs/sqs_test.go @@ -92,12 +92,12 @@ func TestFactory_Create(t *testing.T) { assert.NotNil(t, cons.queue) assert.Equal(t, "queueName", cons.queueName) assert.Equal(t, "URL", cons.queueURL) - assert.Equal(t, int64(10), cons.maxMessages) - assert.Equal(t, int64(20), cons.pollWaitSeconds) - assert.Equal(t, int64(30), cons.visibilityTimeout) - assert.Equal(t, 0, cons.buffer) + assert.Equal(t, aws.Int64(3), cons.maxMessages) + assert.Nil(t, cons.pollWaitSeconds) + assert.Nil(t, cons.visibilityTimeout) assert.Equal(t, 10*time.Second, cons.statsInterval) assert.Nil(t, cons.cnl) + assert.True(t, cons.OutOfOrder()) } func Test_consumer_Consume(t *testing.T) { diff --git a/examples/sqs-simple/main.go b/examples/sqs-simple/main.go new file mode 100644 index 000000000..7196d2617 --- /dev/null +++ b/examples/sqs-simple/main.go @@ -0,0 +1,108 @@ +package main + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/beatlabs/patron" + "github.com/beatlabs/patron/component/async" + patronsqs "github.com/beatlabs/patron/component/async/sqs" + "github.com/beatlabs/patron/log" +) + +type sqsConfig struct { + endpoint string + name string + region string +} + +// Make sure localstack is running locally, or point to actual queue on AWS +var sampleConfig = sqsConfig{ + endpoint: "http://localhost:4566", + name: "sandbox-payin", + region: "eu-west-1", +} + +func init() { + err := os.Setenv("PATRON_LOG_LEVEL", "debug") + if err != nil { + fmt.Printf("failed to set log level env var: %v", err) + os.Exit(1) + } +} + +func main() { + name := "sqs" + version := "1.0.0" + + service, err := patron.New(name, version) + if err != nil { + fmt.Printf("failed to set up service: %v", err) + os.Exit(1) + } + ctx := context.Background() + + sqsComponent, err := sampleSqs() + if err != nil { + log.Fatalf("failed to create sqs component: %v", err) + } + + err = service.WithComponents(sqsComponent).Run(ctx) + if err != nil { + log.Fatalf("failed to create and run service: %v", err) + } +} + +func sampleSqs() (*async.Component, error) { + sess, err := session.NewSession(&aws.Config{ + Endpoint: &sampleConfig.endpoint, + Region: &sampleConfig.region, + }) + if err != nil { + return nil, err + } + sqsClient := sqs.New(sess) + + factory, err := patronsqs.NewFactory( + sqsClient, + sampleConfig.name, + // Optionally override the queue's default polling setting. + // Long polling is highly recommended to avoid large costs on AWS. + // See https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-short-and-long-polling.html + // It's probably best to not specify any value: the default value on the queue will be used. + patronsqs.PollWaitSeconds(20), + // Optionally override the queue's default visibility timeout. + // See https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-visibility-timeout.html + // Again, a sensible default should be configured on the queue, but there might be specific use case where you want to override. + patronsqs.VisibilityTimeout(30), + // Optionally change the number of messages fetched by each worker. + // The default is 3. + patronsqs.MaxMessages(5), + ) + if err != nil { + return nil, err + } + + // Note: the retry count is not increased on an error processing a message, but rather consuming from the queue. + // If the max number if retries is reached, the service will terminate. + // The max number of retires of a message is determined by the SQS queue, not the consumer. + return async.New("sqs", factory, messageHandler). + // Note that NackExitStrategy does not work with concurrency, so we need to pick either Nack or Ack Strategy + // Ack strategy is not recommended for SQS: we want failed messages to end up in the dead letter queue + WithFailureStrategy(async.NackStrategy). + WithRetries(3). + WithRetryWait(30 * time.Second). + WithConcurrency(10). + Create() +} + +func messageHandler(message async.Message) error { + log.Info("Received message, payload:", string(message.Payload())) + time.Sleep(3 * time.Second) // useful to see concurrency in action + return nil +} diff --git a/test/docker/aws/consumer_integration_test.go b/test/docker/aws/consumer_integration_test.go index 1e17694d9..9bd93c12f 100644 --- a/test/docker/aws/consumer_integration_test.go +++ b/test/docker/aws/consumer_integration_test.go @@ -11,7 +11,7 @@ import ( patronSQS "github.com/beatlabs/patron/client/sqs" sqsConsumer "github.com/beatlabs/patron/component/async/sqs" "github.com/beatlabs/patron/correlation" - opentracing "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" "github.com/opentracing/opentracing-go/mocktracer" "github.com/stretchr/testify/assert" @@ -37,7 +37,13 @@ func Test_SQS_Consume(t *testing.T) { defer mtr.Reset() opentracing.SetGlobalTracer(mtr) - factory, err := sqsConsumer.NewFactory(api, queueName) + factory, err := sqsConsumer.NewFactory( + api, + queueName, + sqsConsumer.MaxMessages(10), + sqsConsumer.PollWaitSeconds(20), + sqsConsumer.VisibilityTimeout(30), + ) require.NoError(t, err) cns, err := factory.Create() require.NoError(t, err) @@ -70,7 +76,9 @@ func Test_SQS_Consume(t *testing.T) { } }() - assert.Equal(t, sent, <-chReceived) + received := <-chReceived + + assert.ElementsMatch(t, sent, received) assert.Len(t, mtr.FinishedSpans(), 3) for _, span := range mtr.FinishedSpans() {