Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: timeout and shutdown #121

Merged
merged 2 commits into from Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
56 changes: 22 additions & 34 deletions main.go
Expand Up @@ -3,8 +3,10 @@ package main
import (
"bufio"
"context"
"errors"
"flag"
"fmt"
"io"
"log"
"net"
"net/url"
Expand Down Expand Up @@ -135,8 +137,9 @@ func main() {

deathNote := sync.Map{}

connectionAccepted := make(chan net.Addr)
connectionLost := make(chan net.Addr)
// Buffered so we don't block the main process.
connectionAccepted := make(chan net.Addr, 10)
connectionLost := make(chan net.Addr, 10)

go processRequests(&deathNote, connectionAccepted, connectionLost)

Expand Down Expand Up @@ -206,8 +209,10 @@ func processRequests(deathNote *sync.Map, connectionAccepted chan<- net.Addr, co
}

if err != nil {
log.Println(err)
break
if !errors.Is(err, io.EOF) {
log.Println(err)
}
return
}
}
}(conn)
Expand All @@ -216,45 +221,28 @@ func processRequests(deathNote *sync.Map, connectionAccepted chan<- net.Addr, co

func waitForPruneCondition(ctx context.Context, connectionAccepted <-chan net.Addr, connectionLost <-chan net.Addr) {
connectionCount := 0
never := make(chan time.Time, 1)
defer close(never)

handleConnectionAccepted := func(addr net.Addr) {
log.Printf("New client connected: %s", addr)
connectionCount++
}

select {
case <-time.After(connectionTimeout):
panic("Timed out waiting for the first connection")
case addr := <-connectionAccepted:
handleConnectionAccepted(addr)
case <-ctx.Done():
log.Println("Signal received")
return
}

timer := time.NewTimer(connectionTimeout)
for {
var noConnectionTimeout <-chan time.Time
if connectionCount == 0 {
noConnectionTimeout = time.After(reconnectionTimeout)
} else {
noConnectionTimeout = never
}

select {
case addr := <-connectionAccepted:
handleConnectionAccepted(addr)
break
log.Printf("New client connected: %s", addr)
connectionCount++
if connectionCount == 1 {
if !timer.Stop() {
<-timer.C
}
}
case addr := <-connectionLost:
log.Printf("Client disconnected: %s", addr.String())
connectionCount--
break
if connectionCount == 0 {
timer.Reset(reconnectionTimeout)
}
case <-ctx.Done():
log.Println("Signal received")
return
case <-noConnectionTimeout:
log.Println("Timed out waiting for re-connection")
case <-timer.C:
log.Println("Timeout waiting for connection")
return
}
}
Expand Down
21 changes: 11 additions & 10 deletions main_test.go
Expand Up @@ -6,10 +6,10 @@ import (
"context"
"fmt"
"io"
"log"
"net"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -64,26 +64,27 @@ func TestInitialTimeout(t *testing.T) {
// reset connectionTimeout
connectionTimeout = testConnectionTimeout

origWriter := log.Default().Writer()
defer func() {
log.SetOutput(origWriter)
}()
var buf bytes.Buffer
log.SetOutput(&buf)

acc := make(chan net.Addr)
lost := make(chan net.Addr)

done := make(chan string)

go func() {
defer func() {
err := recover().(string)
done <- err
}()
waitForPruneCondition(context.Background(), acc, lost)
done <- buf.String()
}()

select {
case p := <-done:
if !strings.Contains(p, "first connection") {
t.Fail()
}
require.Contains(t, p, "Timeout waiting for connection")
case <-time.After(7 * time.Second):
t.Fail()
t.Fatal("Timeout waiting prune condition")
}
}

Expand Down