-
Notifications
You must be signed in to change notification settings - Fork 11
/
PostgresGroupDAO.scala
358 lines (314 loc) · 20.3 KB
/
PostgresGroupDAO.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
package org.broadinstitute.dsde.workbench.sam.dataAccess
import akka.http.scaladsl.model.StatusCodes
import org.broadinstitute.dsde.workbench.model._
import org.broadinstitute.dsde.workbench.sam.db.PSQLStateExtensions
import org.broadinstitute.dsde.workbench.sam.db.SamParameterBinderFactory._
import org.broadinstitute.dsde.workbench.sam.db.SamTypeBinders._
import org.broadinstitute.dsde.workbench.sam.db.tables._
import org.broadinstitute.dsde.workbench.sam.errorReportSource
import org.broadinstitute.dsde.workbench.sam.model.FullyQualifiedPolicyId
import org.postgresql.util.PSQLException
import scalikejdbc.DBSession
import scalikejdbc.interpolation.SQLSyntax
import java.time.Instant
import scala.util.{Failure, Try}
/** The sam group model is implemented using 2 tables GroupTable and GroupMemberTable. GroupTable stores the name, email address and other top level information
* of the group. GroupMemberTable stores member groups and users. Note that this is a recursive structure; groups can contain groups and groups may be members
* of more than one group. Querying this structure is expensive because it requires recursive queries. The FlatGroupMemberTable is used to shift that burden
* from read time to write time. FlatGroupMemberTable stores all the members of a group, direct and inherited. This makes membership queries straight forward,
* less database intensive and fast. FlatGroupMemberTable also stores the path through which a user/group is a member. This information is required to update
* the flat membership structure without recalculating large swaths when top level groups change.
*
* A FlatGroupMemberTable contains: the id of the group containing the member the id of the member user or group the path to the member: an array of group ids
* indicating the path from group_id to member id starts with group_id, exclusive of the member id last_group_membership_element, the last group is in the path
* above, is tracked separately so it can be indexed
*
* Example database records: group(7795) contains user(userid) group(7798) contains user(userid) group(7801) contains group(7798) and group(7799)
*
* testdb=# select * from sam_group_member; id | group_id | member_group_id | member_user_id
* -------+----------+-----------------+---------------- 15636 | 7795 | | userid 15637 | 7798 | | userid 15638 | 7801 | 7798 | 15639 | 7801 | 7799 |
*
* testdb=# select * from sam_group_member_flat; id | group_id | member_group_id | member_user_id | group_membership_path | last_group_membership_element
* --------+----------+-----------------+----------------+-----------------------+------------------------------ 345985 | 7795 | | userid | {7795} | 7795
* 345986 | 7798 | | userid | {7798} | 7798 345987 | 7801 | 7798 | | {7801} | 7801 345988 | 7801 | 7799 | | {7801} | 7801 345989 | 7801 | | userid |
* {7801,7798} | 7798
*
* It is crucial that all group updates are in serializable transactions to avoid race conditions when concurrent modifications are made affecting the same
* group structure.
*/
trait PostgresGroupDAO {
protected def insertGroupMembers(groupId: GroupPK, members: Set[WorkbenchSubject])(implicit session: DBSession): Int = {
val memberGroupPKs = queryForGroupPKs(members)
val memberUserIds = collectUserIds(members)
insertGroupMemberPKs(groupId, memberGroupPKs, memberUserIds)
}
protected def collectUserIds(members: Set[WorkbenchSubject]): List[WorkbenchUserId] =
members.collect { case userId: WorkbenchUserId =>
userId
}.toList
def verifyNoCycles(groupId: GroupPK, memberGroupPKs: List[GroupPK])(implicit session: DBSession): Unit =
if (memberGroupPKs.nonEmpty) {
val gmf = GroupMemberFlatTable.syntax("gmf")
val g = GroupTable.syntax("g")
val groupsCausingCycle =
samsql"""select ${g.result.email}
from ${GroupMemberFlatTable as gmf}
join ${GroupTable as g} on ${gmf.groupId} = ${g.id}
where ${gmf.groupId} in ($memberGroupPKs)
and ${gmf.memberGroupId} = $groupId""".map(_.get[WorkbenchEmail](g.resultName.email)).list().apply()
if (groupsCausingCycle.nonEmpty) {
throw new WorkbenchExceptionWithErrorReport(
ErrorReport(StatusCodes.BadRequest, s"Could not add member group(s) ${groupsCausingCycle.mkString("[", ",", "]")} because it would cause a cycle")
)
}
}
protected def insertGroupMemberPKs(groupId: GroupPK, memberGroupPKs: List[GroupPK], memberUserIds: List[WorkbenchUserId])(implicit session: DBSession): Int =
if (memberGroupPKs.isEmpty && memberUserIds.isEmpty) {
0
} else {
verifyNoCycles(groupId, memberGroupPKs)
val insertCount = insertGroupMembersIntoHierarchical(groupId, memberGroupPKs, memberUserIds)
if (insertCount > 0) {
// if nothing was inserted no need to change the flat structure, it would insert dup records
insertGroupMembersIntoFlat(groupId, memberGroupPKs, memberUserIds)
}
insertCount
}
private def insertGroupMembersIntoHierarchical(groupId: GroupPK, memberGroupPKs: List[GroupPK], memberUserIds: List[WorkbenchUserId])(implicit
session: DBSession
) = {
val memberUserValues: List[SQLSyntax] = memberUserIds.map { case userId: WorkbenchUserId =>
samsqls"(${groupId}, ${userId}, ${None})"
}
val memberGroupValues: List[SQLSyntax] = memberGroupPKs.map { groupPK =>
samsqls"(${groupId}, ${None}, ${groupPK})"
}
val gm = GroupMemberTable.column
samsql"""insert into ${GroupMemberTable.table} (${gm.groupId}, ${gm.memberUserId}, ${gm.memberGroupId})
values ${memberUserValues ++ memberGroupValues}
on conflict do nothing"""
.update()
.apply()
}
/** Inserting a user/group into a group requires 2 inserts into the flat group model: 1) insert direct membership - group_id is given group,
* member_group_id/member_user_id is give member, path contains only given group 2) insert indirect memberships - insert a record for every record where
* member_group_id is the given group, group_id is the same, member_group_id/member_user_id is the member, path is the same with given group appended
*
* Inserting a subgroup into a group requires a third insert to connect the subgroup's lower hierarchy: for all records from GroupMemberFlatTable where
* group_id is the subgroup id (all paths from the subgroup to a member), call these the tail records join GroupMemberFlatTable where member_group_id is the
* subgroup id and the last element of the path is the parent group (all paths to the subgroup by way of the parent), call these the head records insert a
* new record joining these pairs of paths, head to tail: group_id is the head group_id (first element of the path) member id is the tail member id path is
* the head path + tail path
*
* Example: Insert group T into group H. H starts empty but is already a member of groups A and B. T already has member groups X and Y which are empty. The
* flat group model starts containing: Group | Member Group | Path
* ------|--------------|------ A | H | {A} B | H | {B} T | X | {T} T | Y | {T}
*
* step 1 inserts direct membership of T in H Group | Member Group | Path
* ------|--------------|------ H | T | {H}
*
* step 2 inserts indirect memberships T in A and B Group | Member Group | Path
* ------|--------------|------ A | T | {A,H} B | T | {B,H}
*
* step 3 inserts T's lower group hierarchy so that X and Y are members of H, A and B. The tail records are all of the records above where Group is T: ((T,
* X, {T}), (T, Y, {T}) The head records are all of the records above where Member Group is T and the last path element is H: ((H, T, {H}), (A, T, {A,H}),
* (B, T, {B,H})) Group | Member Group | Path
* ------|--------------|------ H | X | {H,T} H | Y | {H,T} A | X | {A,H,T} A | Y | {A,H,T} B | X | {B,H,T} B | Y | {B,H,T}
*
* @param groupId
* group being added to
* @param memberGroupPKs
* new member group ids
* @param memberUserIds
* new member user ids
* @param session
*/
private def insertGroupMembersIntoFlat(groupId: GroupPK, memberGroupPKs: List[GroupPK], memberUserIds: List[WorkbenchUserId])(implicit
session: DBSession
): Unit = {
val fgmColumn = GroupMemberFlatTable.column
val fgm = GroupMemberFlatTable.syntax("fgm")
val directUserValues = memberUserIds.map(uid => samsqls"($uid, cast(null as BIGINT))")
val directGroupValues = memberGroupPKs.map(gpk => samsqls"(cast(null as varchar), $gpk)")
// insert direct memberships
samsql"""insert into ${GroupMemberFlatTable.table} (${fgmColumn.groupId}, ${fgmColumn.memberUserId}, ${fgmColumn.memberGroupId}, ${fgmColumn.groupMembershipPath}, ${fgmColumn.lastGroupMembershipElement})
select ${groupId}, insertValues.member_user_id, insertValues.member_group_id, array[$groupId], $groupId
from (values ${directUserValues ++ directGroupValues}) AS insertValues (member_user_id, member_group_id)""".update().apply()
// insert memberships where groupId is a subgroup
samsql"""insert into ${GroupMemberFlatTable.table} (${fgmColumn.groupId}, ${fgmColumn.memberUserId}, ${fgmColumn.memberGroupId}, ${fgmColumn.groupMembershipPath}, ${fgmColumn.lastGroupMembershipElement})
select ${fgm.groupId}, insertValues.member_user_id, insertValues.member_group_id, array_append(${fgm.groupMembershipPath}, $groupId), $groupId
from (values ${directUserValues ++ directGroupValues}) AS insertValues (member_user_id, member_group_id),
${GroupMemberFlatTable as fgm}
where ${fgm.memberGroupId} = $groupId""".update().apply()
if (memberGroupPKs.nonEmpty) {
// insert subgroup memberships
val tail = GroupMemberFlatTable.syntax("tail")
val head = GroupMemberFlatTable.syntax("head")
samsql"""insert into ${GroupMemberFlatTable.table} (${fgmColumn.groupId}, ${fgmColumn.memberUserId}, ${fgmColumn.memberGroupId}, ${fgmColumn.groupMembershipPath}, ${fgmColumn.lastGroupMembershipElement})
select ${head.groupId}, ${tail.memberUserId}, ${tail.memberGroupId}, array_cat(${head.groupMembershipPath}, ${tail.groupMembershipPath}), ${tail.groupMembershipPath}[array_upper(${tail.groupMembershipPath}, 1)]
from ${GroupMemberFlatTable as tail}
join ${GroupMemberFlatTable as head} on ${head.memberGroupId} = ${tail.groupId}
where ${tail.groupId} in ($memberGroupPKs)
and ${head.lastGroupMembershipElement} = ${groupId}""".update().apply()
}
}
def removeAllGroupMembers(groupPK: GroupPK)(implicit session: DBSession): Int = {
removeAllMembersFromFlatGroup(groupPK)
removeAllMembersFromHierarchy(groupPK)
}
private def removeAllMembersFromHierarchy(groupPK: GroupPK)(implicit session: DBSession) = {
val gm = GroupMemberTable.syntax("gm")
samsql"delete from ${GroupMemberTable as gm} where ${gm.groupId} = ${groupPK}".update().apply()
}
private def removeAllMembersFromFlatGroup(groupPK: GroupPK)(implicit session: DBSession) = {
// removing all members means all rows where groupPK has members (groupId = groupPK) or descendants (groupPK in groupMembershipPath)
val f = GroupMemberFlatTable.syntax("f")
samsql"delete from ${GroupMemberFlatTable as f} where ${f.groupId} = ${groupPK}".update().apply()
samsql"delete from ${GroupMemberFlatTable as f} where array_position(${f.groupMembershipPath}, ${groupPK}) is not null".update().apply()
}
def removeGroupMember(groupId: WorkbenchGroupIdentity, removeMember: WorkbenchSubject)(implicit session: DBSession): Boolean =
removeMember match {
case memberUser: WorkbenchUserId =>
removeMemberUserFromFlatGroup(groupId, memberUser)
removeMemberUserFromHierarchy(groupId, memberUser)
case memberGroup: WorkbenchGroupIdentity =>
removeMemberGroupFromFlatGroup(groupId, memberGroup)
removeMemberGroupFromHierarchy(groupId, memberGroup)
case _ => throw new WorkbenchException(s"unexpected WorkbenchSubject $removeMember")
}
private def removeMemberGroupFromHierarchy(groupId: WorkbenchGroupIdentity, memberGroup: WorkbenchGroupIdentity)(implicit session: DBSession) = {
val groupMemberColumn = GroupMemberTable.column
samsql"""delete from ${GroupMemberTable.table}
where ${groupMemberColumn.groupId} = (${workbenchGroupIdentityToGroupPK(groupId)})
and ${groupMemberColumn.memberGroupId} = (${workbenchGroupIdentityToGroupPK(memberGroup)})""".update().apply() > 0
}
private def removeMemberGroupFromFlatGroup(groupId: WorkbenchGroupIdentity, memberGroup: WorkbenchGroupIdentity)(implicit session: DBSession) = {
val f = GroupMemberFlatTable.syntax("f")
// remove rows where memberGroup directly in groupId
samsql"""delete from ${GroupMemberFlatTable as f}
where ${f.memberGroupId} = (${workbenchGroupIdentityToGroupPK(memberGroup)})
and ${f.lastGroupMembershipElement} = (${workbenchGroupIdentityToGroupPK(groupId)})""".update().apply()
// remove rows where groupId is directly followed by memberGroup in membership path, these are indirect memberships
// The condition that uses @> is for performance, it allows the query to hit an index. It finds all rows where
// f.groupMembershipPath contains both groupId and memberGroup but in no particular order or placement.
// The condition using array_position is sufficient however @> uses an index and array_position does not.
// Think of it like an efficient pre filter so that the real condition has less to do.
samsql"""delete from ${GroupMemberFlatTable as f}
where array_position(${f.groupMembershipPath}, (${workbenchGroupIdentityToGroupPK(groupId)})) + 1 =
array_position(${f.groupMembershipPath}, (${workbenchGroupIdentityToGroupPK(memberGroup)}))
and ${f.groupMembershipPath} @> array[(${workbenchGroupIdentityToGroupPK(groupId)}), (${workbenchGroupIdentityToGroupPK(memberGroup)})]"""
.update()
.apply()
}
private def removeMemberUserFromHierarchy(groupId: WorkbenchGroupIdentity, memberUser: WorkbenchUserId)(implicit session: DBSession) = {
val groupMemberColumn = GroupMemberTable.column
samsql"""delete from ${GroupMemberTable.table}
where ${groupMemberColumn.groupId} = (${workbenchGroupIdentityToGroupPK(groupId)})
and ${groupMemberColumn.memberUserId} = ${memberUser}""".update().apply() > 0
}
private def removeMemberUserFromFlatGroup(groupId: WorkbenchGroupIdentity, memberUser: WorkbenchUserId)(implicit session: DBSession) = {
val f = GroupMemberFlatTable.syntax("f")
samsql"""delete from ${GroupMemberFlatTable as f}
where ${f.memberUserId} = $memberUser
and ${f.lastGroupMembershipElement} = (${workbenchGroupIdentityToGroupPK(groupId)})""".update().apply()
}
def isGroupMember(groupId: WorkbenchGroupIdentity, member: WorkbenchSubject)(implicit session: DBSession): Boolean = {
val f = GroupMemberFlatTable.syntax("f")
val query =
samsql"""SELECT count(*) FROM ${GroupMemberFlatTable as f}
WHERE ${memberClause(member, f)} AND ${f.groupId} = (${workbenchGroupIdentityToGroupPK(groupId)})"""
query.map(rs => rs.int(1)).single().apply().getOrElse(0) > 0
}
def updateGroupUpdatedDate(groupId: WorkbenchGroupIdentity)(implicit session: DBSession): Int = {
val g = GroupTable.column
samsql"update ${GroupTable.table} set ${g.updatedDate} = ${Instant.now()} where ${g.id} = (${workbenchGroupIdentityToGroupPK(groupId)})".update().apply()
}
def deleteGroup(groupName: WorkbenchGroupName)(implicit session: DBSession): Int = {
val g = GroupTable.syntax("g")
val maybeGroupPK = Try {
// foreign keys in accessInstructions and groupMember tables are set to cascade delete
// note: this will not remove this group from any parent groups and will throw a
// foreign key constraint violation error if group is still a member of any parent groups
samsql"delete from ${GroupTable as g} where ${g.name} = ${groupName} returning ${g.result.id}".map(_.get[GroupPK](g.resultName.id)).single().apply()
}.recoverWith {
case fkViolation: PSQLException if fkViolation.getSQLState == PSQLStateExtensions.FOREIGN_KEY_VIOLATION =>
Failure(
new WorkbenchExceptionWithErrorReport(
ErrorReport(StatusCodes.Conflict, s"group ${groupName.value} cannot be deleted because it is a member of at least 1 other group")
)
)
}.get
maybeGroupPK.foreach(removeAllMembersFromFlatGroup)
maybeGroupPK.size // this should be 0 or 1
}
private def memberClause(
member: WorkbenchSubject,
f: scalikejdbc.QuerySQLSyntaxProvider[scalikejdbc.SQLSyntaxSupport[GroupMemberFlatRecord], GroupMemberFlatRecord]
): SQLSyntax =
member match {
case subGroupId: WorkbenchGroupIdentity => samsqls"${f.memberGroupId} = (${workbenchGroupIdentityToGroupPK(subGroupId)})"
case WorkbenchUserId(userId) => samsqls"${f.memberUserId} = $userId"
case _ => throw new WorkbenchException(s"illegal member $member")
}
protected def queryForGroupPKs(members: Set[WorkbenchSubject])(implicit session: DBSession): List[GroupPK] = {
// group PK query
val memberGroupNames = members.collect { case groupName: WorkbenchGroupName =>
groupName
}
val gpk = GroupTable.syntax("g")
val groupPKStatement = samsqls"""select ${gpk.id} as group_id from ${GroupTable as gpk} where ${gpk.name} in ($memberGroupNames)"""
// policy group PK query
val memberPolicyIdTuples = members.collect { case policyId: FullyQualifiedPolicyId =>
samsqls"(${policyId.resource.resourceTypeName}, ${policyId.resource.resourceId}, ${policyId.accessPolicyName})"
}
val rt = ResourceTypeTable.syntax("rt")
val r = ResourceTable.syntax("r")
val p = PolicyTable.syntax("p")
val policyGroupPKStatement = samsqls"""select ${p.groupId} as group_id
from ${ResourceTypeTable as rt}
join ${ResourceTable as r} on ${rt.id} = ${r.resourceTypeId}
join ${PolicyTable as p} on ${r.id} = ${p.resourceId}
where (${rt.name}, ${r.name}, ${p.name}) in ($memberPolicyIdTuples)"""
// there are 4 scenarios: there are both member groups and policies, only one or the other or neither
// in the case there are both union both queries
// if only groups or members then only use only the appropriate query
// if neither don't make any query
val subgroupPKQuery = (memberGroupNames.nonEmpty, memberPolicyIdTuples.nonEmpty) match {
case (true, true) => Option(samsqls"$groupPKStatement union $policyGroupPKStatement")
case (true, false) => Option(groupPKStatement)
case (false, true) => Option(policyGroupPKStatement)
case (false, false) => None
}
val memberGroupPKs = subgroupPKQuery.map(x => samsql"$x".map(rs => rs.get[GroupPK]("group_id")).list().apply()).getOrElse(List.empty)
if (memberGroupPKs.size != memberGroupNames.size + memberPolicyIdTuples.size) {
throw new WorkbenchException(s"Some member groups not found.")
}
memberGroupPKs
}
def workbenchGroupIdentityToGroupPK(groupId: WorkbenchGroupIdentity): SQLSyntax =
groupId match {
case group: WorkbenchGroupName => groupPKQueryForGroup(group)
case policy: FullyQualifiedPolicyId => groupPKQueryForPolicy(policy)
}
private def groupPKQueryForGroup(groupName: WorkbenchGroupName, groupTableAlias: String = "gpk"): SQLSyntax = {
val gpk = GroupTable.syntax(groupTableAlias)
samsqls"select ${gpk.id} from ${GroupTable as gpk} where ${gpk.name} = $groupName"
}
private def groupPKQueryForPolicy(
policyId: FullyQualifiedPolicyId,
resourceTypeTableAlias: String = "rt",
resourceTableAlias: String = "r",
policyTableAlias: String = "p"
): SQLSyntax = {
val rt = ResourceTypeTable.syntax(resourceTypeTableAlias)
val r = ResourceTable.syntax(resourceTableAlias)
val p = PolicyTable.syntax(policyTableAlias)
samsqls"""select ${p.groupId}
from ${ResourceTypeTable as rt}
join ${ResourceTable as r} on ${rt.id} = ${r.resourceTypeId}
join ${PolicyTable as p} on ${r.id} = ${p.resourceId}
where ${rt.name} = ${policyId.resource.resourceTypeName}
and ${r.name} = ${policyId.resource.resourceId}
and ${p.name} = ${policyId.accessPolicyName}"""
}
}