diff --git a/.golangci.yml b/.golangci.yml index 7468b3036976..63813979bbb1 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -36,6 +36,7 @@ linters-settings: files: - $all - "!**/internal/bson/*_test.go" + - "!**/internal/driver/*.go" - "!**/internal/util/testutil/*.go" - "!**/internal/wire/*.go" deny: @@ -44,6 +45,7 @@ linters-settings: files: - $all - "!**/internal/bson/*.go" + - "!**/internal/driver/*.go" deny: - pkg: github.com/cristalhq/bson - pkg: github.com/cristalhq/bson/bsonproto diff --git a/internal/driver/driver.go b/internal/driver/driver.go index f1c8d24b4a1f..dcf7f8ad6fa2 100644 --- a/internal/driver/driver.go +++ b/internal/driver/driver.go @@ -22,6 +22,7 @@ import ( "log/slog" "net" "net/url" + "sync/atomic" "github.com/FerretDB/FerretDB/internal/util/lazyerrors" "github.com/FerretDB/FerretDB/internal/wire" @@ -120,7 +121,7 @@ func (c *Conn) Read() (*wire.MsgHeader, wire.MsgBody, error) { c.l.Debug( fmt.Sprintf("<<<\n%s", body.String()), slog.Int("length", int(header.MessageLength)), - slog.Int("id", int(header.ResponseTo)), + slog.Int("id", int(header.RequestID)), slog.Int("response_to", int(header.ResponseTo)), slog.String("opcode", header.OpCode.String()), ) @@ -133,7 +134,7 @@ func (c *Conn) Write(header *wire.MsgHeader, body wire.MsgBody) error { c.l.Debug( fmt.Sprintf(">>>\n%s", body.String()), slog.Int("length", int(header.MessageLength)), - slog.Int("id", int(header.ResponseTo)), + slog.Int("id", int(header.RequestID)), slog.Int("response_to", int(header.ResponseTo)), slog.String("opcode", header.OpCode.String()), ) @@ -164,8 +165,72 @@ func (c *Conn) WriteRaw(b []byte) error { return nil } +// lastRequestID stores incremented value of last recorded request header ID. +var lastRequestID atomic.Int32 + // Request sends the given request to the connection and returns the response. +// If header MessageLength or RequestID is not specified, it assings the proper values. +// For header.OpCode the wire.OpCodeMsg is used as default. +// +// It returns errors only for request/response parsing issues, or connection issues. +// All of the driver level errors are stored inside response. func (c *Conn) Request(ctx context.Context, header *wire.MsgHeader, body wire.MsgBody) (*wire.MsgHeader, wire.MsgBody, error) { - // TODO https://github.com/FerretDB/FerretDB/issues/4146 - panic("not implemented") + if header.MessageLength == 0 { + msgBin, err := body.MarshalBinary() + if err != nil { + return nil, nil, lazyerrors.Error(err) + } + + header.MessageLength = int32(len(msgBin) + wire.MsgHeaderLen) + } + + if header.OpCode == 0 { + header.OpCode = wire.OpCodeMsg + } + + if header.RequestID == 0 { + header.RequestID = lastRequestID.Add(1) + } + + if header.ResponseTo != 0 { + return nil, nil, lazyerrors.Errorf("setting response_to is not allowed") + } + + if m, ok := body.(*wire.OpMsg); ok { + if m.Flags != 0 { + return nil, nil, lazyerrors.Errorf("unsupported request flags %s", m.Flags) + } + } + + if err := c.Write(header, body); err != nil { + return nil, nil, lazyerrors.Error(err) + } + + resHeader, resBody, err := c.Read() + if err != nil { + return nil, nil, lazyerrors.Error(err) + } + + if resHeader.ResponseTo != header.RequestID { + c.l.Error( + "response_to is not equal to request_id", + slog.Int("request_id", int(header.RequestID)), + slog.Int("response_id", int(resHeader.RequestID)), + slog.Int("response_to", int(resHeader.ResponseTo)), + ) + + return nil, nil, lazyerrors.Errorf( + "response_to is not equal to request_id (response_to=%d; expected=%d)", + resHeader.ResponseTo, + header.RequestID, + ) + } + + if m, ok := resBody.(*wire.OpMsg); ok { + if m.Flags != 0 { + return nil, nil, lazyerrors.Errorf("unsupported response flags %s", m.Flags) + } + } + + return resHeader, resBody, nil } diff --git a/internal/driver/driver_test.go b/internal/driver/driver_test.go index a66555ee8b93..b7d327abb55c 100644 --- a/internal/driver/driver_test.go +++ b/internal/driver/driver_test.go @@ -17,9 +17,15 @@ package driver import ( "testing" + "github.com/cristalhq/bson/bsonproto" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/FerretDB/FerretDB/internal/bson" + "github.com/FerretDB/FerretDB/internal/types" + "github.com/FerretDB/FerretDB/internal/util/must" "github.com/FerretDB/FerretDB/internal/util/testutil" + "github.com/FerretDB/FerretDB/internal/wire" ) func TestDriver(t *testing.T) { @@ -33,6 +39,157 @@ func TestDriver(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { require.NoError(t, c.Close()) }) - // TODO https://github.com/FerretDB/FerretDB/issues/4146 - _ = c + dbName := testutil.DatabaseName(t) + + lsid := bson.Binary{ + B: []byte{ + 0xa3, 0x19, 0xf2, 0xb4, 0xa1, 0x75, 0x40, 0xc7, + 0xb8, 0xe7, 0xa3, 0xa3, 0x2e, 0xc2, 0x56, 0xbe, + }, + Subtype: bsonproto.BinaryUUID, + } + + expectedBatches := make([]bson.Array, 3) + + err = expectedBatches[0].Add(must.NotFail(bson.NewDocument("_id", int32(0), "w", int32(2), "v", int32(1)))) + require.NoError(t, err) + err = expectedBatches[1].Add(must.NotFail(bson.NewDocument("_id", int32(1), "v", int32(2)))) + require.NoError(t, err) + err = expectedBatches[2].Add(must.NotFail(bson.NewDocument("_id", int32(2), "v", int32(3)))) + require.NoError(t, err) + + t.Run("Drop", func(t *testing.T) { + dropCmd := must.NotFail(bson.NewDocument( + "dropDatabase", int32(1), + "lsid", must.NotFail(bson.NewDocument("id", lsid)), + "$db", dbName, + )) + + body, err := wire.NewOpMsg(must.NotFail(dropCmd.Encode())) + require.NoError(t, err) + + resHeader, resBody, err := c.Request(ctx, new(wire.MsgHeader), body) + require.NoError(t, err) + assert.NotZero(t, resHeader.RequestID) + + resMsg, err := must.NotFail(resBody.(*wire.OpMsg).RawDocument()).Decode() + require.NoError(t, err) + + ok := resMsg.Get("ok").(float64) + require.Equal(t, float64(1), ok) + }) + + t.Run("Insert", func(t *testing.T) { + insertCmd := must.NotFail(bson.NewDocument( + "insert", "values", + "documents", must.NotFail(bson.NewArray( + must.NotFail(bson.NewDocument("w", int32(2), "v", int32(1), "_id", int32(0))), + must.NotFail(bson.NewDocument("v", int32(2), "_id", int32(1))), + must.NotFail(bson.NewDocument("v", int32(3), "_id", int32(2))), + )), + "ordered", true, + "lsid", must.NotFail(bson.NewDocument("id", lsid)), + "$db", dbName, + )) + + body, err := wire.NewOpMsg(must.NotFail(insertCmd.Encode())) + require.NoError(t, err) + + resHeader, resBody, err := c.Request(ctx, new(wire.MsgHeader), body) + require.NoError(t, err) + assert.NotZero(t, resHeader.RequestID) + + resMsg, err := must.NotFail(resBody.(*wire.OpMsg).RawDocument()).Decode() + require.NoError(t, err) + + ok := resMsg.Get("ok").(float64) + require.Equal(t, float64(1), ok) + + n := resMsg.Get("n").(int32) + require.Equal(t, int32(3), n) + }) + + var cursorID int64 + + t.Run("Find", func(t *testing.T) { + findCmd := must.NotFail(bson.NewDocument( + "find", "values", + "filter", must.NotFail(bson.NewDocument()), + "sort", must.NotFail(bson.NewDocument("_id", int32(1))), + "lsid", must.NotFail(bson.NewDocument("id", lsid)), + "batchSize", int32(1), + "$db", dbName, + )) + + body, err := wire.NewOpMsg(must.NotFail(findCmd.Encode())) + require.NoError(t, err) + + resHeader, resBody, err := c.Request(ctx, new(wire.MsgHeader), body) + require.NoError(t, err) + assert.NotZero(t, resHeader.RequestID) + + resMsg, err := must.NotFail(resBody.(*wire.OpMsg).RawDocument()).Decode() + require.NoError(t, err) + + cursor, err := resMsg.Get("cursor").(bson.RawDocument).Decode() + require.NoError(t, err) + + firstBatch := cursor.Get("firstBatch").(bson.RawArray) + cursorID = cursor.Get("id").(int64) + + testutil.AssertEqual(t, must.NotFail(expectedBatches[0].Convert()), must.NotFail(firstBatch.Convert())) + require.NotZero(t, cursorID) + }) + + getMoreCmd := must.NotFail(bson.NewDocument( + "getMore", cursorID, + "collection", "values", + "lsid", must.NotFail(bson.NewDocument("id", lsid)), + "batchSize", int32(1), + "$db", dbName, + )) + + t.Run("GetMore", func(t *testing.T) { + for i := 1; i < 3; i++ { + body, err := wire.NewOpMsg(must.NotFail(getMoreCmd.Encode())) + require.NoError(t, err) + + resHeader, resBody, err := c.Request(ctx, new(wire.MsgHeader), body) + require.NoError(t, err) + assert.NotZero(t, resHeader.RequestID) + + resMsg, err := must.NotFail(resBody.(*wire.OpMsg).RawDocument()).Decode() + require.NoError(t, err) + + cursor, err := resMsg.Get("cursor").(bson.RawDocument).Decode() + require.NoError(t, err) + + nextBatch := cursor.Get("nextBatch").(bson.RawArray) + newCursorID := cursor.Get("id").(int64) + + testutil.AssertEqual(t, must.NotFail(expectedBatches[i].Convert()), must.NotFail(nextBatch.Convert())) + assert.Equal(t, cursorID, newCursorID) + } + }) + + t.Run("GetMoreEmpty", func(t *testing.T) { + body, err := wire.NewOpMsg(must.NotFail(getMoreCmd.Encode())) + require.NoError(t, err) + + resHeader, resBody, err := c.Request(ctx, new(wire.MsgHeader), body) + require.NoError(t, err) + assert.NotZero(t, resHeader.RequestID) + + resMsg, err := must.NotFail(resBody.(*wire.OpMsg).RawDocument()).Decode() + require.NoError(t, err) + + cursor, err := resMsg.Get("cursor").(bson.RawDocument).Decode() + require.NoError(t, err) + + nextBatch := cursor.Get("nextBatch").(bson.RawArray) + newCursorID := cursor.Get("id").(int64) + + testutil.AssertEqual(t, types.MakeArray(0), must.NotFail(nextBatch.Convert())) + assert.Zero(t, newCursorID) + }) }