Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(client): middleware args consistency #15897

Merged
merged 6 commits into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 16 additions & 18 deletions packages/client/src/runtime/core/model/aggregates/aggregate.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import type { Client } from '../../../getPrismaClient'
import type { ModelAction } from '../applyModel'
import type { UserArgs } from '../UserArgs'
import { aggregateMap } from './utils/aggregateMap'
Expand All @@ -7,11 +6,11 @@ import { aggregateMap } from './utils/aggregateMap'
* Transforms the `userArgs` for the `.aggregate` shorthand. It is an API sugar
* for not having to do things like: `{select: {_avg: {select: {age: true}}}}`.
* The goal here is to desugar it into something that is understood by the QE.
* @param userArgs to transform
* @param args to transform
* @returns
*/
export function desugarUserArgs(userArgs: UserArgs) {
const _userArgs = desugarCountInUserArgs(userArgs)
export function desugarUserArgs(args?: UserArgs) {
const _userArgs = desugarCountInUserArgs(args ?? {})
const userArgsEntries = Object.entries(_userArgs)

return userArgsEntries.reduce(
Expand All @@ -31,26 +30,26 @@ export function desugarUserArgs(userArgs: UserArgs) {

/**
* Desugar `userArgs` when it contains `{_count: true}`.
* @param userArgs the user input
* @param args the user input
* @returns
*/
function desugarCountInUserArgs(userArgs: UserArgs) {
if (typeof userArgs['_count'] === 'boolean') {
return { ...userArgs, _count: { _all: userArgs['_count'] } }
function desugarCountInUserArgs(args: UserArgs) {
if (typeof args['_count'] === 'boolean') {
return { ...args, _count: { _all: args['_count'] } }
}

return userArgs
return args
}

/**
* Creates an unpacker that adds sugar to the basic result of the QE. An
* unpacker helps to transform a result before returning it to the user.
* @param userArgs the user input
* @param args the user input
* @returns
*/
export function createUnpacker(userArgs: UserArgs) {
export function createUnpacker(args?: UserArgs) {
return (data: object) => {
if (typeof userArgs['_count'] === 'boolean') {
if (typeof args?.['_count'] === 'boolean') {
data['_count'] = data['_count']['_all']
}

Expand All @@ -61,17 +60,16 @@ export function createUnpacker(userArgs: UserArgs) {
/**
* Executes the `.aggregate` action on a model.
* @see {desugarUserArgs}
* @param client to provide dmmf information
* @param userArgs the user input to desugar
* @param args the user input to desugar
* @param modelAction a callback action that triggers request execution
* @returns
*/
export function aggregate(client: Client, userArgs: UserArgs | undefined, modelAction: ModelAction) {
const aggregateArgs = desugarUserArgs(userArgs ?? {})
const aggregateUnpacker = createUnpacker(userArgs ?? {})
export function aggregate(args: UserArgs | undefined, modelAction: ModelAction) {
const aggregateUnpacker = createUnpacker(args)

return modelAction({
action: 'aggregate',
unpacker: aggregateUnpacker,
})(aggregateArgs)
argsMapper: desugarUserArgs,
})(args)
}
55 changes: 39 additions & 16 deletions packages/client/src/runtime/core/model/aggregates/count.ts
Original file line number Diff line number Diff line change
@@ -1,27 +1,50 @@
import type { Client } from '../../../getPrismaClient'
import type { ModelAction } from '../applyModel'
import type { UserArgs } from '../UserArgs'
import { aggregate } from './aggregate'
import { createUnpacker as createUnpackerAggregate, desugarUserArgs as desugarUserArgsAggregate } from './aggregate'

/**
* Executes the `.count` action on a model via {@link aggregate}.
* @param client to provide dmmf information
* @param userArgs the user input to desugar
* @param modelAction a callback action that triggers request execution
* Transforms the `userArgs` for the `.count` shorthand. It is an API sugar. It
* reuses the logic from the `.aggregate` shorthand to add additional handling.
* The goal here is to desugar it into something that is understood by the QE.
* @param args to transform
* @returns
*/
function desugarUserArgs(args?: UserArgs) {
const { select, ..._userArgs } = args ?? {} // exclude select

if (typeof select === 'object') {
return desugarUserArgsAggregate({ ..._userArgs, _count: select })
} else {
return desugarUserArgsAggregate({ ..._userArgs, _count: { _all: true } })
}
}

/**
* Creates an unpacker that adds sugar to the basic result of the QE. An
* unpacker helps to transform a result before returning it to the user.
* @param args the user input
* @returns
*/
export function count(client: Client, userArgs: UserArgs | undefined, modelAction: ModelAction) {
const { select, ..._userArgs } = userArgs ?? {} // exclude select
export function createUnpacker(args?: UserArgs) {
const { select } = args ?? {}

// count is an aggregate, we reuse that but hijack its unpacker
if (typeof select === 'object') {
// we transpose the original select field into the _count field
return aggregate(client, { ..._userArgs, _count: select }, (p) =>
modelAction({ ...p, action: 'count', unpacker: (data) => p.unpacker?.(data)['_count'] }),
) // for count selects, return the relevant part of the result
return (data: object) => createUnpackerAggregate(args)(data)['_count']
} else {
return aggregate(client, { ..._userArgs, _count: { _all: true } }, (p) =>
modelAction({ ...p, action: 'count', unpacker: (data) => p.unpacker?.(data)['_count']['_all'] }),
) // for simple counts, just return the result that is a number
return (data: object) => createUnpackerAggregate(args)(data)['_count']['_all']
}
}

/**
* Executes the `.count` action on a model via {@link aggregate}.
* @param args the user input to desugar
* @param modelAction a callback action that triggers request execution
* @returns
*/
export function count(args: UserArgs | undefined, modelAction: ModelAction) {
return modelAction({
action: 'count',
unpacker: createUnpacker(args),
argsMapper: desugarUserArgs,
})(args)
}
30 changes: 13 additions & 17 deletions packages/client/src/runtime/core/model/aggregates/groupBy.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import type { Client } from '../../../getPrismaClient'
import type { ModelAction } from '../applyModel'
import type { UserArgs } from '../UserArgs'
import { desugarUserArgs as desugarUserArgsAggregate } from './aggregate'
Expand All @@ -8,15 +7,15 @@ import { desugarUserArgs as desugarUserArgsAggregate } from './aggregate'
* It reuses the logic from the `.aggregate` shorthand and adds additional
* handling for the `by` clause. The goal here is to desugar it into something
* that is understood by the QE.
* @param userArgs to transform
* @param args to transform
* @returns
*/
function desugarUserArgs(userArgs: UserArgs) {
const _userArgs = desugarUserArgsAggregate(userArgs)
function desugarUserArgs(args?: UserArgs) {
millsp marked this conversation as resolved.
Show resolved Hide resolved
const _userArgs = desugarUserArgsAggregate(args ?? {})
millsp marked this conversation as resolved.
Show resolved Hide resolved

// we desugar the array into { [key]: boolean }
if (Array.isArray(userArgs['by'])) {
for (const key of userArgs['by']) {
if (Array.isArray(_userArgs.by)) {
for (const key of _userArgs.by) {
if (typeof key === 'string') {
_userArgs['select'][key] = true
}
Expand All @@ -29,12 +28,12 @@ function desugarUserArgs(userArgs: UserArgs) {
/**
* Creates an unpacker that adds sugar to the basic result of the QE. An
* unpacker helps to transform a result before returning it to the user.
* @param userArgs the user input
* @param args the user input
* @returns
*/
export function createUnpacker(userArgs: UserArgs) {
export function createUnpacker(args?: UserArgs) {
return (data: object[]) => {
if (typeof userArgs['_count'] === 'boolean') {
if (typeof args?.['_count'] === 'boolean') {
data.forEach((row) => {
row['_count'] = row['_count']['_all']
})
Expand All @@ -46,17 +45,14 @@ export function createUnpacker(userArgs: UserArgs) {

/**
* Executes the `.groupBy` action on a model by reusing {@link aggregate}.
* @param client to provide dmmf information
* @param userArgs the user input to desugar
* @param args the user input to desugar
* @param modelAction a callback action that triggers request execution
* @returns
*/
export function groupBy(client: Client, userArgs: UserArgs | undefined, modelAction: ModelAction) {
const groupByArgs = desugarUserArgs(userArgs ?? {})
const groupByUnpacker = createUnpacker(userArgs ?? {})

export function groupBy(args: UserArgs | undefined, modelAction: ModelAction) {
return modelAction({
action: 'groupBy',
unpacker: groupByUnpacker,
})(groupByArgs)
unpacker: createUnpacker(args),
argsMapper: desugarUserArgs,
})(args)
}
7 changes: 3 additions & 4 deletions packages/client/src/runtime/core/model/applyAggregates.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@ import type { UserArgs } from './UserArgs'
* short, we manipulate the user input which is designed to have DX to transform
* it into something that the engines understand. Similarly, we take the engine
* output for that input and produce something that is easier to work with.
* @param client to provide dmmf information
* @param action that tells which aggregate action to execute
* @param modelAction a callback action that triggers request execution
* @returns
*/
export function applyAggregates(client: Client, action: Action, modelAction: ModelAction) {
// we effectively take over the aggregate api to perform data changes
if (action === 'aggregate') return (userArgs?: UserArgs) => aggregate(client, userArgs, modelAction)
if (action === 'count') return (userArgs?: UserArgs) => count(client, userArgs, modelAction)
if (action === 'groupBy') return (userArgs?: UserArgs) => groupBy(client, userArgs, modelAction)
if (action === 'aggregate') return (userArgs?: UserArgs) => aggregate(userArgs, modelAction)
if (action === 'count') return (userArgs?: UserArgs) => count(userArgs, modelAction)
if (action === 'groupBy') return (userArgs?: UserArgs) => groupBy(userArgs, modelAction)

return undefined
}
7 changes: 7 additions & 0 deletions packages/client/src/runtime/getPrismaClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import { PrismaClientValidationError } from '.'
import { $extends, Extension } from './core/extensions/$extends'
import { MetricsClient } from './core/metrics/MetricsClient'
import { applyModels } from './core/model/applyModels'
import { UserArgs } from './core/model/UserArgs'
import { createPrismaPromise } from './core/request/createPrismaPromise'
import type {
InteractiveTransactionOptions,
Expand Down Expand Up @@ -181,6 +182,8 @@ export type InternalRequestParams = {
unpacker?: Unpacker // TODO what is this
lock?: PromiseLike<void>
otelParentCtx?: Context
/** Used to "desugar" a user input into an "expanded" one */
argsMapper?: (args?: UserArgs) => UserArgs
} & Omit<QueryMiddlewareParams, 'runInTransaction'>

// only used by the .use() hooks
Expand Down Expand Up @@ -1106,6 +1109,7 @@ new PrismaClient({
action,
model,
headers,
argsMapper,
transaction,
lock,
unpacker,
Expand All @@ -1115,6 +1119,9 @@ new PrismaClient({
this._dmmf = await this._getDmmf({ clientMethod, callsite })
}

// execute argument transformation before execution
args = argsMapper ? argsMapper(args) : args

let rootField: string | undefined
const operation = actionOperationMap[action]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import { defineMatrix } from '../../_utils/defineMatrix'

export default defineMatrix(() => [
[
{
provider: 'sqlite',
},
],
])
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { idForProvider } from '../../../_utils/idForProvider'
import testMatrix from '../_matrix'

export default testMatrix.setupSchema(({ provider }) => {
return /* Prisma */ `
generator client {
provider = "prisma-client-js"
}

datasource db {
provider = "${provider}"
url = env("DATABASE_URI_${provider}")
}

model Resource {
id ${idForProvider(provider)}
}
`
})