Skip to content

Commit

Permalink
SQLite & MySQL: improve perf of Set operations with first-write-wins (#…
Browse files Browse the repository at this point in the history
…3159)

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
Signed-off-by: Bernd Verst <github@bernd.dev>
Co-authored-by: Bernd Verst <github@bernd.dev>
  • Loading branch information
ItalyPaleAle and berndverst committed Nov 1, 2023
1 parent 7fd5524 commit 7114fd0
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 168 deletions.
122 changes: 34 additions & 88 deletions state/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,35 +389,6 @@ func (m *MySQL) ensureStateTable(ctx context.Context, schemaName, stateTableName
}
}

// Create the DaprSaveFirstWriteV1 stored procedure
_, err = m.db.ExecContext(ctx, `CREATE PROCEDURE IF NOT EXISTS DaprSaveFirstWriteV1(tableName VARCHAR(255), id VARCHAR(255), value JSON, etag VARCHAR(36), isbinary BOOLEAN, expiredateToken TEXT)
LANGUAGE SQL
MODIFIES SQL DATA
BEGIN
SET @id = id;
SET @value = value;
SET @etag = etag;
SET @isbinary = isbinary;
SET @selectQuery = concat('SELECT COUNT(id) INTO @count FROM ', tableName ,' WHERE id = ? AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)');
PREPARE select_stmt FROM @selectQuery;
EXECUTE select_stmt USING @id;
DEALLOCATE PREPARE select_stmt;
IF @count < 1 THEN
SET @upsertQuery = concat('INSERT INTO ', tableName, ' SET id=?, value=?, eTag=?, isbinary=?, expiredate=', expiredateToken, ' ON DUPLICATE KEY UPDATE value=?, eTag=?, isbinary=?, expiredate=', expiredateToken);
PREPARE upsert_stmt FROM @upsertQuery;
EXECUTE upsert_stmt USING @id, @value, @etag, @isbinary, @value, @etag, @isbinary;
DEALLOCATE PREPARE upsert_stmt;
ELSE
SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = 'Row already exists';
END IF;
END`)
if err != nil {
return err
}

return nil
}

Expand Down Expand Up @@ -596,7 +567,6 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.
ttlQuery string
params []any
result sql.Result
maxRows int64 = 1
)

var v any
Expand Down Expand Up @@ -624,10 +594,7 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.
ttlQuery = "NULL"
}

mustCommit := false
hasEtag := req.ETag != nil && *req.ETag != ""

if hasEtag {
if req.HasETag() {
// When an eTag is provided do an update - not insert
query = `UPDATE ` + m.tableName + `
SET value = ?, eTag = ?, isbinary = ?, expiredate = ` + ttlQuery + `
Expand All @@ -636,73 +603,52 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`
params = []any{enc, eTag, isBinary, req.Key, *req.ETag}
} else if req.Options.Concurrency == state.FirstWrite {
// If we're not in a transaction already, start one as we need to ensure consistency
if querier == m.db {
querier, err = m.db.BeginTx(parentCtx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer querier.(*sql.Tx).Rollback()
mustCommit = true
}

// With first-write-wins and no etag, we can insert the row only if it doesn't exist
// Things get a bit tricky when the row exists but it is expired, so it just hasn't been garbage-collected yet
// What we can do in that case is to first check if the row doesn't exist or has expired, and then perform an upsert
// To do that, we use a stored procedure
query = "CALL DaprSaveFirstWriteV1(?, ?, ?, ?, ?, ?)"
params = []any{m.tableName, req.Key, enc, eTag, isBinary, ttlQuery}
// If the operation uses first-write concurrency, we need to handle the special case of a row that has expired but hasn't been garbage collected yet
// In this case, the row should be considered as if it were deleted
query = `REPLACE INTO ` + m.tableName + `
WITH a AS (
SELECT
? AS id,
? AS value,
? AS isbinary,
CURRENT_TIMESTAMP AS insertDate,
CURRENT_TIMESTAMP AS updateDate,
? AS eTag,
` + ttlQuery + ` AS expiredate
FROM ` + m.tableName + `
WHERE NOT EXISTS (
SELECT 1
FROM ` + m.tableName + `
WHERE id = ?
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)
)
)
SELECT * FROM a`
params = []any{req.Key, enc, isBinary, eTag, req.Key}
} else {
// If this is a duplicate MySQL returns that two rows affected
maxRows = 2
query = `INSERT INTO ` + m.tableName + ` (id, value, eTag, isbinary, expiredate)
VALUES (?, ?, ?, ?, ` + ttlQuery + `)
ON DUPLICATE KEY UPDATE
value=?, eTag=?, isbinary=?, expiredate=` + ttlQuery
params = []any{req.Key, enc, eTag, isBinary, enc, eTag, isBinary}
query = `REPLACE INTO ` + m.tableName + ` (id, value, eTag, isbinary, expiredate)
VALUES (?, ?, ?, ?, ` + ttlQuery + `)`
params = []any{req.Key, enc, eTag, isBinary}
}

ctx, cancel := context.WithTimeout(parentCtx, m.timeout)
defer cancel()
result, err = querier.ExecContext(ctx, query, params...)

if err != nil {
if hasEtag {
return state.NewETagError(state.ETagMismatch, err)
}

return err
}

// Do not count affected rows when using first-write
// Conflicts are handled separately
if hasEtag || req.Options.Concurrency != state.FirstWrite {
var rows int64
rows, err = result.RowsAffected()
if err != nil {
return err
}

if rows == 0 {
err = errors.New("rows affected error: no rows match given key and eTag")
err = state.NewETagError(state.ETagMismatch, err)
m.logger.Error(err)
return err
}

if rows > maxRows {
err = fmt.Errorf("rows affected error: more than %d row affected; actual %d", maxRows, rows)
m.logger.Error(err)
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}

// Commit the transaction if needed
if mustCommit {
err = querier.(*sql.Tx).Commit()
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
if rows == 0 && (req.HasETag() || req.Options.Concurrency == state.FirstWrite) {
err = errors.New("rows affected error: no rows match given key and eTag")
err = state.NewETagError(state.ETagMismatch, err)
m.logger.Error(err)
return err
}

return nil
Expand Down
39 changes: 6 additions & 33 deletions state/mysql/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func TestMultiCommitSetsAndDeletes(t *testing.T) {
defer m.mySQL.Close()

m.mock1.ExpectBegin()
m.mock1.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("REPLACE INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("DELETE FROM").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectCommit()

Expand Down Expand Up @@ -255,24 +255,8 @@ func TestSetHandlesErr(t *testing.T) {
m, _ := mockDatabase(t)
defer m.mySQL.Close()

t.Run("error occurs when update with tag", func(t *testing.T) {
m.mock1.ExpectExec("UPDATE state").WillReturnError(errors.New("error"))

eTag := "946af561"
request := createSetRequest()
request.ETag = &eTag

// Act
err := m.mySQL.Set(context.Background(), &request)

// Assert
assert.Error(t, err)
assert.IsType(t, &state.ETagError{}, err)
assert.Equal(t, err.(*state.ETagError).Kind(), state.ETagMismatch)
})

t.Run("error occurs when insert", func(t *testing.T) {
m.mock1.ExpectExec("INSERT INTO state").WillReturnError(errors.New("error"))
m.mock1.ExpectExec("REPLACE INTO state").WillReturnError(errors.New("error"))
request := createSetRequest()

// Act
Expand All @@ -284,7 +268,7 @@ func TestSetHandlesErr(t *testing.T) {
})

t.Run("insert on conflict", func(t *testing.T) {
m.mock1.ExpectExec("INSERT INTO state").WillReturnResult(sqlmock.NewResult(1, 2))
m.mock1.ExpectExec("REPLACE INTO state").WillReturnResult(sqlmock.NewResult(1, 2))
request := createSetRequest()

// Act
Expand All @@ -294,17 +278,6 @@ func TestSetHandlesErr(t *testing.T) {
assert.NoError(t, err)
})

t.Run("too many rows error", func(t *testing.T) {
m.mock1.ExpectExec("INSERT INTO state").WillReturnResult(sqlmock.NewResult(1, 3))
request := createSetRequest()

// Act
err := m.mySQL.Set(context.Background(), &request)

// Assert
assert.Error(t, err)
})

t.Run("no rows effected error", func(t *testing.T) {
m.mock1.ExpectExec("UPDATE state").WillReturnResult(sqlmock.NewResult(1, 0))

Expand Down Expand Up @@ -716,7 +689,7 @@ func TestValidSetRequest(t *testing.T) {
}

m.mock1.ExpectBegin()
m.mock1.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("REPLACE INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectCommit()

// Act
Expand Down Expand Up @@ -805,9 +778,9 @@ func TestMultiOperationOrder(t *testing.T) {

// expected to run the operations in sequence
m.mock1.ExpectBegin()
m.mock1.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("REPLACE INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("DELETE FROM").WithArgs("k1").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("REPLACE INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectCommit()

// Act
Expand Down
72 changes: 25 additions & 47 deletions state/sqlite/sqlite_dbaccess.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,52 +333,40 @@ func (a *sqliteDBAccess) doSet(parentCtx context.Context, db querier, req *state

// Only check for etag if FirstWrite specified (ref oracledatabaseaccess)
var (
res sql.Result
mustCommit bool
stmt string
res sql.Result
stmt string
)
ctx, cancel := context.WithTimeout(context.Background(), a.metadata.Timeout)
defer cancel()

// Sprintf is required for table name because sql.DB does not substitute parameters for table names.
// And the same is for DATETIME function's seconds parameter (which is from an integer anyways).
if !req.HasETag() {
switch {
case !req.HasETag() && req.Options.Concurrency == state.FirstWrite:
// If the operation uses first-write concurrency, we need to handle the special case of a row that has expired but hasn't been garbage collected yet
// In this case, the row should be considered as if it were deleted
// With SQLite, the only way we can handle that is by performing a SELECT query first
if req.Options.Concurrency == state.FirstWrite {
// If we're not in a transaction already, start one as we need to ensure consistency
if db == a.db {
db, err = a.db.BeginTx(parentCtx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer db.(*sql.Tx).Rollback()
mustCommit = true
}

// Check if there's already a row with the given key that has not expired yet
var count int
stmt = `SELECT COUNT(key)
stmt = `WITH a AS (
SELECT
?, ?, ?, ?, ` + expiration + `, CURRENT_TIMESTAMP
FROM ` + a.metadata.TableName + `
WHERE key = ?
AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)`
err = db.QueryRowContext(parentCtx, stmt, req.Key).Scan(&count)
if err != nil {
return fmt.Errorf("failed to check for existing row with first-write concurrency: %w", err)
}

// If the row exists, then we just return an etag error
// Otherwise, we can fall through and continue with an INSERT OR REPLACE statement
if count > 0 {
return state.NewETagError(state.ETagMismatch, nil)
}
}
WHERE NOT EXISTS (
SELECT 1
FROM ` + a.metadata.TableName + `
WHERE key = ?
AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)
)
)
INSERT OR REPLACE INTO ` + a.metadata.TableName + `
SELECT * FROM a`
res, err = db.ExecContext(ctx, stmt, req.Key, requestValue, isBinary, newEtag, req.Key)

case !req.HasETag():
stmt = "INSERT OR REPLACE INTO " + a.metadata.TableName + `
(key, value, is_binary, etag, update_time, expiration_time)
VALUES(?, ?, ?, ?, CURRENT_TIMESTAMP, ` + expiration + `)`
ctx, cancel := context.WithTimeout(context.Background(), a.metadata.Timeout)
defer cancel()
res, err = db.ExecContext(ctx, stmt, req.Key, requestValue, isBinary, newEtag, req.Key)
} else {
res, err = db.ExecContext(ctx, stmt, req.Key, requestValue, isBinary, newEtag)

default:
stmt = `UPDATE ` + a.metadata.TableName + ` SET
value = ?,
etag = ?,
Expand All @@ -389,8 +377,6 @@ func (a *sqliteDBAccess) doSet(parentCtx context.Context, db querier, req *state
key = ?
AND etag = ?
AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)`
ctx, cancel := context.WithTimeout(context.Background(), a.metadata.Timeout)
defer cancel()
res, err = db.ExecContext(ctx, stmt, requestValue, newEtag, isBinary, req.Key, *req.ETag)
}
if err != nil {
Expand All @@ -403,20 +389,12 @@ func (a *sqliteDBAccess) doSet(parentCtx context.Context, db querier, req *state
return err
}
if rows == 0 {
if req.HasETag() {
if req.HasETag() || req.Options.Concurrency == state.FirstWrite {
return state.NewETagError(state.ETagMismatch, nil)
}
return errors.New("no item was updated")
}

// Commit the transaction if needed
if mustCommit {
err = db.(*sql.Tx).Commit()
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
}

return nil
}

Expand Down

0 comments on commit 7114fd0

Please sign in to comment.