diff --git a/src/persistence/EntityPersistExecutor.ts b/src/persistence/EntityPersistExecutor.ts index f41d4dcaad..67a5616093 100644 --- a/src/persistence/EntityPersistExecutor.ts +++ b/src/persistence/EntityPersistExecutor.ts @@ -81,10 +81,28 @@ export class EntityPersistExecutor { if (entityTarget === Object) throw new CannotDetermineEntityError(this.mode) + let metadata = this.connection.getMetadata(entityTarget) + + // Check for single table inheritance and find the correct metadata in that case. + // Goal is to use the correct discriminator as we could have a repository + // for an (abstract) base class and thus the target would not match. + if ( + metadata.inheritancePattern === "STI" && + metadata.childEntityMetadatas.length > 0 + ) { + const matchingChildMetadata = + metadata.childEntityMetadatas.find( + (meta) => + entity.constructor === meta.target, + ) + if (matchingChildMetadata) { + metadata = matchingChildMetadata + } + } + subjects.push( new Subject({ - metadata: - this.connection.getMetadata(entityTarget), + metadata, entity: entity, canBeInserted: this.mode === "save", canBeUpdated: this.mode === "save", diff --git a/test/github-issues/2927/entity/Content.ts b/test/github-issues/2927/entity/Content.ts new file mode 100644 index 0000000000..dbc77c0e70 --- /dev/null +++ b/test/github-issues/2927/entity/Content.ts @@ -0,0 +1,28 @@ +import { + Column, + Entity, + PrimaryGeneratedColumn, + TableInheritance, +} from "../../../../src" + +export enum ContentType { + Photo = "photo", + Post = "post", + SpecialPhoto = "special_photo", +} + +@Entity() +@TableInheritance({ + pattern: "STI", + column: { type: "enum", name: "content_type", enum: ContentType }, +}) +export class Content { + @PrimaryGeneratedColumn() + id: number + + @Column() + title: string + + @Column() + description: string +} diff --git a/test/github-issues/2927/entity/Photo.ts b/test/github-issues/2927/entity/Photo.ts new file mode 100644 index 0000000000..579bdd1477 --- /dev/null +++ b/test/github-issues/2927/entity/Photo.ts @@ -0,0 +1,8 @@ +import { ChildEntity, Column } from "../../../../src" +import { Content, ContentType } from "./Content" + +@ChildEntity(ContentType.Photo) +export class Photo extends Content { + @Column() + size: number +} diff --git a/test/github-issues/2927/entity/Post.ts b/test/github-issues/2927/entity/Post.ts new file mode 100644 index 0000000000..40dcb66a35 --- /dev/null +++ b/test/github-issues/2927/entity/Post.ts @@ -0,0 +1,8 @@ +import { ChildEntity, Column } from "../../../../src" +import { Content, ContentType } from "./Content" + +@ChildEntity(ContentType.Post) +export class Post extends Content { + @Column() + viewCount: number +} diff --git a/test/github-issues/2927/entity/SpecialPhoto.ts b/test/github-issues/2927/entity/SpecialPhoto.ts new file mode 100644 index 0000000000..40b88b4585 --- /dev/null +++ b/test/github-issues/2927/entity/SpecialPhoto.ts @@ -0,0 +1,9 @@ +import { ChildEntity, Column } from "../../../../src" +import { ContentType } from "./Content" +import { Photo } from "./Photo" + +@ChildEntity(ContentType.SpecialPhoto) +export class SpecialPhoto extends Photo { + @Column() + specialProperty: number +} diff --git a/test/github-issues/2927/issue-2927.ts b/test/github-issues/2927/issue-2927.ts new file mode 100644 index 0000000000..9eaf3f6358 --- /dev/null +++ b/test/github-issues/2927/issue-2927.ts @@ -0,0 +1,103 @@ +import "reflect-metadata" +import { + createTestingConnections, + closeTestingConnections, + reloadTestingDatabases, +} from "../../utils/test-utils" +import { DataSource } from "../../../src/data-source/index" +import { expect } from "chai" +import { Content } from "./entity/Content" +import { Photo } from "./entity/Photo" +import { SpecialPhoto } from "./entity/SpecialPhoto" +import { Post } from "./entity/Post" + +describe("github issues > #2927 When using base class' custom repository, the discriminator is ignored", () => { + let dataSources: DataSource[] + before( + async () => + (dataSources = await createTestingConnections({ + entities: [__dirname + "/entity/*{.js,.ts}"], + schemaCreate: true, + dropSchema: true, + })), + ) + beforeEach(() => reloadTestingDatabases(dataSources)) + after(() => closeTestingConnections(dataSources)) + + it("should use the correct subclass for inheritance when saving and retrieving concrete instance", () => + Promise.all( + dataSources.map(async (dataSource) => { + const entityManager = dataSource.createEntityManager() + const repository = entityManager.getRepository(Content) + + // Create and save a new Photo. + const photo = new Photo() + photo.title = "some title" + photo.description = "some description" + photo.size = 42 + await repository.save(photo) + + // Retrieve it back from the DB. + const contents = await repository.find() + expect(contents.length).to.equal(1) + expect(contents[0] instanceof Photo).to.equal(true) + const fetchedPhoto = contents[0] as Photo + expect(fetchedPhoto).to.eql(photo) + }), + )) + + it("should work for deeply nested classes", () => + Promise.all( + dataSources.map(async (dataSource) => { + const entityManager = dataSource.createEntityManager() + const repository = entityManager.getRepository(Content) + + // Create and save a new SpecialPhoto. + const specialPhoto = new SpecialPhoto() + specialPhoto.title = "some title" + specialPhoto.description = "some description" + specialPhoto.size = 42 + specialPhoto.specialProperty = 420 + await repository.save(specialPhoto) + + // Retrieve it back from the DB. + const contents = await repository.find() + expect(contents.length).to.equal(1) + expect(contents[0] instanceof SpecialPhoto).to.equal(true) + const fetchedSpecialPhoto = contents[0] as SpecialPhoto + expect(fetchedSpecialPhoto).to.eql(specialPhoto) + }), + )) + + it("should work for saving and fetching different subclasses", () => + Promise.all( + dataSources.map(async (dataSource) => { + const entityManager = dataSource.createEntityManager() + const repository = entityManager.getRepository(Content) + + // Create and save a new Post. + const post = new Post() + post.title = "some title" + post.description = "some description" + post.viewCount = 69 + + // Create and save a new SpecialPhoto. + const specialPhoto = new SpecialPhoto() + specialPhoto.title = "some title" + specialPhoto.description = "some description" + specialPhoto.size = 42 + specialPhoto.specialProperty = 420 + + await repository.save([post, specialPhoto]) + + // Retrieve them back from the DB. + const contents = await repository.find() + expect(contents.length).to.equal(2) + expect(contents.find((content) => content instanceof Post)).not + .to.be.undefined + expect( + contents.find((content) => content instanceof SpecialPhoto), + ).not.to.be.undefined + }), + )) +})