Skip to content

Commit

Permalink
Make our own low-level driver for testing (#4193)
Browse files Browse the repository at this point in the history
Closes #4146.
  • Loading branch information
noisersup committed Mar 22, 2024
1 parent 87e7082 commit 445c9b3
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 6 deletions.
2 changes: 2 additions & 0 deletions .golangci.yml
Expand Up @@ -36,6 +36,7 @@ linters-settings:
files:
- $all
- "!**/internal/bson/*_test.go"
- "!**/internal/driver/*.go"
- "!**/internal/util/testutil/*.go"
- "!**/internal/wire/*.go"
deny:
Expand All @@ -44,6 +45,7 @@ linters-settings:
files:
- $all
- "!**/internal/bson/*.go"
- "!**/internal/driver/*.go"
deny:
- pkg: github.com/cristalhq/bson
- pkg: github.com/cristalhq/bson/bsonproto
Expand Down
73 changes: 69 additions & 4 deletions internal/driver/driver.go
Expand Up @@ -22,6 +22,7 @@ import (
"log/slog"
"net"
"net/url"
"sync/atomic"

"github.com/FerretDB/FerretDB/internal/util/lazyerrors"
"github.com/FerretDB/FerretDB/internal/wire"
Expand Down Expand Up @@ -120,7 +121,7 @@ func (c *Conn) Read() (*wire.MsgHeader, wire.MsgBody, error) {
c.l.Debug(
fmt.Sprintf("<<<\n%s", body.String()),
slog.Int("length", int(header.MessageLength)),
slog.Int("id", int(header.ResponseTo)),
slog.Int("id", int(header.RequestID)),
slog.Int("response_to", int(header.ResponseTo)),
slog.String("opcode", header.OpCode.String()),
)
Expand All @@ -133,7 +134,7 @@ func (c *Conn) Write(header *wire.MsgHeader, body wire.MsgBody) error {
c.l.Debug(
fmt.Sprintf(">>>\n%s", body.String()),
slog.Int("length", int(header.MessageLength)),
slog.Int("id", int(header.ResponseTo)),
slog.Int("id", int(header.RequestID)),
slog.Int("response_to", int(header.ResponseTo)),
slog.String("opcode", header.OpCode.String()),
)
Expand Down Expand Up @@ -164,8 +165,72 @@ func (c *Conn) WriteRaw(b []byte) error {
return nil
}

// lastRequestID stores incremented value of last recorded request header ID.
var lastRequestID atomic.Int32

// Request sends the given request to the connection and returns the response.
// If header MessageLength or RequestID is not specified, it assings the proper values.
// For header.OpCode the wire.OpCodeMsg is used as default.
//
// It returns errors only for request/response parsing issues, or connection issues.
// All of the driver level errors are stored inside response.
func (c *Conn) Request(ctx context.Context, header *wire.MsgHeader, body wire.MsgBody) (*wire.MsgHeader, wire.MsgBody, error) {
// TODO https://github.com/FerretDB/FerretDB/issues/4146
panic("not implemented")
if header.MessageLength == 0 {
msgBin, err := body.MarshalBinary()
if err != nil {
return nil, nil, lazyerrors.Error(err)
}

header.MessageLength = int32(len(msgBin) + wire.MsgHeaderLen)
}

if header.OpCode == 0 {
header.OpCode = wire.OpCodeMsg
}

if header.RequestID == 0 {
header.RequestID = lastRequestID.Add(1)
}

if header.ResponseTo != 0 {
return nil, nil, lazyerrors.Errorf("setting response_to is not allowed")
}

if m, ok := body.(*wire.OpMsg); ok {
if m.Flags != 0 {
return nil, nil, lazyerrors.Errorf("unsupported request flags %s", m.Flags)
}
}

if err := c.Write(header, body); err != nil {
return nil, nil, lazyerrors.Error(err)
}

resHeader, resBody, err := c.Read()
if err != nil {
return nil, nil, lazyerrors.Error(err)
}

if resHeader.ResponseTo != header.RequestID {
c.l.Error(
"response_to is not equal to request_id",
slog.Int("request_id", int(header.RequestID)),
slog.Int("response_id", int(resHeader.RequestID)),
slog.Int("response_to", int(resHeader.ResponseTo)),
)

return nil, nil, lazyerrors.Errorf(
"response_to is not equal to request_id (response_to=%d; expected=%d)",
resHeader.ResponseTo,
header.RequestID,
)
}

if m, ok := resBody.(*wire.OpMsg); ok {
if m.Flags != 0 {
return nil, nil, lazyerrors.Errorf("unsupported response flags %s", m.Flags)
}
}

return resHeader, resBody, nil
}
161 changes: 159 additions & 2 deletions internal/driver/driver_test.go
Expand Up @@ -17,9 +17,15 @@ package driver
import (
"testing"

"github.com/cristalhq/bson/bsonproto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/FerretDB/FerretDB/internal/bson"
"github.com/FerretDB/FerretDB/internal/types"
"github.com/FerretDB/FerretDB/internal/util/must"
"github.com/FerretDB/FerretDB/internal/util/testutil"
"github.com/FerretDB/FerretDB/internal/wire"
)

func TestDriver(t *testing.T) {
Expand All @@ -33,6 +39,157 @@ func TestDriver(t *testing.T) {
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, c.Close()) })

// TODO https://github.com/FerretDB/FerretDB/issues/4146
_ = c
dbName := testutil.DatabaseName(t)

lsid := bson.Binary{
B: []byte{
0xa3, 0x19, 0xf2, 0xb4, 0xa1, 0x75, 0x40, 0xc7,
0xb8, 0xe7, 0xa3, 0xa3, 0x2e, 0xc2, 0x56, 0xbe,
},
Subtype: bsonproto.BinaryUUID,
}

expectedBatches := make([]bson.Array, 3)

err = expectedBatches[0].Add(must.NotFail(bson.NewDocument("_id", int32(0), "w", int32(2), "v", int32(1))))
require.NoError(t, err)
err = expectedBatches[1].Add(must.NotFail(bson.NewDocument("_id", int32(1), "v", int32(2))))
require.NoError(t, err)
err = expectedBatches[2].Add(must.NotFail(bson.NewDocument("_id", int32(2), "v", int32(3))))
require.NoError(t, err)

t.Run("Drop", func(t *testing.T) {
dropCmd := must.NotFail(bson.NewDocument(
"dropDatabase", int32(1),
"lsid", must.NotFail(bson.NewDocument("id", lsid)),
"$db", dbName,
))

body, err := wire.NewOpMsg(must.NotFail(dropCmd.Encode()))
require.NoError(t, err)

resHeader, resBody, err := c.Request(ctx, new(wire.MsgHeader), body)
require.NoError(t, err)
assert.NotZero(t, resHeader.RequestID)

resMsg, err := must.NotFail(resBody.(*wire.OpMsg).RawDocument()).Decode()
require.NoError(t, err)

ok := resMsg.Get("ok").(float64)
require.Equal(t, float64(1), ok)
})

t.Run("Insert", func(t *testing.T) {
insertCmd := must.NotFail(bson.NewDocument(
"insert", "values",
"documents", must.NotFail(bson.NewArray(
must.NotFail(bson.NewDocument("w", int32(2), "v", int32(1), "_id", int32(0))),
must.NotFail(bson.NewDocument("v", int32(2), "_id", int32(1))),
must.NotFail(bson.NewDocument("v", int32(3), "_id", int32(2))),
)),
"ordered", true,
"lsid", must.NotFail(bson.NewDocument("id", lsid)),
"$db", dbName,
))

body, err := wire.NewOpMsg(must.NotFail(insertCmd.Encode()))
require.NoError(t, err)

resHeader, resBody, err := c.Request(ctx, new(wire.MsgHeader), body)
require.NoError(t, err)
assert.NotZero(t, resHeader.RequestID)

resMsg, err := must.NotFail(resBody.(*wire.OpMsg).RawDocument()).Decode()
require.NoError(t, err)

ok := resMsg.Get("ok").(float64)
require.Equal(t, float64(1), ok)

n := resMsg.Get("n").(int32)
require.Equal(t, int32(3), n)
})

var cursorID int64

t.Run("Find", func(t *testing.T) {
findCmd := must.NotFail(bson.NewDocument(
"find", "values",
"filter", must.NotFail(bson.NewDocument()),
"sort", must.NotFail(bson.NewDocument("_id", int32(1))),
"lsid", must.NotFail(bson.NewDocument("id", lsid)),
"batchSize", int32(1),
"$db", dbName,
))

body, err := wire.NewOpMsg(must.NotFail(findCmd.Encode()))
require.NoError(t, err)

resHeader, resBody, err := c.Request(ctx, new(wire.MsgHeader), body)
require.NoError(t, err)
assert.NotZero(t, resHeader.RequestID)

resMsg, err := must.NotFail(resBody.(*wire.OpMsg).RawDocument()).Decode()
require.NoError(t, err)

cursor, err := resMsg.Get("cursor").(bson.RawDocument).Decode()
require.NoError(t, err)

firstBatch := cursor.Get("firstBatch").(bson.RawArray)
cursorID = cursor.Get("id").(int64)

testutil.AssertEqual(t, must.NotFail(expectedBatches[0].Convert()), must.NotFail(firstBatch.Convert()))
require.NotZero(t, cursorID)
})

getMoreCmd := must.NotFail(bson.NewDocument(
"getMore", cursorID,
"collection", "values",
"lsid", must.NotFail(bson.NewDocument("id", lsid)),
"batchSize", int32(1),
"$db", dbName,
))

t.Run("GetMore", func(t *testing.T) {
for i := 1; i < 3; i++ {
body, err := wire.NewOpMsg(must.NotFail(getMoreCmd.Encode()))
require.NoError(t, err)

resHeader, resBody, err := c.Request(ctx, new(wire.MsgHeader), body)
require.NoError(t, err)
assert.NotZero(t, resHeader.RequestID)

resMsg, err := must.NotFail(resBody.(*wire.OpMsg).RawDocument()).Decode()
require.NoError(t, err)

cursor, err := resMsg.Get("cursor").(bson.RawDocument).Decode()
require.NoError(t, err)

nextBatch := cursor.Get("nextBatch").(bson.RawArray)
newCursorID := cursor.Get("id").(int64)

testutil.AssertEqual(t, must.NotFail(expectedBatches[i].Convert()), must.NotFail(nextBatch.Convert()))
assert.Equal(t, cursorID, newCursorID)
}
})

t.Run("GetMoreEmpty", func(t *testing.T) {
body, err := wire.NewOpMsg(must.NotFail(getMoreCmd.Encode()))
require.NoError(t, err)

resHeader, resBody, err := c.Request(ctx, new(wire.MsgHeader), body)
require.NoError(t, err)
assert.NotZero(t, resHeader.RequestID)

resMsg, err := must.NotFail(resBody.(*wire.OpMsg).RawDocument()).Decode()
require.NoError(t, err)

cursor, err := resMsg.Get("cursor").(bson.RawDocument).Decode()
require.NoError(t, err)

nextBatch := cursor.Get("nextBatch").(bson.RawArray)
newCursorID := cursor.Get("id").(int64)

testutil.AssertEqual(t, types.MakeArray(0), must.NotFail(nextBatch.Convert()))
assert.Zero(t, newCursorID)
})
}

0 comments on commit 445c9b3

Please sign in to comment.