Skip to content

Commit

Permalink
Add support to disable read batch
Browse files Browse the repository at this point in the history
Add support to disable read batch
  • Loading branch information
cnderrauber committed Aug 31, 2023
1 parent 46a4b2c commit a788a32
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 20 deletions.
17 changes: 13 additions & 4 deletions udp/batchconn.go
Expand Up @@ -4,6 +4,7 @@
package udp

import (
"io"
"net"
"runtime"
"sync"
Expand All @@ -28,6 +29,7 @@ type BatchReader interface {
type BatchPacketConn interface {
BatchWriter
BatchReader
io.Closer
}

// BatchConn uses ipv4/v6.NewPacketConn to wrap a net.PacketConn to write/read messages in batch,
Expand All @@ -48,7 +50,7 @@ type BatchConn struct {
closed atomic.Bool
}

// NewBatchConn creates a *BatchCon from net.PacketConn with batch configs.
// NewBatchConn creates a *BatchConn from net.PacketConn with batch configs.
func NewBatchConn(conn net.PacketConn, batchWriteSize int, batchWriteInterval time.Duration) *BatchConn {
bc := &BatchConn{
PacketConn: conn,
Expand Down Expand Up @@ -92,6 +94,14 @@ func NewBatchConn(conn net.PacketConn, batchWriteSize int, batchWriteInterval ti
// Close batchConn and the underlying PacketConn
func (c *BatchConn) Close() error {
c.closed.Store(true)
c.batchWriteMutex.Lock()
if c.batchWritePos > 0 {
_ = c.flush()
}
c.batchWriteMutex.Unlock()
if c.batchConn != nil {
return c.batchConn.Close()
}
return c.PacketConn.Close()
}

Expand All @@ -100,15 +110,14 @@ func (c *BatchConn) WriteTo(b []byte, addr net.Addr) (int, error) {
if c.batchConn == nil {
return c.PacketConn.WriteTo(b, addr)
}
return c.writeBatch(b, addr)
return c.enqueueMessage(b, addr)
}

func (c *BatchConn) writeBatch(buf []byte, raddr net.Addr) (int, error) {
func (c *BatchConn) enqueueMessage(buf []byte, raddr net.Addr) (int, error) {
var err error
c.batchWriteMutex.Lock()
defer c.batchWriteMutex.Unlock()

// c.writeCounter++
msg := &c.batchWriteMessages[c.batchWritePos]
// reset buffers
msg.Buffers = msg.Buffers[:1]
Expand Down
7 changes: 4 additions & 3 deletions udp/conn.go
Expand Up @@ -118,7 +118,8 @@ func (l *listener) Addr() net.Addr {
// it will use ReadBatch/WriteBatch to improve throughput for UDP.
type BatchIOConfig struct {
Enable bool
// ReadBatchSize indicates the maximum number of packets to be read in one batch
// ReadBatchSize indicates the maximum number of packets to be read in one batch, a batch size less than 2 means
// disable read batch.
ReadBatchSize int
// WriteBatchSize indicates the maximum number of packets to be written in one batch
WriteBatchSize int
Expand Down Expand Up @@ -158,7 +159,7 @@ func (lc *ListenConfig) Listen(network string, laddr *net.UDPAddr) (net.Listener
lc.Backlog = defaultListenBacklog
}

if lc.Batch.Enable && (lc.Batch.ReadBatchSize <= 0 || lc.Batch.WriteBatchSize <= 0 || lc.Batch.WriteBatchInterval <= 0) {
if lc.Batch.Enable && (lc.Batch.WriteBatchSize <= 0 || lc.Batch.WriteBatchInterval <= 0) {
return nil, ErrInvalidBatchConfig
}

Expand Down Expand Up @@ -218,7 +219,7 @@ func (l *listener) readLoop() {
defer l.readWG.Done()
defer close(l.readDoneCh)

if br, ok := l.pConn.(BatchReader); ok {
if br, ok := l.pConn.(BatchReader); ok && l.readBatchSize > 1 {
l.readBatch(br)
} else {
l.read()
Expand Down
53 changes: 40 additions & 13 deletions udp/conn_test.go
Expand Up @@ -13,6 +13,7 @@ import (
"io"
"net"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -488,6 +489,8 @@ func TestBatchIO(t *testing.T) {
WriteBatchSize: 3,
WriteBatchInterval: 5 * time.Millisecond,
},
ReadBufferSize: 64 * 1024,
WriteBufferSize: 64 * 1024,
}

laddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 15678}
Expand All @@ -496,30 +499,54 @@ func TestBatchIO(t *testing.T) {
t.Fatal(err)
}

acceptQc := make(chan struct{})
var serverConnWg sync.WaitGroup
serverConnWg.Add(1)
go func() {
defer close(acceptQc)
var exit atomic.Bool
defer func() {
defer serverConnWg.Done()
exit.Store(true)
}()
for {
buf := make([]byte, 1400)
conn, err := listener.Accept()
if errors.Is(err, ErrClosedListener) {
conn, lerr := listener.Accept()
if errors.Is(lerr, ErrClosedListener) {
break
}
assert.NoError(t, err)
assert.NoError(t, lerr)
serverConnWg.Add(1)
go func() {
defer func() { _ = conn.Close() }()
for {
n, err := conn.Read(buf)
assert.NoError(t, err)
_, err = conn.Write(buf[:n])
assert.NoError(t, err)
defer func() {
_ = conn.Close()
serverConnWg.Done()
}()
for !exit.Load() {
_ = conn.SetReadDeadline(time.Now().Add(time.Second))
n, rerr := conn.Read(buf)
if rerr != nil {
assert.ErrorContains(t, rerr, "timeout")
} else {
_, rerr = conn.Write(buf[:n])
assert.NoError(t, rerr)
}
}
}()
}
}()

raddr, _ := listener.Addr().(*net.UDPAddr)

// test flush by WriteBatchInterval expired
readBuf := make([]byte, 1400)
cli, err := net.DialUDP("udp", nil, raddr)
assert.NoError(t, err)
flushStr := "flushbytimer"
_, err = cli.Write([]byte("flushbytimer"))
assert.NoError(t, err)
n, err := cli.Read(readBuf)
assert.NoError(t, err)
assert.Equal(t, flushStr, string(readBuf[:n]))

wgs := sync.WaitGroup{}
cc := 3
wgs.Add(cc)
Expand All @@ -532,7 +559,7 @@ func TestBatchIO(t *testing.T) {
client, err := net.DialUDP("udp", nil, raddr)
assert.NoError(t, err)
defer func() { _ = client.Close() }()
for i := 0; i < 100; i++ {
for i := 0; i < 1; i++ {
_, err := client.Write([]byte(sendStr))
assert.NoError(t, err)
err = client.SetReadDeadline(time.Now().Add(time.Second))
Expand All @@ -546,5 +573,5 @@ func TestBatchIO(t *testing.T) {
wgs.Wait()

_ = listener.Close()
<-acceptQc
serverConnWg.Wait()
}

0 comments on commit a788a32

Please sign in to comment.