Skip to content

Commit

Permalink
fix(storage): postgresql webauthn tbl invalid aaguid constraint (#5183)
Browse files Browse the repository at this point in the history
This fixes an issue with the PostgreSQL schema where the webauthn tables aaguid column had a NOT NULL constraint erroneously.

Fixes #5182

Signed-off-by: James Elliott <james-d-elliott@users.noreply.github.com>
  • Loading branch information
james-d-elliott committed Apr 8, 2023
1 parent 3b52ddb commit fa250ea
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 78 deletions.
9 changes: 9 additions & 0 deletions internal/model/schema_migration.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package model

import (
"strings"
)

// SchemaMigration represents an intended migration.
type SchemaMigration struct {
Version int
Expand All @@ -9,6 +13,11 @@ type SchemaMigration struct {
Query string
}

// NotEmpty returns true if the SchemaMigration is not an empty string.
func (m SchemaMigration) NotEmpty() bool {
return len(strings.TrimSpace(m.Query)) != 0
}

// Before returns the version the schema should be at Before the migration is applied.
func (m SchemaMigration) Before() (before int) {
if m.Up {
Expand Down
70 changes: 47 additions & 23 deletions internal/storage/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"embed"
"errors"
"fmt"
"io/fs"
"sort"
"strconv"
"strings"
Expand All @@ -15,8 +16,12 @@ import (
var migrationsFS embed.FS

func latestMigrationVersion(providerName string) (version int, err error) {
entries, err := migrationsFS.ReadDir("migrations")
if err != nil {
var (
entries []fs.DirEntry
migration model.SchemaMigration
)

if entries, err = migrationsFS.ReadDir("migrations"); err != nil {
return -1, err
}

Expand All @@ -25,21 +30,20 @@ func latestMigrationVersion(providerName string) (version int, err error) {
continue
}

m, err := scanMigration(entry.Name())
if err != nil {
if migration, err = scanMigration(entry.Name()); err != nil {
return -1, err
}

if m.Provider != providerName {
if migration.Provider != providerName && migration.Provider != providerAll {
continue
}

if !m.Up {
if !migration.Up {
continue
}

if m.Version > version {
version = m.Version
if migration.Version > version {
version = migration.Version
}
}

Expand All @@ -50,12 +54,17 @@ func latestMigrationVersion(providerName string) (version int, err error) {
// target versions. If the target version is -1 this indicates the latest version. If the target version is 0
// this indicates the database zero state.
func loadMigrations(providerName string, prior, target int) (migrations []model.SchemaMigration, err error) {
if prior == target && (prior != -1 || target != -1) {
if prior == target {
return nil, ErrMigrateCurrentVersionSameAsTarget
}

entries, err := migrationsFS.ReadDir("migrations")
if err != nil {
var (
migrationsAll []model.SchemaMigration
migration model.SchemaMigration
entries []fs.DirEntry
)

if entries, err = migrationsFS.ReadDir("migrations"); err != nil {
return nil, err
}

Expand All @@ -66,16 +75,36 @@ func loadMigrations(providerName string, prior, target int) (migrations []model.
continue
}

migration, err := scanMigration(entry.Name())
if err != nil {
if migration, err = scanMigration(entry.Name()); err != nil {
return nil, err
}

if skipMigration(providerName, up, target, prior, &migration) {
continue
}

migrations = append(migrations, migration)
if migration.Provider == providerAll {
migrationsAll = append(migrationsAll, migration)
} else {
migrations = append(migrations, migration)
}
}

// Add "all" migrations for versions that don't exist.
for _, am := range migrationsAll {
found := false

for _, m := range migrations {
if m.Version == am.Version {
found = true

break
}
}

if !found {
migrations = append(migrations, am)
}
}

if up {
Expand Down Expand Up @@ -103,7 +132,7 @@ func skipMigration(providerName string, up bool, target, prior int, migration *m
return true
}

if target != -1 && (migration.Version > target || migration.Version <= prior) {
if migration.Version > target || migration.Version <= prior {
// Skip if the migration version is greater than the target or less than or equal to the previous version.
return true
}
Expand All @@ -113,12 +142,6 @@ func skipMigration(providerName string, up bool, target, prior int, migration *m
return true
}

if migration.Version == 1 && target == -1 {
// Skip if we're targeting pre1 and the migration version is 1 as this migration will destroy all data
// preventing a successful migration.
return true
}

if migration.Version <= target || migration.Version > prior {
// Skip the migration if we want to go down and the migration version is less than or equal to the target
// or greater than the previous version.
Expand All @@ -141,8 +164,9 @@ func scanMigration(m string) (migration model.SchemaMigration, err error) {
Provider: result[reMigration.SubexpIndex("Provider")],
}

data, err := migrationsFS.ReadFile(fmt.Sprintf("migrations/%s", m))
if err != nil {
var data []byte

if data, err = migrationsFS.ReadFile(fmt.Sprintf("migrations/%s", m)); err != nil {
return model.SchemaMigration{}, err
}

Expand Down
47 changes: 0 additions & 47 deletions internal/storage/migrations/V0007.ConsistencyFixes.postgres.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -56,30 +56,8 @@ ALTER TABLE totp_configurations
DROP INDEX IF EXISTS totp_configurations_username_key1;
DROP INDEX IF EXISTS totp_configurations_username_key;

ALTER TABLE totp_configurations
RENAME TO _bkp_UP_V0007_totp_configurations;

CREATE TABLE IF NOT EXISTS totp_configurations (
id SERIAL CONSTRAINT totp_configurations_pkey PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_used_at TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
username VARCHAR(100) NOT NULL,
issuer VARCHAR(100),
algorithm VARCHAR(6) NOT NULL DEFAULT 'SHA1',
digits INTEGER NOT NULL DEFAULT 6,
period INTEGER NOT NULL DEFAULT 30,
secret BYTEA NOT NULL
);

CREATE UNIQUE INDEX totp_configurations_username_key ON totp_configurations (username);

INSERT INTO totp_configurations (created_at, last_used_at, username, issuer, algorithm, digits, period, secret)
SELECT created_at, last_used_at, username, issuer, algorithm, digits, period, secret
FROM _bkp_UP_V0007_totp_configurations
ORDER BY id;

DROP TABLE IF EXISTS _bkp_UP_V0007_totp_configurations;

ALTER TABLE webauthn_devices
DROP CONSTRAINT IF EXISTS webauthn_devices_username_description_key1,
DROP CONSTRAINT IF EXISTS webauthn_devices_kid_key1,
Expand All @@ -97,34 +75,9 @@ DROP INDEX IF EXISTS webauthn_devices_username_description_key;
DROP INDEX IF EXISTS webauthn_devices_kid_key;
DROP INDEX IF EXISTS webauthn_devices_lookup_key;

ALTER TABLE webauthn_devices
RENAME TO _bkp_UP_V0007_webauthn_devices;

CREATE TABLE IF NOT EXISTS webauthn_devices (
id SERIAL CONSTRAINT webauthn_devices_pkey PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_used_at TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
rpid TEXT,
username VARCHAR(100) NOT NULL,
description VARCHAR(30) NOT NULL DEFAULT 'Primary',
kid VARCHAR(512) NOT NULL,
public_key BYTEA NOT NULL,
attestation_type VARCHAR(32),
transport VARCHAR(20) DEFAULT '',
aaguid CHAR(36) NOT NULL,
sign_count INTEGER DEFAULT 0,
clone_warning BOOLEAN NOT NULL DEFAULT FALSE
);

CREATE UNIQUE INDEX webauthn_devices_kid_key ON webauthn_devices (kid);
CREATE UNIQUE INDEX webauthn_devices_lookup_key ON webauthn_devices (username, description);

INSERT INTO webauthn_devices (created_at, last_used_at, rpid, username, description, kid, public_key, attestation_type, transport, aaguid, sign_count, clone_warning)
SELECT created_at, last_used_at, rpid, username, description, kid, public_key, attestation_type, transport, aaguid, sign_count, clone_warning
FROM _bkp_UP_V0007_webauthn_devices;

DROP TABLE IF EXISTS _bkp_UP_V0007_webauthn_devices;

ALTER TABLE oauth2_consent_session
DROP CONSTRAINT oauth2_consent_session_subject_fkey,
DROP CONSTRAINT oauth2_consent_session_preconfiguration_fkey;
Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
ALTER TABLE webauthn_devices
ALTER COLUMN aaguid DROP NOT NULL;

UPDATE webauthn_devices
SET aaguid = NULL
WHERE aaguid = '' OR aaguid = '00000000-00000000-00000000-00000000';
43 changes: 42 additions & 1 deletion internal/storage/migrations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

const (
// This is the latest schema version for the purpose of tests.
LatestVersion = 8
LatestVersion = 9
)

func TestShouldObtainCorrectUpMigrations(t *testing.T) {
Expand Down Expand Up @@ -44,6 +44,47 @@ func TestShouldObtainCorrectDownMigrations(t *testing.T) {
}
}

func TestMigrationShouldGetSpecificMigrationIfAvaliable(t *testing.T) {
upMigrationsPostgreSQL, err := loadMigrations(providerPostgres, 8, 9)
require.NoError(t, err)
require.Len(t, upMigrationsPostgreSQL, 1)

assert.True(t, upMigrationsPostgreSQL[0].Up)
assert.Equal(t, 9, upMigrationsPostgreSQL[0].Version)
assert.Equal(t, providerPostgres, upMigrationsPostgreSQL[0].Provider)

upMigrationsSQLite, err := loadMigrations(providerSQLite, 8, 9)
require.NoError(t, err)
require.Len(t, upMigrationsSQLite, 1)

assert.True(t, upMigrationsSQLite[0].Up)
assert.Equal(t, 9, upMigrationsSQLite[0].Version)
assert.Equal(t, providerAll, upMigrationsSQLite[0].Provider)

downMigrationsPostgreSQL, err := loadMigrations(providerPostgres, 9, 8)
require.NoError(t, err)
require.Len(t, downMigrationsPostgreSQL, 1)

assert.False(t, downMigrationsPostgreSQL[0].Up)
assert.Equal(t, 9, downMigrationsPostgreSQL[0].Version)
assert.Equal(t, providerAll, downMigrationsPostgreSQL[0].Provider)

downMigrationsSQLite, err := loadMigrations(providerSQLite, 9, 8)
require.NoError(t, err)
require.Len(t, downMigrationsSQLite, 1)

assert.False(t, downMigrationsSQLite[0].Up)
assert.Equal(t, 9, downMigrationsSQLite[0].Version)
assert.Equal(t, providerAll, downMigrationsSQLite[0].Provider)
}

func TestMigrationShouldReturnErrorOnSame(t *testing.T) {
migrations, err := loadMigrations(providerPostgres, 1, 1)

assert.EqualError(t, err, "current version is same as migration target, no action being taken")
assert.Nil(t, migrations)
}

func TestMigrationsShouldNotBeDuplicatedPostgres(t *testing.T) {
migrations, err := loadMigrations(providerPostgres, 0, SchemaLatest)
require.NoError(t, err)
Expand Down
16 changes: 9 additions & 7 deletions internal/storage/sql_provider_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,16 @@ func (p *SQLProvider) schemaMigrateLock(ctx context.Context, conn SQLXConnection
}

func (p *SQLProvider) schemaMigrateApply(ctx context.Context, conn SQLXConnection, migration model.SchemaMigration) (err error) {
if _, err = conn.ExecContext(ctx, migration.Query); err != nil {
return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err)
}
if migration.NotEmpty() {
if _, err = conn.ExecContext(ctx, migration.Query); err != nil {
return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err)
}

if migration.Version == 1 && migration.Up {
// Add the schema encryption value if upgrading to v1.
if err = p.setNewEncryptionCheckValue(ctx, conn, &p.key); err != nil {
return err
if migration.Version == 1 && migration.Up {
// Add the schema encryption value if upgrading to v1.
if err = p.setNewEncryptionCheckValue(ctx, conn, &p.key); err != nil {
return err
}
}
}

Expand Down

0 comments on commit fa250ea

Please sign in to comment.