Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

add type constraints for various sql fragment functions #8688

Merged
merged 1 commit into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions packages/lesswrong/lib/make_voteable.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ export const apolloCacheVoteablePossibleTypes = () => {
}
}

const currentUserVoteResolver = (
resolver: SqlResolverJoin["resolver"],
): SqlResolver => ({field, currentUserField, join}) => join({
const currentUserVoteResolver = <N extends CollectionNameString>(
resolver: SqlResolverJoin<'Votes'>["resolver"],
): SqlResolver<N> => ({field, currentUserField, join}) => join({
table: "Votes",
type: "left",
on: {
Expand Down
17 changes: 9 additions & 8 deletions packages/lesswrong/lib/sql/ProjectionContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,11 @@ class ProjectionContext<N extends CollectionNameString = CollectionNameString> {
}

absoluteField(name: string, jsonifyStarSelector = true) {
if (name !== "*") {
name = `"${name}"`;
let absoluteName: string = name;
if (absoluteName !== "*") {
absoluteName = `"${name}"`;
}
const absoluteField = `"${this.primaryPrefix}".${name}`;
const absoluteField = `"${this.primaryPrefix}".${absoluteName}`;
return name.indexOf("*") > -1 && jsonifyStarSelector
? `ROW_TO_JSON(${absoluteField})`
: absoluteField;
Expand Down Expand Up @@ -217,9 +218,9 @@ class ProjectionContext<N extends CollectionNameString = CollectionNameString> {
}
}

addJoin({resolver, ...joinBase}: SqlResolverJoin) {
addJoin<J extends CollectionNameString>({resolver, ...joinBase}: SqlResolverJoin<J>) {
const spec = this.getJoinSpec(joinBase);
const subField = (name: string) => `"${spec.prefix}"."${name}"`;
const subField = (name: FieldName<J>) => `"${spec.prefix}"."${name}"`;
return resolver(subField);
}

Expand All @@ -245,7 +246,7 @@ class ProjectionContext<N extends CollectionNameString = CollectionNameString> {
* Get the arguments to pass to `sqlResolver` functions defined in
* collection schemas.
*/
getSqlResolverArgs(): SqlResolverArgs {
getSqlResolverArgs(): SqlResolverArgs<N> {
return {
field: this.absoluteField.bind(this),
currentUserField: this.currentUserField.bind(this),
Expand All @@ -261,7 +262,7 @@ class ProjectionContext<N extends CollectionNameString = CollectionNameString> {
: randomId(5, this.randIntCallback);
}

private getJoinSpec({table, type, on}: SqlJoinBase): SqlJoinSpec {
private getJoinSpec<J extends CollectionNameString>({table, type, on}: SqlJoinBase<J>): SqlJoinSpec {
for (const join of this.joins) {
if (
join.table === table &&
Expand All @@ -279,7 +280,7 @@ class ProjectionContext<N extends CollectionNameString = CollectionNameString> {
private compileJoin({table, type = "inner", on, prefix}: SqlJoinSpec): string {
const selectors: string[] = [];
for (const field in on) {
selectors.push(`"${prefix}"."${field}" = ${on[field]}`);
selectors.push(`"${prefix}"."${field}" = ${on[(field as keyof typeof on)]}`);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's pretty sad that TS can't infer this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah :(

}
const selector = selectors.join(" AND ");
return `${type.toUpperCase()} JOIN "${table}" "${prefix}" ON ${selector}`;
Expand Down
2 changes: 1 addition & 1 deletion packages/lesswrong/lib/sql/tests/testHelpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ export const TestCollection4 = {
graphqlArguments: "testCollection2Id: String",
resolver: async () => null,
sqlResolver: ({resolverArg, join}) => join({
table: "TestCollection2",
table: "TestCollection2" as CollectionNameString,
type: "left",
on: {
_id: resolverArg("testCollection2Id"),
Expand Down
2 changes: 2 additions & 0 deletions packages/lesswrong/lib/types/collectionTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ interface HasIdType {
_id: string
}

type HasIdCollectionNames = Exclude<Extract<ObjectsByCollectionName[CollectionNameString], HasIdType>['__collectionName'], undefined>;

// Common base type for everything with a userId field
interface HasUserIdType {
userId: string | null
Expand Down
28 changes: 15 additions & 13 deletions packages/lesswrong/lib/types/schemaTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,33 @@ interface CollectionFieldPermissions {

type FormInputType = 'text' | 'number' | 'url' | 'email' | 'textarea' | 'checkbox' | 'checkboxgroup' | 'radiogroup' | 'select' | 'datetime' | 'date' | keyof ComponentTypes;

type SqlFieldFunction = (fieldName: string) => string;
type FieldName<N extends CollectionNameString> = (keyof ObjectsByCollectionName[N] & string) | '*';

type SqlJoinBase = {
table: string,
type SqlFieldFunction<N extends CollectionNameString> = (fieldName: FieldName<N>) => string;

type SqlJoinBase<N extends CollectionNameString> = {
table: N,
type?: "inner" | "full" | "left" | "right",
on: Record<string, string>,
on: Partial<Record<FieldName<N>, string>>,
}

type SqlResolverJoin = SqlJoinBase & {
resolver: (field: SqlFieldFunction) => string,
type SqlResolverJoin<N extends CollectionNameString> = SqlJoinBase<N> & {
resolver: (field: SqlFieldFunction<N>) => string,
};

type SqlJoinSpec = SqlJoinBase & {
type SqlJoinSpec<N extends CollectionNameString = CollectionNameString> = SqlJoinBase<N> & {
prefix: string,
};

type SqlResolverArgs = {
field: SqlFieldFunction,
currentUserField: SqlFieldFunction,
join: (args: SqlResolverJoin) => string,
type SqlResolverArgs<N extends CollectionNameString> = {
field: SqlFieldFunction<N>,
currentUserField: SqlFieldFunction<'Users'>,
join: <J extends CollectionNameString>(args: SqlResolverJoin<J>) => string,
arg: (value: unknown) => string,
resolverArg: (name: string) => string,
}

type SqlResolver = (args: SqlResolverArgs) => string;
type SqlResolver<N extends CollectionNameString> = (args: SqlResolverArgs<N>) => string;

interface CollectionFieldSpecification<N extends CollectionNameString> extends CollectionFieldPermissions {
type?: any,
Expand All @@ -65,7 +67,7 @@ interface CollectionFieldSpecification<N extends CollectionNameString> extends C
addOriginalField?: boolean,
arguments?: string|null,
resolver: (root: ObjectsByCollectionName[N], args: any, context: ResolverContext, info?: any)=>any,
sqlResolver?: SqlResolver,
sqlResolver?: SqlResolver<N>,
},
blackbox?: boolean,
denormalized?: boolean,
Expand Down
6 changes: 3 additions & 3 deletions packages/lesswrong/lib/utils/schemaUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,11 @@ export const foreignKeyField = <CollectionName extends CollectionNameString>({
nullable,
}),
...(autoJoin ? {
sqlResolver: ({field, join}: SqlResolverArgs) => join({
sqlResolver: ({field, join}: SqlResolverArgs<CollectionName>) => join<HasIdCollectionNames>({
table: collectionName,
type: nullable ? "left" : "inner",
on: {
_id: field(idFieldName),
_id: field(idFieldName as FieldName<CollectionName>),
},
resolver: (foreignField) => foreignField("*"),
})
Expand Down Expand Up @@ -207,7 +207,7 @@ export const simplSchemaToGraphQLtype = (type: any): string|null => {

interface ResolverOnlyFieldArgs<N extends CollectionNameString> extends CollectionFieldSpecification<N> {
resolver: (doc: ObjectsByCollectionName[N], args: any, context: ResolverContext) => any,
sqlResolver?: SqlResolver,
sqlResolver?: SqlResolver<N>,
graphQLtype?: string|GraphQLScalarType|null,
graphqlArguments?: string|null,
}
Expand Down
10 changes: 5 additions & 5 deletions packages/lesswrong/unitTests/sql/ProjectionContext.tests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ describe("ProjectionContext", () => {
});
it("detects duplicate joins", () => {
const context = new ProjectionContext(TestCollection);
const join1: SqlResolverJoin = {
table: "TestTable2",
const join1: SqlResolverJoin<CollectionNameString> = {
table: "TestTable2" as CollectionNameString,
type: "left",
on: {
_id: "_id",
},
resolver: () => "",
};
const join2: SqlResolverJoin = {
table: "TestTable2",
const join2: SqlResolverJoin<CollectionNameString> = {
table: "TestTable2" as CollectionNameString,
type: "left",
on: {
userId: "userId",
},
} as AnyBecauseHard,
resolver: () => "",
};
expect(context.getJoins()).toHaveLength(0);
Expand Down