Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relationship fails when attribute name != column name #21

Open
shish opened this issue Jul 3, 2023 · 3 comments
Open

Relationship fails when attribute name != column name #21

shish opened this issue Jul 3, 2023 · 3 comments

Comments

@shish
Copy link

shish commented Jul 3, 2023

My model has a Survey class with owner_id attribute, which is using a different column name (user_id) for historic reasons

class User(Base):
    __tablename__ = "user"
    user_id: Mapped[int] = mapped_column("id", primary_key=True)
    username: Mapped[str]

class Survey(Base):
    __tablename__ = "survey"
    survey_id: Mapped[int] = mapped_column("id", primary_key=True)
    name: Mapped[str]
    owner_id: Mapped[int] = mapped_column("user_id", ForeignKey("user.id"))
    owner: Mapped[User] = relationship("User", backref="surveys", lazy=True)
import models

@strawberry_sqlalchemy_mapper.type(models.User)
class User:
    pass

@strawberry_sqlalchemy_mapper.type(models.Survey)
class Survey:
    pass

@strawberry.type
class Query:
    @strawberry.field
    def survey(self, info: Info, survey_id: int) -> typing.Optional[Survey]:
        db = info.context["db"]
        return db.execute(select(models.Survey).where(models.Survey.survey_id == survey_id)).scalars().first()

In relationship_resolver_for, the code tries to access getattr(self, sql_column_name) instead of getattr(self, python_attr_name)

query MyQuery {
  survey(surveyId: 1) {
    name
    owner {
      username
    }
  }
}
  File ".../strawberry_sqlalchemy_mapper/mapper.py", line 409, in <listcomp>
    getattr(self, local.key)
AttributeError: 'Survey' object has no attribute 'user_id'

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar
@TimDumol TimDumol self-assigned this Jul 3, 2023
@cpsnowden
Copy link

cpsnowden commented Sep 5, 2023

@TimDumol , we ran into this issue ourselves and see errors from two places where the relationship value is resolved on the respective row using the sql_column_name rather than the python_attr_name

StrawberrySQLAlchemyLoader#loader_for

 def group_by_remote_key(row: Any) -> Tuple:
                    return tuple(
                        [
                            getattr(row, remote.key) <- uses sql_column_name
                            for _, remote in relationship.local_remote_pairs
                        ]
                    )

StrawberrySQLAlchemyMapper#relationship_resolver_for

 relationship_key = tuple(
                    [
                        getattr(self, local.key) <- uses sql_column_name
                        for local, _ in relationship.local_remote_pairs
                    ]
                )

We have a temporary work around by overriding the respective methods and building a column name to attribute name map from the respective relationship mapper but keen to have a central fix for this.

I'm happy to contribute a fix if we can agree an approach.

Example fix:

def build_get_col(mapper):
    attr_names = mapper.attr.keys()
    col_to_attr = {
        mapper.c[attr_name].name: attr_name for attr_name in attr_names if attr_name in mapper.c
    }
    def get_col(row: Any, col: str):
        attr = col_to_attr[col]
        return getattr(row, attr)
    return get_col

##StrawberrySQLAlchemyLoader
def loader_for(self, relationship: RelationshipProperty) -> DataLoader:
        """
        Retrieve or create a DataLoader for the given relationship
        """
        try:
            return self._loaders[relationship]
        except KeyError:
            related_model = relationship.entity.entity

            get_col = build_get_col(related_model.mapper) #get_col created here
            async def load_fn(keys: List[Tuple]) -> List[Any]:
                query = select(related_model).filter(
                    tuple_(
                        *[remote for _, remote in relationship.local_remote_pairs]
                    ).in_(keys)
                )
                if relationship.order_by:
                    query = query.order_by(*relationship.order_by)
                rows = self.bind.scalars(query).all()

                def group_by_remote_key(row: Any) -> Tuple:
                    return tuple(
                        [
                            get_col(row, remote.key)
                            for _, remote in relationship.local_remote_pairs
                        ]
                    )

                grouped_keys: Mapping[Tuple, List[Any]] = defaultdict(list)
                for row in rows:
                    grouped_keys[group_by_remote_key(row)].append(row)
                if relationship.uselist:
                    return [grouped_keys[key] for key in keys]
                else:
                    return [
                        grouped_keys[key][0] if grouped_keys[key] else None
                        for key in keys
                    ]

            self._loaders[relationship] = DataLoader(load_fn=load_fn)
            return self._loaders[relationship]

##StrawberrySQLAlchemyMapper
def relationship_resolver_for(
        self, relationship: RelationshipProperty
    ) -> Callable[..., Awaitable[Any]]:
        """
        Return an async field resolver for the given relationship,
        so as to avoid n+1 query problem.
        """
        get_col = build_get_col(relationship.parent) #get_col created here
        async def resolve(self, info: Info):
            instance_state = cast(InstanceState, inspect(self))
            if relationship.key not in instance_state.unloaded:
                related_objects = getattr(self, relationship.key)
            else:
                relationship_key = tuple(
                    [
                        get_col(self, local.key)
                        for local, _ in relationship.local_remote_pairs
                    ]
                )
                if any(item is None for item in relationship_key):
                    if relationship.uselist:
                        return []
                    else:
                        return None
                if isinstance(info.context, dict):
                    loader = info.context["sqlalchemy_loader"]
                else:
                    loader = info.context.sqlalchemy_loader
                related_objects = await loader.loader_for(relationship).load(
                    relationship_key
                )
            return related_objects

        setattr(resolve, _IS_GENERATED_RESOLVER_KEY, True)

        return resolve

@TimDumol TimDumol removed their assignment Sep 5, 2023
@TimDumol
Copy link
Collaborator

TimDumol commented Sep 5, 2023

Hi @cpsnowden - sorry totally forgot I assigned myself to this. Your proposed fix looks good to me. Feel free to PR it!

@cpsnowden
Copy link

Thanks @TimDumol - see that @gravy-jones-locker is addressing this in #25

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants