forked from kolypto/py-nplus1loader
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bulk_load.py
197 lines (158 loc) · 8.38 KB
/
bulk_load.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
""" Bulk load an attribute for multiple instances at once """
from typing import Tuple, Iterable, Callable, Optional
from funcy import chunks
from sqlalchemy import Column, tuple_
from sqlalchemy.orm.attributes import set_committed_value
from sqlalchemy.orm.util import identity_key
from sqlalchemy.sql.elements import BinaryExpression
from sqlalchemy.orm.state import InstanceState
from sqlalchemy.orm import Mapper, Session, Query, defaultload, joinedload
# Signature of a function that may alter the Query
# (query: Query, mapper: Mapper, attr_name: str, is_relationship: bool)
QueryAlterator = Callable[[Query, Mapper, str, bool], Query]
# These functions implement bulk lazy loading
# In other words: given a list of instances, it can load one particular attribute's value on all of them.
def bulk_load_attribute_for_instance_states(session: Session, mapper: Mapper,
states: Iterable[InstanceState], attr_name: str,
alter_query: Optional[QueryAlterator],
):
""" Given a list of partially-loaded instances, augment them with an attribute `attr_name` by loading it from the DB
It will augment all instances in chunks, not all at once.
Args:
session: The Session to use for loading
mapper: The Mapper all those instances are handled with
states: The instances to augment
attr_name: The attribute to load
alter_query: A function to alter the query that loads columns
"""
# Are we dealing with a column, or with a relationship?
if attr_name in mapper.columns:
loader_func = _bulk_load_column_for_instance_states
elif attr_name in mapper.relationships:
loader_func = _bulk_load_relationship_for_instance_states
else:
# Neither a column nor a relationship. What is it?
raise KeyError(attr_name)
# We're going to make SQL queries, so we have to temporarily disable Session's autoflush.
# If we don't, it may try to save any unsaved instances.
with session.no_autoflush:
# Iterate those instances in bite-size chunks
# `500` is the number SqlAlchemy uses internally with SelectInLoader
for states_chunk in chunks(500, states):
# First, collect primary keys from those incomplete instances
identities = (state.identity for state in states_chunk)
# Now, augment those instances by loading the missing attribute `attr_name` from the database
loader_func(session, mapper, identities, attr_name, alter_query)
def _bulk_load_column_for_instance_states(session: Session, mapper: Mapper,
identities: Iterable[Tuple], attr_name: str,
alter_query: Optional[QueryAlterator]):
""" Load a column attribute for a list of instance states where the attribute is unloaded """
Model = mapper.class_
attr: Column = mapper.columns[attr_name]
# Using those identities (primary keys), load the missing attribute
q = load_by_primary_keys(session, mapper, identities, attr)
# Alter the query
if alter_query:
q = alter_query(q, mapper, attr_name, False)
# Having the missing attribute's value loaded, assign it to every instance in the session
for identity, attr_value in q:
# Build the identity key the way SqlAlchemy likes it:
# (Model, primary-key, None)
key = identity_key(Model, identity)
# We do not iterate the Session to find an instance that matches the primary key.
# Instead, we take it directly using the `identity_map`
instance = session.identity_map[key]
# Set the value of the missing attribute.
# This is how it immediately becomes loaded.
# Note that this action does not overwrite any modifications made to the attribute.
set_committed_value(instance, attr_name, attr_value)
def _bulk_load_relationship_for_instance_states(session: Session, mapper: Mapper,
identities: Iterable[Tuple], attr_name: str,
alter_query: Optional[QueryAlterator]):
""" Load a relationship attribute for a list of instance states where the attribute is unloaded """
Model = mapper.class_
relationship: Column = mapper.all_orm_descriptors[attr_name]
# Prepare the primary key
pk_columns = get_primary_key_columns(mapper)
pk_column_names = [col.key for col in mapper.primary_key]
# Using those identities (primary keys), load the missing attribute from the DB and put it into instances
#
# Note that we won't do anything manually here. We just make a query, and seemingly throw it away.
# But what happens here is that we have a model that's partially loaded:
# defaultload(Model).load_only(primary key fields)
# This tells SqlAlchemy that the query contains `Model` instances.
# Then we load a relationship using joinedload().
#
# Because all those instances are already in SqlAlchemy's Session which maintains an identity map,
# when those additional relationships are loaded from the database... they will automatically augment
# the instances that are already in the session.
#
# So here we rely on the fact that as soon as other, yet unloaded, fields become available,
# SqlAlchemy adds them to existing instances (!)
#
# Magic.
q = session.query(Model).options(
defaultload(Model).load_only(*pk_column_names),
joinedload(relationship)
).filter(
build_primary_key_condition(pk_columns, identities)
)
# Alter the query
if alter_query:
q = alter_query(q, mapper, attr_name, True)
# Finally, exeucte the query
q.all()
def load_by_primary_keys(session: Session, mapper: Mapper, identities: Iterable[Tuple], *entities) -> Query:
""" Given a Session, load many instances using a list of their primary keys
Args:
session: The Session to use for loading
mapper: The mapper to filter the primary keys from
identities: An itarable of identities (primary key tuples)
entities: Additional entities to load with ssn.query(...)
Returns:
A Query.
First field "pk": the identity tuple (the primary key)
Other fields: the *entities you wanted loaded
"""
pk_columns = get_primary_key_columns(mapper)
# Load many instances by their primary keys
#
# First of all, we need to load the primary key, as well as the missing column's value, so it looks like we need
# pk_col1, pk_col2, ..., attr_value
# But then in Python we would have to slice the list.
# But because Postgres supports tuples, we select a tuple of the primary key instead:
# (pk_col1, pk_col2, ...), attr_value
# Just two columns, one being a composite primary key.
# It perfectly matches SqlAlchemy's instance identity, which is a tuple of primary keys.
#
# Secondly, the primary key condition. We're going to load N intances by their primary keys.
# We could have done like this:
# WHERE (pk_col1=:val AND pk_col2=:val) OR (pk_col1=:val AND pk_col2=:val) OR ...
# but once again, tuples are very convenient and readable:
# WHERE (pk_col1, pk_col2) IN ((:val, :val), (:val, :val), ...)
#
# Thanks for this neat trick, @vdmit11 :)
return session.query(
# That's the primary key tuple
tuple_(*pk_columns).label('pk'),
# Additional entities you want to load
*entities
).filter(
build_primary_key_condition(pk_columns, identities)
)
def build_primary_key_condition(pk_columns: Tuple[Column], identities: Iterable[Tuple]) -> BinaryExpression:
""" Build an IN(...) condition for a primary key to select many instances at once
Args:
pk_columns: The columns to filter with
identities: An iterable of identities (primary key values)
This conditon builder uses tuples:
tuple(primary-key-columns) = [ tuple(primary-key-values), ... ]
The resulting query looks like this:
WHERE (pk_col1, pk_col2) IN ((:val, :val), (:val, :val), ...)
"""
return tuple_(*pk_columns).in_(identities)
def get_primary_key_columns(mapper: Mapper) -> Tuple[Column]:
""" Get a tuple of primary key columns for a Mapper
If you have a Model, use get_mapper(model)
"""
return mapper.primary_key