-
Notifications
You must be signed in to change notification settings - Fork 1
/
MultiMergeSorted.scala
145 lines (116 loc) · 3.86 KB
/
MultiMergeSorted.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
package dev.chopsticks.stream
import org.apache.pekko.stream._
import org.apache.pekko.stream.scaladsl.{GraphDSL, Source}
import org.apache.pekko.stream.stage.{GraphStage, GraphStageLogic, InHandler, OutHandler}
import scala.collection.immutable
object MultiMergeSorted {
def merge[T](sources: Seq[Source[T, Any]], untilLastSourceComplete: Boolean = false)(implicit
ordering: Ordering[T]
): Source[T, Any] = {
Source.fromGraph(GraphDSL.createGraph(new MultiMergeSorted[T](sources.size, ordering, untilLastSourceComplete)) {
implicit b => merge =>
import GraphDSL.Implicits._
sources.foreach { source => source ~> merge }
SourceShape(merge.out)
})
}
}
final class MultiMergeSorted[T] private (
val inputsCount: Int,
val ordering: Ordering[T],
untilLastSourceComplete: Boolean
) extends GraphStage[UniformFanInShape[T, T]] {
require(inputsCount >= 1, "A Merge must have one or more input ports")
val ins: immutable.IndexedSeq[Inlet[T]] = Vector.tabulate(inputsCount)(i => Inlet[T]("Merge.in" + i))
val out: Outlet[T] = Outlet[T]("Merge.out")
override val shape: UniformFanInShape[T, T] = UniformFanInShape(out, ins: _*)
// scalastyle:off method.length
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
private val buffer = new Array[AnyRef](inputsCount)
private def itemAt(index: Int): T = buffer(index).asInstanceOf[T]
private def nextItemToGo(): (Int, T) = {
var minValue = itemAt(0)
var minIndex = 0
for (i <- 1 until buffer.length) {
val value = itemAt(i)
// scalastyle:off null
if (value != null && (minValue == null || ordering.compare(value, minValue) <= 0)) {
minValue = value
minIndex = i
}
// scalastyle:on null
}
(minIndex, minValue)
}
private def lastSourceCompleted: Boolean = {
val index = inputsCount - 1
// scalastyle:off null
untilLastSourceComplete && !activeUpstreamsMap(index) && buffer(index) == null
// scalastyle:on null
}
private val activeUpstreamsMap = new Array[Boolean](inputsCount)
private def bufferReady: Boolean = {
var ret = true
var i = 0
while (ret && i < inputsCount) {
// scalastyle:off null
if (activeUpstreamsMap(i) && buffer(i) == null) ret = false
// scalastyle:on null
i += 1
}
ret
}
private def possiblyPushOrComplete(): Unit = {
if (isAvailable(out) && bufferReady) {
val (index, item) = nextItemToGo()
// scalastyle:off null
if (item != null) {
buffer.update(index, null)
push(out, item)
tryPull(ins(index))
if (lastSourceCompleted) completeStage()
}
// scalastyle:on null
else completeStage()
}
else if (lastSourceCompleted) completeStage()
}
private def onItem(index: Int, item: T): Unit = {
// println(s"<--| index=$index item=$item")
buffer.update(index, item.asInstanceOf[AnyRef])
possiblyPushOrComplete()
}
private def onInletComplete(i: Int): Unit = {
activeUpstreamsMap.update(i, false)
possiblyPushOrComplete()
}
override def preStart(): Unit = {
ins.foreach(tryPull)
}
for (i <- ins.indices) {
val in = ins(i)
activeUpstreamsMap.update(i, true)
setHandler(
in,
new InHandler {
override def onPush(): Unit = {
val item = grab(in)
onItem(i, item)
}
override def onUpstreamFinish(): Unit = {
onInletComplete(i)
}
}
)
}
setHandler(
out,
new OutHandler {
override def onPull(): Unit = {
possiblyPushOrComplete()
}
}
)
}
override def toString: String = "MultiMergeSorted"
}