Skip to content

Commit

Permalink
Merge pull request #8688 from ForumMagnum/sql-fragments-typed-branch
Browse files Browse the repository at this point in the history
add type constraints for various sql fragment functions
  • Loading branch information
oetherington committed Jan 30, 2024
2 parents ed9d13f + 88cd004 commit ef61491
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 33 deletions.
6 changes: 3 additions & 3 deletions packages/lesswrong/lib/make_voteable.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ export const apolloCacheVoteablePossibleTypes = () => {
}
}

const currentUserVoteResolver = (
resolver: SqlResolverJoin["resolver"],
): SqlResolver => ({field, currentUserField, join}) => join({
const currentUserVoteResolver = <N extends CollectionNameString>(
resolver: SqlResolverJoin<'Votes'>["resolver"],
): SqlResolver<N> => ({field, currentUserField, join}) => join({
table: "Votes",
type: "left",
on: {
Expand Down
17 changes: 9 additions & 8 deletions packages/lesswrong/lib/sql/ProjectionContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,11 @@ class ProjectionContext<N extends CollectionNameString = CollectionNameString> {
}

absoluteField(name: string, jsonifyStarSelector = true) {
if (name !== "*") {
name = `"${name}"`;
let absoluteName: string = name;
if (absoluteName !== "*") {
absoluteName = `"${name}"`;
}
const absoluteField = `"${this.primaryPrefix}".${name}`;
const absoluteField = `"${this.primaryPrefix}".${absoluteName}`;
return name.indexOf("*") > -1 && jsonifyStarSelector
? `ROW_TO_JSON(${absoluteField})`
: absoluteField;
Expand Down Expand Up @@ -217,9 +218,9 @@ class ProjectionContext<N extends CollectionNameString = CollectionNameString> {
}
}

addJoin({resolver, ...joinBase}: SqlResolverJoin) {
addJoin<J extends CollectionNameString>({resolver, ...joinBase}: SqlResolverJoin<J>) {
const spec = this.getJoinSpec(joinBase);
const subField = (name: string) => `"${spec.prefix}"."${name}"`;
const subField = (name: FieldName<J>) => `"${spec.prefix}"."${name}"`;
return resolver(subField);
}

Expand All @@ -245,7 +246,7 @@ class ProjectionContext<N extends CollectionNameString = CollectionNameString> {
* Get the arguments to pass to `sqlResolver` functions defined in
* collection schemas.
*/
getSqlResolverArgs(): SqlResolverArgs {
getSqlResolverArgs(): SqlResolverArgs<N> {
return {
field: this.absoluteField.bind(this),
currentUserField: this.currentUserField.bind(this),
Expand All @@ -261,7 +262,7 @@ class ProjectionContext<N extends CollectionNameString = CollectionNameString> {
: randomId(5, this.randIntCallback);
}

private getJoinSpec({table, type, on}: SqlJoinBase): SqlJoinSpec {
private getJoinSpec<J extends CollectionNameString>({table, type, on}: SqlJoinBase<J>): SqlJoinSpec {
for (const join of this.joins) {
if (
join.table === table &&
Expand All @@ -279,7 +280,7 @@ class ProjectionContext<N extends CollectionNameString = CollectionNameString> {
private compileJoin({table, type = "inner", on, prefix}: SqlJoinSpec): string {
const selectors: string[] = [];
for (const field in on) {
selectors.push(`"${prefix}"."${field}" = ${on[field]}`);
selectors.push(`"${prefix}"."${field}" = ${on[(field as keyof typeof on)]}`);
}
const selector = selectors.join(" AND ");
return `${type.toUpperCase()} JOIN "${table}" "${prefix}" ON ${selector}`;
Expand Down
2 changes: 1 addition & 1 deletion packages/lesswrong/lib/sql/tests/testHelpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ export const TestCollection4 = {
graphqlArguments: "testCollection2Id: String",
resolver: async () => null,
sqlResolver: ({resolverArg, join}) => join({
table: "TestCollection2",
table: "TestCollection2" as CollectionNameString,
type: "left",
on: {
_id: resolverArg("testCollection2Id"),
Expand Down
2 changes: 2 additions & 0 deletions packages/lesswrong/lib/types/collectionTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ interface HasIdType {
_id: string
}

type HasIdCollectionNames = Exclude<Extract<ObjectsByCollectionName[CollectionNameString], HasIdType>['__collectionName'], undefined>;

// Common base type for everything with a userId field
interface HasUserIdType {
userId: string | null
Expand Down
28 changes: 15 additions & 13 deletions packages/lesswrong/lib/types/schemaTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,33 @@ interface CollectionFieldPermissions {

type FormInputType = 'text' | 'number' | 'url' | 'email' | 'textarea' | 'checkbox' | 'checkboxgroup' | 'radiogroup' | 'select' | 'datetime' | 'date' | keyof ComponentTypes;

type SqlFieldFunction = (fieldName: string) => string;
type FieldName<N extends CollectionNameString> = (keyof ObjectsByCollectionName[N] & string) | '*';

type SqlJoinBase = {
table: string,
type SqlFieldFunction<N extends CollectionNameString> = (fieldName: FieldName<N>) => string;

type SqlJoinBase<N extends CollectionNameString> = {
table: N,
type?: "inner" | "full" | "left" | "right",
on: Record<string, string>,
on: Partial<Record<FieldName<N>, string>>,
}

type SqlResolverJoin = SqlJoinBase & {
resolver: (field: SqlFieldFunction) => string,
type SqlResolverJoin<N extends CollectionNameString> = SqlJoinBase<N> & {
resolver: (field: SqlFieldFunction<N>) => string,
};

type SqlJoinSpec = SqlJoinBase & {
type SqlJoinSpec<N extends CollectionNameString = CollectionNameString> = SqlJoinBase<N> & {
prefix: string,
};

type SqlResolverArgs = {
field: SqlFieldFunction,
currentUserField: SqlFieldFunction,
join: (args: SqlResolverJoin) => string,
type SqlResolverArgs<N extends CollectionNameString> = {
field: SqlFieldFunction<N>,
currentUserField: SqlFieldFunction<'Users'>,
join: <J extends CollectionNameString>(args: SqlResolverJoin<J>) => string,
arg: (value: unknown) => string,
resolverArg: (name: string) => string,
}

type SqlResolver = (args: SqlResolverArgs) => string;
type SqlResolver<N extends CollectionNameString> = (args: SqlResolverArgs<N>) => string;

interface CollectionFieldSpecification<N extends CollectionNameString> extends CollectionFieldPermissions {
type?: any,
Expand All @@ -65,7 +67,7 @@ interface CollectionFieldSpecification<N extends CollectionNameString> extends C
addOriginalField?: boolean,
arguments?: string|null,
resolver: (root: ObjectsByCollectionName[N], args: any, context: ResolverContext, info?: any)=>any,
sqlResolver?: SqlResolver,
sqlResolver?: SqlResolver<N>,
},
blackbox?: boolean,
denormalized?: boolean,
Expand Down
6 changes: 3 additions & 3 deletions packages/lesswrong/lib/utils/schemaUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,11 @@ export const foreignKeyField = <CollectionName extends CollectionNameString>({
nullable,
}),
...(autoJoin ? {
sqlResolver: ({field, join}: SqlResolverArgs) => join({
sqlResolver: ({field, join}: SqlResolverArgs<CollectionName>) => join<HasIdCollectionNames>({
table: collectionName,
type: nullable ? "left" : "inner",
on: {
_id: field(idFieldName),
_id: field(idFieldName as FieldName<CollectionName>),
},
resolver: (foreignField) => foreignField("*"),
})
Expand Down Expand Up @@ -207,7 +207,7 @@ export const simplSchemaToGraphQLtype = (type: any): string|null => {

interface ResolverOnlyFieldArgs<N extends CollectionNameString> extends CollectionFieldSpecification<N> {
resolver: (doc: ObjectsByCollectionName[N], args: any, context: ResolverContext) => any,
sqlResolver?: SqlResolver,
sqlResolver?: SqlResolver<N>,
graphQLtype?: string|GraphQLScalarType|null,
graphqlArguments?: string|null,
}
Expand Down
10 changes: 5 additions & 5 deletions packages/lesswrong/unitTests/sql/ProjectionContext.tests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ describe("ProjectionContext", () => {
});
it("detects duplicate joins", () => {
const context = new ProjectionContext(TestCollection);
const join1: SqlResolverJoin = {
table: "TestTable2",
const join1: SqlResolverJoin<CollectionNameString> = {
table: "TestTable2" as CollectionNameString,
type: "left",
on: {
_id: "_id",
},
resolver: () => "",
};
const join2: SqlResolverJoin = {
table: "TestTable2",
const join2: SqlResolverJoin<CollectionNameString> = {
table: "TestTable2" as CollectionNameString,
type: "left",
on: {
userId: "userId",
},
} as AnyBecauseHard,
resolver: () => "",
};
expect(context.getJoins()).toHaveLength(0);
Expand Down

0 comments on commit ef61491

Please sign in to comment.