Skip to content

Commit

Permalink
feat(adapter-pg): define foreign keys, use single queries
Browse files Browse the repository at this point in the history
Many of the multi-queries result in race conditions that can be avoided
if only a single query is used.
  • Loading branch information
kevinji committed Apr 26, 2024
1 parent 2159de2 commit 2b7d139
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 90 deletions.
33 changes: 18 additions & 15 deletions packages/adapter-pg/schema.sql
@@ -1,4 +1,15 @@
\set ON_ERROR_STOP true
BEGIN TRANSACTION;

CREATE TABLE users
(
id SERIAL,
name VARCHAR(255),
email VARCHAR(255),
"emailVerified" TIMESTAMPTZ,
image TEXT,

PRIMARY KEY (id)
);

CREATE TABLE verification_token
(
Expand All @@ -24,26 +35,18 @@ CREATE TABLE accounts
session_state TEXT,
token_type TEXT,

PRIMARY KEY (id)
PRIMARY KEY (id),
FOREIGN KEY ("userId") REFERENCES users(id) ON DELETE CASCADE
);

CREATE TABLE sessions
(
id SERIAL,
"sessionToken" VARCHAR(255) NOT NULL,
"userId" INTEGER NOT NULL,
expires TIMESTAMPTZ NOT NULL,
"sessionToken" VARCHAR(255) NOT NULL,

PRIMARY KEY (id)
PRIMARY KEY ("sessionToken"),
FOREIGN KEY ("userId") REFERENCES users(id) ON DELETE CASCADE
);

CREATE TABLE users
(
id SERIAL,
name VARCHAR(255),
email VARCHAR(255),
"emailVerified" TIMESTAMPTZ,
image TEXT,

PRIMARY KEY (id)
);
COMMIT;
156 changes: 81 additions & 75 deletions packages/adapter-pg/src/index.ts
Expand Up @@ -31,12 +31,45 @@ export function mapExpiresAt(account: any): any {
}
}

// SAFETY: `idKey` must be a literal string
// SAFETY: `keys` must be a subset of literal `obj` keys
// e.g. ["name", "email"]
function createParameterizedUpdate(
obj: Record<string, any>,
idKey: string,
keys: string[]
): [string, any[]] {
let updatedCols = []
let values = [obj[idKey]]
let index = 2 // $1 is for the ID
for (const key of keys) {
if (Object.prototype.hasOwnProperty.call(obj, key)) {
updatedCols.push(`"${key}" = $${index}`)
values.push(obj[key])
++index
}
}
return [updatedCols.join(", "), values]
}

/**
* ## Setup
*
* The SQL schema for the tables used by this adapter is as follows. Learn more about the models at our doc page on [Database Models](https://authjs.dev/getting-started/adapters#models).
*
* ```sql
* BEGIN TRANSACTION;
*
* CREATE TABLE users (
* id SERIAL,
* name VARCHAR(255),
* email VARCHAR(255),
* "emailVerified" TIMESTAMPTZ,
* image TEXT,
*
* PRIMARY KEY (id)
* );
*
* CREATE TABLE verification_token (
* identifier TEXT NOT NULL,
* expires TIMESTAMPTZ NOT NULL,
Expand All @@ -59,27 +92,20 @@ export function mapExpiresAt(account: any): any {
* session_state TEXT,
* token_type TEXT,
*
* PRIMARY KEY (id)
* PRIMARY KEY (id),
* FOREIGN KEY ("userId") REFERENCES users(id) ON DELETE CASCADE
* );
*
* CREATE TABLE sessions (
* id SERIAL,
* "sessionToken" VARCHAR(255) NOT NULL,
* "userId" INTEGER NOT NULL,
* expires TIMESTAMPTZ NOT NULL,
* "sessionToken" VARCHAR(255) NOT NULL,
*
* PRIMARY KEY (id)
* PRIMARY KEY ("sessionToken"),
* FOREIGN KEY ("userId") REFERENCES users(id) ON DELETE CASCADE
* );
*
* CREATE TABLE users (
* id SERIAL,
* name VARCHAR(255),
* email VARCHAR(255),
* "emailVerified" TIMESTAMPTZ,
* image TEXT,
*
* PRIMARY KEY (id)
* );
* COMMIT;
* ```
*
* ```ts title="auth.ts"
Expand Down Expand Up @@ -108,7 +134,7 @@ export function mapExpiresAt(account: any): any {
* ```
*
*/
export default function PostgresAdapter(client: Pool): Adapter {
export default function PostgresAdapter(pool: Pool): Adapter {
return {
async createVerificationToken(
verificationToken: VerificationToken
Expand All @@ -118,7 +144,7 @@ export default function PostgresAdapter(client: Pool): Adapter {
INSERT INTO verification_token ( identifier, expires, token )
VALUES ($1, $2, $3)
`
await client.query(sql, [identifier, expires, token])
await pool.query(sql, [identifier, expires, token])
return verificationToken
},
async useVerificationToken({
Expand All @@ -130,8 +156,8 @@ export default function PostgresAdapter(client: Pool): Adapter {
}): Promise<VerificationToken> {
const sql = `delete from verification_token
where identifier = $1 and token = $2
RETURNING identifier, expires, token `
const result = await client.query(sql, [identifier, token])
RETURNING identifier, expires, token`
const result = await pool.query(sql, [identifier, token])
return result.rowCount !== 0 ? result.rows[0] : null
},

Expand All @@ -141,26 +167,21 @@ export default function PostgresAdapter(client: Pool): Adapter {
INSERT INTO users (name, email, "emailVerified", image)
VALUES ($1, $2, $3, $4)
RETURNING id, name, email, "emailVerified", image`
const result = await client.query(sql, [
name,
email,
emailVerified,
image,
])
const result = await pool.query(sql, [name, email, emailVerified, image])
return result.rows[0]
},
async getUser(id) {
const sql = `select * from users where id = $1`
try {
const result = await client.query(sql, [id])
const result = await pool.query(sql, [id])
return result.rowCount === 0 ? null : result.rows[0]
} catch (e) {
return null
}
},
async getUserByEmail(email) {
const sql = `select * from users where email = $1`
const result = await client.query(sql, [email])
const result = await pool.query(sql, [email])
return result.rowCount !== 0 ? result.rows[0] : null
},
async getUserByAccount({
Expand All @@ -174,34 +195,29 @@ export default function PostgresAdapter(client: Pool): Adapter {
and
a."providerAccountId" = $2`

const result = await client.query(sql, [provider, providerAccountId])
const result = await pool.query(sql, [provider, providerAccountId])
return result.rowCount !== 0 ? result.rows[0] : null
},
async updateUser(user: Partial<AdapterUser>): Promise<AdapterUser> {
const fetchSql = `select * from users where id = $1`
const query1 = await client.query(fetchSql, [user.id])
const oldUser = query1.rows[0]

const newUser = {
...oldUser,
...user,
}

const { id, name, email, emailVerified, image } = newUser
async updateUser(
user: Partial<AdapterUser> & Pick<AdapterUser, "id">
): Promise<AdapterUser> {
const [updatedCols, values] = createParameterizedUpdate(user, "id", [
"name",
"email",
"emailVerified",
"image",
])
const updateSql = `
UPDATE users set
name = $2, email = $3, "emailVerified" = $4, image = $5
${updatedCols}
where id = $1
RETURNING name, id, email, "emailVerified", image
`
const query2 = await client.query(updateSql, [
id,
name,
email,
emailVerified,
image,
])
return query2.rows[0]
const query = await pool.query(updateSql, values)
if (query.rows.length === 0) {
throw Error(`userId {user.id} does not exist`)
}
return query.rows[0]
},
async linkAccount(account) {
const sql = `
Expand Down Expand Up @@ -249,18 +265,18 @@ export default function PostgresAdapter(client: Pool): Adapter {
account.token_type,
]

const result = await client.query(sql, params)
const result = await pool.query(sql, params)
return mapExpiresAt(result.rows[0])
},
async createSession({ sessionToken, userId, expires }) {
if (userId === undefined) {
throw Error(`userId is undef in createSession`)
}
const sql = `insert into sessions ("userId", expires, "sessionToken")
const sql = `insert into sessions ("sessionToken", "userId", expires)
values ($1, $2, $3)
RETURNING id, "sessionToken", "userId", expires`
RETURNING "sessionToken", "userId", expires`

const result = await client.query(sql, [userId, expires, sessionToken])
const result = await pool.query(sql, [sessionToken, userId, expires])
return result.rows[0]
},

Expand All @@ -271,7 +287,7 @@ export default function PostgresAdapter(client: Pool): Adapter {
if (sessionToken === undefined) {
return null
}
const result1 = await client.query(
const result1 = await pool.query(
`select * from sessions where "sessionToken" = $1`,
[sessionToken]
)
Expand All @@ -280,7 +296,7 @@ export default function PostgresAdapter(client: Pool): Adapter {
}
let session: AdapterSession = result1.rows[0]

const result2 = await client.query("select * from users where id = $1", [
const result2 = await pool.query("select * from users where id = $1", [
session.userId,
])
if (result2.rowCount === 0) {
Expand All @@ -295,44 +311,34 @@ export default function PostgresAdapter(client: Pool): Adapter {
async updateSession(
session: Partial<AdapterSession> & Pick<AdapterSession, "sessionToken">
): Promise<AdapterSession | null | undefined> {
const { sessionToken } = session
const result1 = await client.query(
`select * from sessions where "sessionToken" = $1`,
[sessionToken]
const [updatedCols, values] = createParameterizedUpdate(
session,
"sessionToken",
["expires"]
)
if (result1.rowCount === 0) {
return null
}
const originalSession: AdapterSession = result1.rows[0]

const newSession: AdapterSession = {
...originalSession,
...session,
}
const sql = `
UPDATE sessions set
expires = $2
${updatedCols}
where "sessionToken" = $1
RETURNING "sessionToken", "userId", expires
`
const result = await client.query(sql, [
newSession.sessionToken,
newSession.expires,
])
const result = await pool.query(sql, values)
if (result.rows.length === 0) {
return null
}
return result.rows[0]
},
async deleteSession(sessionToken) {
const sql = `delete from sessions where "sessionToken" = $1`
await client.query(sql, [sessionToken])
await pool.query(sql, [sessionToken])
},
async unlinkAccount(partialAccount) {
const { provider, providerAccountId } = partialAccount
const sql = `delete from accounts where "providerAccountId" = $1 and provider = $2`
await client.query(sql, [providerAccountId, provider])
await pool.query(sql, [providerAccountId, provider])
},
async deleteUser(userId: string) {
await client.query(`delete from users where id = $1`, [userId])
await client.query(`delete from sessions where "userId" = $1`, [userId])
await client.query(`delete from accounts where "userId" = $1`, [userId])
await pool.query(`delete from users where id = $1`, [userId])
},
}
}

0 comments on commit 2b7d139

Please sign in to comment.