Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQLite & MySQL: improve perf of Set operations with first-write-wins #3159

Merged
merged 7 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
122 changes: 34 additions & 88 deletions state/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,35 +389,6 @@ func (m *MySQL) ensureStateTable(ctx context.Context, schemaName, stateTableName
}
}

// Create the DaprSaveFirstWriteV1 stored procedure
_, err = m.db.ExecContext(ctx, `CREATE PROCEDURE IF NOT EXISTS DaprSaveFirstWriteV1(tableName VARCHAR(255), id VARCHAR(255), value JSON, etag VARCHAR(36), isbinary BOOLEAN, expiredateToken TEXT)
LANGUAGE SQL
MODIFIES SQL DATA
BEGIN
SET @id = id;
SET @value = value;
SET @etag = etag;
SET @isbinary = isbinary;

SET @selectQuery = concat('SELECT COUNT(id) INTO @count FROM ', tableName ,' WHERE id = ? AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)');
PREPARE select_stmt FROM @selectQuery;
EXECUTE select_stmt USING @id;
DEALLOCATE PREPARE select_stmt;

IF @count < 1 THEN
SET @upsertQuery = concat('INSERT INTO ', tableName, ' SET id=?, value=?, eTag=?, isbinary=?, expiredate=', expiredateToken, ' ON DUPLICATE KEY UPDATE value=?, eTag=?, isbinary=?, expiredate=', expiredateToken);
PREPARE upsert_stmt FROM @upsertQuery;
EXECUTE upsert_stmt USING @id, @value, @etag, @isbinary, @value, @etag, @isbinary;
DEALLOCATE PREPARE upsert_stmt;
ELSE
SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = 'Row already exists';
END IF;

END`)
if err != nil {
return err
}

return nil
}

Expand Down Expand Up @@ -596,7 +567,6 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.
ttlQuery string
params []any
result sql.Result
maxRows int64 = 1
)

var v any
Expand Down Expand Up @@ -624,10 +594,7 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.
ttlQuery = "NULL"
}

mustCommit := false
hasEtag := req.ETag != nil && *req.ETag != ""

if hasEtag {
if req.HasETag() {
// When an eTag is provided do an update - not insert
query = `UPDATE ` + m.tableName + `
SET value = ?, eTag = ?, isbinary = ?, expiredate = ` + ttlQuery + `
Expand All @@ -636,73 +603,52 @@ func (m *MySQL) setValue(parentCtx context.Context, querier querier, req *state.
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`
params = []any{enc, eTag, isBinary, req.Key, *req.ETag}
} else if req.Options.Concurrency == state.FirstWrite {
// If we're not in a transaction already, start one as we need to ensure consistency
if querier == m.db {
querier, err = m.db.BeginTx(parentCtx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer querier.(*sql.Tx).Rollback()
mustCommit = true
}

// With first-write-wins and no etag, we can insert the row only if it doesn't exist
// Things get a bit tricky when the row exists but it is expired, so it just hasn't been garbage-collected yet
// What we can do in that case is to first check if the row doesn't exist or has expired, and then perform an upsert
// To do that, we use a stored procedure
query = "CALL DaprSaveFirstWriteV1(?, ?, ?, ?, ?, ?)"
params = []any{m.tableName, req.Key, enc, eTag, isBinary, ttlQuery}
// If the operation uses first-write concurrency, we need to handle the special case of a row that has expired but hasn't been garbage collected yet
// In this case, the row should be considered as if it were deleted
query = `REPLACE INTO ` + m.tableName + `
WITH a AS (
SELECT
? AS id,
? AS value,
? AS isbinary,
CURRENT_TIMESTAMP AS insertDate,
CURRENT_TIMESTAMP AS updateDate,
? AS eTag,
` + ttlQuery + ` AS expiredate
FROM ` + m.tableName + `
WHERE NOT EXISTS (
SELECT 1
FROM ` + m.tableName + `
WHERE id = ?
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)
)
)
SELECT * FROM a`
params = []any{req.Key, enc, isBinary, eTag, req.Key}
} else {
// If this is a duplicate MySQL returns that two rows affected
maxRows = 2
query = `INSERT INTO ` + m.tableName + ` (id, value, eTag, isbinary, expiredate)
VALUES (?, ?, ?, ?, ` + ttlQuery + `)
ON DUPLICATE KEY UPDATE
value=?, eTag=?, isbinary=?, expiredate=` + ttlQuery
params = []any{req.Key, enc, eTag, isBinary, enc, eTag, isBinary}
query = `REPLACE INTO ` + m.tableName + ` (id, value, eTag, isbinary, expiredate)
VALUES (?, ?, ?, ?, ` + ttlQuery + `)`
params = []any{req.Key, enc, eTag, isBinary}
}

ctx, cancel := context.WithTimeout(parentCtx, m.timeout)
defer cancel()
result, err = querier.ExecContext(ctx, query, params...)

if err != nil {
if hasEtag {
return state.NewETagError(state.ETagMismatch, err)
}

return err
}

// Do not count affected rows when using first-write
// Conflicts are handled separately
if hasEtag || req.Options.Concurrency != state.FirstWrite {
var rows int64
rows, err = result.RowsAffected()
if err != nil {
return err
}

if rows == 0 {
err = errors.New("rows affected error: no rows match given key and eTag")
err = state.NewETagError(state.ETagMismatch, err)
m.logger.Error(err)
return err
}

if rows > maxRows {
err = fmt.Errorf("rows affected error: more than %d row affected; actual %d", maxRows, rows)
m.logger.Error(err)
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}

// Commit the transaction if needed
if mustCommit {
err = querier.(*sql.Tx).Commit()
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
if rows == 0 && (req.HasETag() || req.Options.Concurrency == state.FirstWrite) {
err = errors.New("rows affected error: no rows match given key and eTag")
err = state.NewETagError(state.ETagMismatch, err)
m.logger.Error(err)
return err
}

return nil
Expand Down
39 changes: 6 additions & 33 deletions state/mysql/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func TestMultiCommitSetsAndDeletes(t *testing.T) {
defer m.mySQL.Close()

m.mock1.ExpectBegin()
m.mock1.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("REPLACE INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("DELETE FROM").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectCommit()

Expand Down Expand Up @@ -255,24 +255,8 @@ func TestSetHandlesErr(t *testing.T) {
m, _ := mockDatabase(t)
defer m.mySQL.Close()

t.Run("error occurs when update with tag", func(t *testing.T) {
m.mock1.ExpectExec("UPDATE state").WillReturnError(errors.New("error"))

eTag := "946af561"
request := createSetRequest()
request.ETag = &eTag

// Act
err := m.mySQL.Set(context.Background(), &request)

// Assert
assert.Error(t, err)
assert.IsType(t, &state.ETagError{}, err)
assert.Equal(t, err.(*state.ETagError).Kind(), state.ETagMismatch)
})

t.Run("error occurs when insert", func(t *testing.T) {
m.mock1.ExpectExec("INSERT INTO state").WillReturnError(errors.New("error"))
m.mock1.ExpectExec("REPLACE INTO state").WillReturnError(errors.New("error"))
request := createSetRequest()

// Act
Expand All @@ -284,7 +268,7 @@ func TestSetHandlesErr(t *testing.T) {
})

t.Run("insert on conflict", func(t *testing.T) {
m.mock1.ExpectExec("INSERT INTO state").WillReturnResult(sqlmock.NewResult(1, 2))
m.mock1.ExpectExec("REPLACE INTO state").WillReturnResult(sqlmock.NewResult(1, 2))
request := createSetRequest()

// Act
Expand All @@ -294,17 +278,6 @@ func TestSetHandlesErr(t *testing.T) {
assert.NoError(t, err)
})

t.Run("too many rows error", func(t *testing.T) {
m.mock1.ExpectExec("INSERT INTO state").WillReturnResult(sqlmock.NewResult(1, 3))
request := createSetRequest()

// Act
err := m.mySQL.Set(context.Background(), &request)

// Assert
assert.Error(t, err)
})

t.Run("no rows effected error", func(t *testing.T) {
m.mock1.ExpectExec("UPDATE state").WillReturnResult(sqlmock.NewResult(1, 0))

Expand Down Expand Up @@ -716,7 +689,7 @@ func TestValidSetRequest(t *testing.T) {
}

m.mock1.ExpectBegin()
m.mock1.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("REPLACE INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectCommit()

// Act
Expand Down Expand Up @@ -805,9 +778,9 @@ func TestMultiOperationOrder(t *testing.T) {

// expected to run the operations in sequence
m.mock1.ExpectBegin()
m.mock1.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("REPLACE INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("DELETE FROM").WithArgs("k1").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("INSERT INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectExec("REPLACE INTO").WillReturnResult(sqlmock.NewResult(0, 1))
m.mock1.ExpectCommit()

// Act
Expand Down
72 changes: 25 additions & 47 deletions state/sqlite/sqlite_dbaccess.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,52 +333,40 @@ func (a *sqliteDBAccess) doSet(parentCtx context.Context, db querier, req *state

// Only check for etag if FirstWrite specified (ref oracledatabaseaccess)
var (
res sql.Result
mustCommit bool
stmt string
res sql.Result
stmt string
)
ctx, cancel := context.WithTimeout(context.Background(), a.metadata.Timeout)
defer cancel()

// Sprintf is required for table name because sql.DB does not substitute parameters for table names.
// And the same is for DATETIME function's seconds parameter (which is from an integer anyways).
if !req.HasETag() {
switch {
case !req.HasETag() && req.Options.Concurrency == state.FirstWrite:
// If the operation uses first-write concurrency, we need to handle the special case of a row that has expired but hasn't been garbage collected yet
// In this case, the row should be considered as if it were deleted
// With SQLite, the only way we can handle that is by performing a SELECT query first
if req.Options.Concurrency == state.FirstWrite {
// If we're not in a transaction already, start one as we need to ensure consistency
if db == a.db {
db, err = a.db.BeginTx(parentCtx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer db.(*sql.Tx).Rollback()
mustCommit = true
}

// Check if there's already a row with the given key that has not expired yet
var count int
stmt = `SELECT COUNT(key)
stmt = `WITH a AS (
SELECT
?, ?, ?, ?, ` + expiration + `, CURRENT_TIMESTAMP
FROM ` + a.metadata.TableName + `
WHERE key = ?
AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)`
err = db.QueryRowContext(parentCtx, stmt, req.Key).Scan(&count)
if err != nil {
return fmt.Errorf("failed to check for existing row with first-write concurrency: %w", err)
}

// If the row exists, then we just return an etag error
// Otherwise, we can fall through and continue with an INSERT OR REPLACE statement
if count > 0 {
return state.NewETagError(state.ETagMismatch, nil)
}
}
WHERE NOT EXISTS (
SELECT 1
FROM ` + a.metadata.TableName + `
WHERE key = ?
AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)
)
)
INSERT OR REPLACE INTO ` + a.metadata.TableName + `
SELECT * FROM a`
res, err = db.ExecContext(ctx, stmt, req.Key, requestValue, isBinary, newEtag, req.Key)

case !req.HasETag():
stmt = "INSERT OR REPLACE INTO " + a.metadata.TableName + `
(key, value, is_binary, etag, update_time, expiration_time)
VALUES(?, ?, ?, ?, CURRENT_TIMESTAMP, ` + expiration + `)`
ctx, cancel := context.WithTimeout(context.Background(), a.metadata.Timeout)
defer cancel()
res, err = db.ExecContext(ctx, stmt, req.Key, requestValue, isBinary, newEtag, req.Key)
} else {
res, err = db.ExecContext(ctx, stmt, req.Key, requestValue, isBinary, newEtag)

default:
stmt = `UPDATE ` + a.metadata.TableName + ` SET
value = ?,
etag = ?,
Expand All @@ -389,8 +377,6 @@ func (a *sqliteDBAccess) doSet(parentCtx context.Context, db querier, req *state
key = ?
AND etag = ?
AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)`
ctx, cancel := context.WithTimeout(context.Background(), a.metadata.Timeout)
defer cancel()
res, err = db.ExecContext(ctx, stmt, requestValue, newEtag, isBinary, req.Key, *req.ETag)
}
if err != nil {
Expand All @@ -403,20 +389,12 @@ func (a *sqliteDBAccess) doSet(parentCtx context.Context, db querier, req *state
return err
}
if rows == 0 {
if req.HasETag() {
if req.HasETag() || req.Options.Concurrency == state.FirstWrite {
return state.NewETagError(state.ETagMismatch, nil)
}
return errors.New("no item was updated")
}

// Commit the transaction if needed
if mustCommit {
err = db.(*sql.Tx).Commit()
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
}

return nil
}

Expand Down