Skip to content

Commit

Permalink
Improve Multi perf on SQL state stores with 1 op only (#3300)
Browse files Browse the repository at this point in the history
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
Signed-off-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
  • Loading branch information
ItalyPaleAle committed Jan 4, 2024
1 parent 60cd144 commit 94c5618
Show file tree
Hide file tree
Showing 8 changed files with 354 additions and 190 deletions.
47 changes: 29 additions & 18 deletions common/component/postgresql/v1/postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,29 +438,40 @@ func (p *PostgreSQL) doDelete(parentCtx context.Context, db pginterfaces.DBQueri
}

func (p *PostgreSQL) Multi(parentCtx context.Context, request *state.TransactionalStateRequest) error {
_, err := pgtransactions.ExecuteInTransaction[struct{}](parentCtx, p.logger, p.db, p.metadata.Timeout, func(ctx context.Context, tx pgx.Tx) (res struct{}, err error) {
for _, o := range request.Operations {
switch x := o.(type) {
case state.SetRequest:
err = p.doSet(parentCtx, tx, &x)
if request == nil {
return nil
}

// If there's only 1 operation, skip starting a transaction
switch len(request.Operations) {
case 0:
return nil
case 1:
return p.execMultiOperation(parentCtx, request.Operations[0], p.db)
default:
_, err := pgtransactions.ExecuteInTransaction[struct{}](parentCtx, p.logger, p.db, p.metadata.Timeout, func(ctx context.Context, tx pgx.Tx) (res struct{}, err error) {
for _, op := range request.Operations {
err = p.execMultiOperation(ctx, op, tx)
if err != nil {
return res, err
}

case state.DeleteRequest:
err = p.doDelete(parentCtx, tx, &x)
if err != nil {
return res, err
}

default:
return res, fmt.Errorf("unsupported operation: %s", o.Operation())
}
}

return res, nil
})
return err
return res, nil
})
return err
}
}

func (p *PostgreSQL) execMultiOperation(ctx context.Context, op state.TransactionalStateOperation, db pginterfaces.DBQuerier) error {
switch x := op.(type) {
case state.SetRequest:
return p.doSet(ctx, db, &x)
case state.DeleteRequest:
return p.doDelete(ctx, db, &x)
default:
return fmt.Errorf("unsupported operation: %s", op.Operation())
}
}

func (p *PostgreSQL) CleanupExpired() error {
Expand Down
107 changes: 72 additions & 35 deletions common/component/postgresql/v1/postgresql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,9 @@ func TestMultiWithNoRequests(t *testing.T) {
m, _ := mockDatabase(t)
defer m.db.Close()

m.db.ExpectBegin()
m.db.ExpectCommit()
// There's also a rollback called after a commit, which is expected and will not have effect
m.db.ExpectRollback()

var operations []state.TransactionalStateOperation

// Act
err := m.pg.Multi(context.Background(), &state.TransactionalStateRequest{
Operations: operations,
Operations: nil,
})

// Assert
Expand All @@ -66,24 +59,46 @@ func TestValidSetRequest(t *testing.T) {
defer m.db.Close()

setReq := createSetRequest()
operations := []state.TransactionalStateOperation{setReq}
val, _ := json.Marshal(setReq.Value)

m.db.ExpectBegin()
m.db.ExpectExec("INSERT INTO").
WithArgs(setReq.Key, string(val), false).
WillReturnResult(pgxmock.NewResult("INSERT", 1))
m.db.ExpectCommit()
// There's also a rollback called after a commit, which is expected and will not have effect
m.db.ExpectRollback()
t.Run("single op", func(t *testing.T) {
operations := []state.TransactionalStateOperation{setReq}

// Act
err := m.pg.Multi(context.Background(), &state.TransactionalStateRequest{
Operations: operations,
m.db.ExpectExec("INSERT INTO").
WithArgs(setReq.Key, string(val), false).
WillReturnResult(pgxmock.NewResult("INSERT", 1))

// Act
err := m.pg.Multi(context.Background(), &state.TransactionalStateRequest{
Operations: operations,
})

// Assert
require.NoError(t, err)
})

// Assert
require.NoError(t, err)
t.Run("multiple ops", func(t *testing.T) {
operations := []state.TransactionalStateOperation{setReq, setReq}

m.db.ExpectBegin()
m.db.ExpectExec("INSERT INTO").
WithArgs(setReq.Key, string(val), false).
WillReturnResult(pgxmock.NewResult("INSERT", 1))
m.db.ExpectExec("INSERT INTO").
WithArgs(setReq.Key, string(val), false).
WillReturnResult(pgxmock.NewResult("INSERT", 1))
m.db.ExpectCommit()
// There's also a rollback called after a commit, which is expected and will not have effect
m.db.ExpectRollback()

// Act
err := m.pg.Multi(context.Background(), &state.TransactionalStateRequest{
Operations: operations,
})

// Assert
require.NoError(t, err)
})
}

func TestInvalidMultiSetRequestNoKey(t *testing.T) {
Expand Down Expand Up @@ -113,23 +128,45 @@ func TestValidMultiDeleteRequest(t *testing.T) {
defer m.db.Close()

deleteReq := createDeleteRequest()
operations := []state.TransactionalStateOperation{deleteReq}

m.db.ExpectBegin()
m.db.ExpectExec("DELETE FROM").
WithArgs(deleteReq.Key).
WillReturnResult(pgxmock.NewResult("DELETE", 1))
m.db.ExpectCommit()
// There's also a rollback called after a commit, which is expected and will not have effect
m.db.ExpectRollback()
t.Run("single op", func(t *testing.T) {
operations := []state.TransactionalStateOperation{deleteReq}

// Act
err := m.pg.Multi(context.Background(), &state.TransactionalStateRequest{
Operations: operations,
m.db.ExpectExec("DELETE FROM").
WithArgs(deleteReq.Key).
WillReturnResult(pgxmock.NewResult("DELETE", 1))

// Act
err := m.pg.Multi(context.Background(), &state.TransactionalStateRequest{
Operations: operations,
})

// Assert
require.NoError(t, err)
})

// Assert
require.NoError(t, err)
t.Run("multiple ops", func(t *testing.T) {
operations := []state.TransactionalStateOperation{deleteReq, deleteReq}

m.db.ExpectBegin()
m.db.ExpectExec("DELETE FROM").
WithArgs(deleteReq.Key).
WillReturnResult(pgxmock.NewResult("DELETE", 1))
m.db.ExpectExec("DELETE FROM").
WithArgs(deleteReq.Key).
WillReturnResult(pgxmock.NewResult("DELETE", 1))
m.db.ExpectCommit()
// There's also a rollback called after a commit, which is expected and will not have effect
m.db.ExpectRollback()

// Act
err := m.pg.Multi(context.Background(), &state.TransactionalStateRequest{
Operations: operations,
})

// Assert
require.NoError(t, err)
})
}

func TestInvalidMultiDeleteRequestNoKey(t *testing.T) {
Expand All @@ -140,7 +177,7 @@ func TestInvalidMultiDeleteRequestNoKey(t *testing.T) {
m.db.ExpectBegin()
m.db.ExpectRollback()

operations := []state.TransactionalStateOperation{state.DeleteRequest{}} // Delete request without key is not valid for Delete operation
operations := []state.TransactionalStateOperation{state.DeleteRequest{}, state.DeleteRequest{}} // Delete request without key is not valid for Delete operation

// Act
err := m.pg.Multi(context.Background(), &state.TransactionalStateRequest{
Expand Down
56 changes: 33 additions & 23 deletions state/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -768,40 +768,50 @@ func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte
return key, value, &etag, expireTime, nil
}

// Multi handles multiple transactions.
// TransactionalStore Interface.
// Multi handles multiple operations in batch.
// Implements the TransactionalStore Interface.
func (m *MySQL) Multi(ctx context.Context, request *state.TransactionalStateRequest) error {
tx, err := m.db.Begin()
if err != nil {
return err
if request == nil {
return nil
}
defer func() {
rollbackErr := tx.Rollback()
if rollbackErr != nil && !errors.Is(rollbackErr, sql.ErrTxDone) {
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
}
}()

for _, o := range request.Operations {
switch req := o.(type) {
case state.SetRequest:
err = m.setValue(ctx, tx, &req)
if err != nil {
return err
// If there's only 1 operation, skip starting a transaction
switch len(request.Operations) {
case 0:
return nil
case 1:
return m.execMultiOperation(ctx, request.Operations[0], m.db)
default:
tx, err := m.db.Begin()
if err != nil {
return err
}
defer func() {
rollbackErr := tx.Rollback()
if rollbackErr != nil && !errors.Is(rollbackErr, sql.ErrTxDone) {
m.logger.Errorf("Error rolling back transaction: %v", rollbackErr)
}
}()

case state.DeleteRequest:
err = m.deleteValue(ctx, tx, &req)
for _, op := range request.Operations {
err = m.execMultiOperation(ctx, op, tx)
if err != nil {
return err
}

default:
return fmt.Errorf("unsupported operation: %s", req.Operation())
}
return tx.Commit()
}
}

return tx.Commit()
func (m *MySQL) execMultiOperation(ctx context.Context, op state.TransactionalStateOperation, db querier) error {
switch req := op.(type) {
case state.SetRequest:
return m.setValue(ctx, db, &req)
case state.DeleteRequest:
return m.deleteValue(ctx, db, &req)
default:
return fmt.Errorf("unsupported operation: %s", op.Operation())
}
}

// Close implements io.Closer.
Expand Down

0 comments on commit 94c5618

Please sign in to comment.