Skip to content

Commit

Permalink
Merge pull request #312 from cschleiden/coroutine-defer
Browse files Browse the repository at this point in the history
Do not deadlock with deferred Yield
  • Loading branch information
cschleiden committed Jan 25, 2024
2 parents 1375fc2 + e6df955 commit 2e4b887
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 17 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ require (
go.opentelemetry.io/proto/otlp v0.19.0 // indirect
go.tmz.dev/musttag v0.7.2 // indirect
go.uber.org/atomic v1.7.0 // indirect
go.uber.org/goleak v1.3.0 // indirect
go.uber.org/multierr v1.6.0 // indirect
go.uber.org/zap v1.24.0 // indirect
golang.org/x/exp v0.0.0-20230510235704-dd950f8aeaea // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,8 @@ go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI=
go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60=
Expand Down
51 changes: 34 additions & 17 deletions internal/sync/coroutine.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package sync

import (
"errors"
"fmt"
"io"
"log"
"runtime"
"sync/atomic"
"time"
)

const DeadlockDetection = 40 * time.Second

var ErrCoroutineAlreadyFinished = errors.New("coroutine already finished")

type CoroutineCreator interface {
NewCoroutine(ctx Context, fn func(Context) error)
}
Expand Down Expand Up @@ -53,7 +54,8 @@ type coState struct {

err error

logger logger
// logger logger
// idx int

deadlockDetection time.Duration

Expand All @@ -68,6 +70,11 @@ func NewCoroutine(ctx Context, fn func(ctx Context) error) Coroutine {
defer s.finish() // Ensure we always mark the coroutine as finished
defer func() {
if r := recover(); r != nil {
if err, ok := r.(error); ok && errors.Is(err, ErrCoroutineAlreadyFinished) {
// Ignore this specific error
return
}

s.err = fmt.Errorf("panic: %v", r)
}
}()
Expand All @@ -86,21 +93,26 @@ func NewCoroutine(ctx Context, fn func(ctx Context) error) Coroutine {
func newState() *coState {
// i++

return &coState{
c := &coState{
blocking: make(chan bool, 1),
unblock: make(chan bool),
// Only used while debugging issues, default to discarding log messages
logger: log.New(io.Discard, "[co]", log.LstdFlags),
// logger: log.New(os.Stderr, fmt.Sprintf("[co %v]", i), log.Lmsgprefix|log.Ltime),
// idx: i,
deadlockDetection: DeadlockDetection,
}

// Start out as blocked
c.blocked.Store(true)

return c
}

func (s *coState) finish() {
s.finished.Store(true)
s.blocking <- true

s.logger.Println("finish")
// s.logger.Println("finish")
}

func (s *coState) SetCoroutineCreator(creator CoroutineCreator) {
Expand Down Expand Up @@ -136,23 +148,28 @@ func (s *coState) Yield() {
}

func (s *coState) yield(markBlocking bool) {
s.logger.Println("yielding")

s.blocked.Store(true)
// s.logger.Println("yielding")

if markBlocking {
if s.shouldExit.Load() != nil {
// s.logger.Println("yielding, but should exit")
panic(ErrCoroutineAlreadyFinished)
}

s.blocked.Store(true)

s.blocking <- true
}

s.logger.Println("yielded")
// s.logger.Println("yielded")

// Wait for the next Execute() call
<-s.unblock

// Once we're here, another Execute() call has been made. s.blocking is empty

if s.shouldExit.Load() != nil {
s.logger.Println("exiting")
// s.logger.Println("exiting")

// Goexit runs all deferred functions, which includes calling finish() in the main
// execution function. That marks the coroutine as finished and blocking.
Expand All @@ -161,37 +178,37 @@ func (s *coState) yield(markBlocking bool) {

s.blocked.Store(false)

s.logger.Println("done yielding, continuing")
// s.logger.Println("done yielding, continuing")
}

func (s *coState) Execute() {
s.ResetProgress()

if s.Finished() {
s.logger.Println("execute: already finished")
// s.logger.Println("execute: already finished")
return
}

t := time.NewTimer(s.deadlockDetection)
defer t.Stop()

s.logger.Println("execute: unblocking")
// s.logger.Println("execute: unblocking")
s.unblock <- true
s.logger.Println("execute: unblocked")
// s.logger.Println("execute: unblocked")

runtime.Gosched()

// Run until blocked (which is also true when finished)
select {
case <-s.blocking:
s.logger.Println("execute: blocked")
// s.logger.Println("execute: blocked")
case <-t.C:
panic("coroutine timed out")
}
}

func (s *coState) Exit() {
s.logger.Println("exit")
// s.logger.Println("exit")

if s.Finished() {
return
Expand Down
45 changes: 45 additions & 0 deletions internal/sync/coroutine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,24 @@ func Test_Coroutine_ContinueAndBlock(t *testing.T) {
require.True(t, reached)
}

func Test_Coroutine_Exit_Before_Yield(t *testing.T) {
c := NewCoroutine(Background(), func(ctx Context) error {
s := getCoState(ctx)

s.Yield()

require.FailNow(t, "should not reach this")

return nil
})

r := runtime.NumGoroutine()
c.Exit()

require.True(t, c.Finished())
require.Equal(t, r-1, runtime.NumGoroutine())
}

func Test_Coroutine_Exit(t *testing.T) {
c := NewCoroutine(Background(), func(ctx Context) error {
s := getCoState(ctx)
Expand All @@ -120,10 +138,37 @@ func Test_Coroutine_Exit(t *testing.T) {
return nil
})

c.Execute()

r := runtime.NumGoroutine()
c.Exit()

require.True(t, c.Finished())
require.Equal(t, r-1, runtime.NumGoroutine())
}

func Test_Coroutine_Exit_with_defer(t *testing.T) {
c := NewCoroutine(Background(), func(ctx Context) error {
s := getCoState(ctx)

defer func() {
s.Yield()
}()

s.Yield()

require.FailNow(t, "should not reach this")

return nil
})

c.Execute()

r := runtime.NumGoroutine()
c.Exit()

require.True(t, c.Finished())
require.NoError(t, c.Error())
require.Equal(t, r-1, runtime.NumGoroutine())
}

Expand Down
34 changes: 34 additions & 0 deletions internal/workflow/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace"
"go.uber.org/goleak"
)

type testHistoryProvider struct {
Expand Down Expand Up @@ -638,6 +639,37 @@ func Test_Executor(t *testing.T) {
require.Equal(t, goRoutines, runtime.NumGoroutine())
},
},
{
name: "Close_removes_any_goroutines_defer",
f: func(t *testing.T, r *registry.Registry, e *executor, i *core.WorkflowInstance, hp *testHistoryProvider) {
wf := func(ctx sync.Context) error {
defer func() {
_, err := wf.SignalWorkflow[any](ctx, "some-id", "signal", nil).Get(ctx)
if err != nil {
panic(err)
}
}()

c := wf.NewSignalChannel[int](ctx, "signal")

// Block workflow
c.Receive(ctx)

return nil
}

r.RegisterWorkflow(wf)

task := startWorkflowTask(i.InstanceID, wf)

_, err := e.ExecuteTask(context.Background(), task)
require.NoError(t, err)

e.Close()

goleak.VerifyNone(t)
},
},
{
name: "Close_removes_any_goroutines_nested",
f: func(t *testing.T, r *registry.Registry, e *executor, i *core.WorkflowInstance, hp *testHistoryProvider) {
Expand Down Expand Up @@ -683,6 +715,8 @@ func Test_Executor(t *testing.T) {
e, err := newExecutor(r, i, hp)
require.NoError(t, err)
tt.f(t, r, e, i, hp)

e.Close()
})
}
}
Expand Down

0 comments on commit 2e4b887

Please sign in to comment.