Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve load balancing of parallel sum reduction #319

Open
mratsim opened this issue Dec 13, 2023 · 0 comments
Open

Improve load balancing of parallel sum reduction #319

mratsim opened this issue Dec 13, 2023 · 0 comments

Comments

@mratsim
Copy link
Owner

mratsim commented Dec 13, 2023

Currently parallel sum reduction uses 2 strategies depending on the number of points to be summed.

  1. Size checks

proc sum_reduce_vartime_parallel*[F; G: static Subgroup](
tp: Threadpool,
r: var (ECP_ShortW_Jac[F, G] or ECP_ShortW_Prj[F, G]),
points: openArray[ECP_ShortW_Aff[F, G]]) {.inline.} =
## Parallel Batch addition of `points` into `r`
## `r` is overwritten
if points.len < 256:
r.setInf()
r.accumSum_chunk_vartime(points.asUnchecked(), points.len)
elif points.len < 8192:
tp.sum_reduce_vartime_parallelAccums(r, points)
else:
tp.sum_reduce_vartime_parallelChunks(r, points)

  1. Hard split for large inputs

proc sum_reduce_vartime_parallelChunks[F; G: static Subgroup](
tp: Threadpool,
r: var (ECP_ShortW_Jac[F, G] or ECP_ShortW_Prj[F, G]),
points: openArray[ECP_ShortW_Aff[F, G]]) {.noInline.} =
## Batch addition of `points` into `r`
## `r` is overwritten
## Scales better for large number of points
# Chunking constants in ec_shortweierstrass_batch_ops.nim
const maxTempMem = 262144 # 2¹⁸ = 262144
const maxChunkSize = maxTempMem div sizeof(ECP_ShortW_Aff[F, G])
const minChunkSize = (maxChunkSize * 60) div 100 # We want 60%~100% full chunks
let chunkDesc = balancedChunksPrioSize(
start = 0, stopEx = points.len,
minChunkSize, maxChunkSize,
numChunksHint = tp.numThreads.int)
let partialResults = allocStackArray(r.typeof(), chunkDesc.numChunks)
syncScope:
for iter in items(chunkDesc):
proc sum_reduce_chunk_vartime_wrapper(res: ptr, p: ptr, pLen: int) {.nimcall.} =
# The borrow checker prevents capturing `var` and `openArray`
# so we capture pointers instead.
res[].setInf()
res[].accumSum_chunk_vartime(p, pLen)
tp.spawn partialResults[iter.chunkID].addr.sum_reduce_chunk_vartime_wrapper(
points.asUnchecked() +% iter.start,
iter.size)
const minChunkSizeSerial = 32
if chunkDesc.numChunks < minChunkSizeSerial:
r.setInf()
for i in 0 ..< chunkDesc.numChunks:
r.sum_vartime(r, partialResults[i])
else:
let partialResultsAffine = allocStackArray(ECP_ShortW_Aff[F, G], chunkDesc.numChunks)
partialResultsAffine.batchAffine(partialResults, chunkDesc.numChunks)
r.sum_reduce_vartime(partialResultsAffine, chunkDesc.numChunks)

  1. Automated split for medium inputs

proc sum_reduce_vartime_parallelAccums[F; G: static Subgroup](
tp: Threadpool,
r: var (ECP_ShortW_Jac[F, G] or ECP_ShortW_Prj[F, G]),
points: openArray[ECP_ShortW_Aff[F, G]]) =
## Batch addition of `points` into `r`
## `r` is overwritten
## 2x faster for low number of points
const maxTempMem = 1 shl 18 # 2¹⁸ = 262144
const maxChunkSize = maxTempMem div sizeof(ECP_ShortW_Aff[F, G])
type Acc = EcAddAccumulator_vartime[typeof(r), F, G, maxChunkSize]
let ps = points.asUnchecked()
let N = points.len
mixin globalAcc
const chunkSize = 32
tp.parallelFor i in 0 ..< N:
stride: chunkSize
captures: {ps, N}
reduceInto(globalAcc: Flowvar[ptr Acc]):
prologue:
var workerAcc = allocHeap(Acc)
workerAcc[].init()
forLoop:
for j in i ..< min(i+chunkSize, N):
workerAcc[].update(ps[j])
merge(remoteAccFut: Flowvar[ptr Acc]):
let remoteAcc = sync(remoteAccFut)
workerAcc[].merge(remoteAcc[])
freeHeap(remoteAcc)
epilogue:
workerAcc[].handover()
return workerAcc
let ctx = sync(globalAcc)
ctx[].finish(r)
freeHeap(ctx)

The automated split uses the threadpool implementation of Lazy Binary Splitting

This is 2x faster for a medium amount of points.

Improving load balancing

Due to the complex chunking:

  • minimum chunk threshold of 32 to make the batch inversion worthwhile
  • ideally chunk as big as possible within 2¹⁸ = 262144 bytes of memory to fit in most L1 cache

there is a lot of overhead just for iterate by 32 indices and copy in temp mem.

We may be able to improve load-balance by using a shared atomic to represent the current point index, have one accumulator per thread have them accumulate as many points as possible.

There will be more cache-line contention on this atomic but LBS does have contention as well when the work is just copying data most of the time (until we reach accumulator threshold):

type BalancerBackoff = object
## We want to dynamically split parallel loops depending on the number of idle threads.
## However checking an atomic variable require synchronization which at the very least means
## reloading its value in all caches, a guaranteed cache miss. In a tight loop,
## this might be a significant cost, especially given that memory is often the bottleneck.
##
## There is no synchronization possible with thieves, unlike Prell PhD thesis.
## We want to avoid the worst-case scenario in Tzannes paper, tight-loop with too many available cores
## so the producer deque is always empty, leading to it spending all its CPU time splitting loops.
## For this we split depending on the numbers of idle CPUs. This prevents also splitting unnecessarily.
##
## Tzannes et al mentions that checking the thread own deque emptiness is a good approximation of system load
## with low overhead except in very fine-grained parallelism.
## With a better approximation, by checking the number of idle threads we can instead
## directly do the correct number of splits or avoid splitting. But this check is costly.
##
## To minimize checking cost while keeping latency low, even in bursty cases,
## we use log-log iterated backoff.
## - Adversarial Contention Resolution for Simple Channels
## Bender, Farach-Colton, He, Kuszmaul, Leiserson, 2005
## https://people.csail.mit.edu/bradley/papers/BenderFaHe05.pdf
nextCheck: int
windowLogSize: uint32 # while loopIndex < lastCheck + 2^windowLogSize, don't recheck.
round: uint32 # windowSize += 1 after log(windowLogSize) rounds

That strategy is already used for parallel BLS signatures:

# Stage 0a: Setup per-thread accumulators
debug: doAssert pubkeys.len <= 1 shl 32
let N = pubkeys.len.uint32
let numAccums = min(N, tp.numThreads.uint32)
let accums = allocHeapArray(BLSBatchSigAccumulator[H, FF1, FF2, Fpk, ECP_ShortW_Jac[Sig.F, Sig.G], k], numAccums)
# Stage 0b: Setup synchronization
var currentItem {.noInit.}: Atomic[uint32]
var terminateSignal {.noInit.}: Atomic[bool]
currentItem.store(0, moRelaxed)
terminateSignal.store(false, moRelaxed)
# Stage 1: Accumulate partial pairings (Miller Loops)
# ---------------------------------------------------
proc accumulate(
ctx: ptr BLSBatchSigAccumulator,
pubkeys: ptr UncheckedArray[Pubkey],
messages: ptr UncheckedArray[Msg],
signatures: ptr UncheckedArray[Sig],
N: uint32,
domainSepTag: View[byte],
secureRandomBytes: ptr array[32, byte],
accumSepTag: array[sizeof(int), byte],
terminateSignal: ptr Atomic[bool],
currentItem: ptr Atomic[uint32]): bool {.nimcall, gcsafe.} =
ctx[].init(
domainSepTag.toOpenArray(),
secureRandomBytes[],
accumSepTag)
while not terminateSignal[].load(moRelaxed):
let i = currentItem[].fetchAdd(1, moRelaxed)
if i >= N:
break
if not ctx[].update(pubkeys[i], messages[i], signatures[i]):
terminateSignal[].store(true, moRelaxed)
return false
ctx[].handover()
return true
# Stage 2: Schedule work
# ---------------------------------------------------
let partialStates = allocStackArray(Flowvar[bool], numAccums)
for id in 0 ..< numAccums:
partialStates[id] = tp.spawn accumulate(
accums[id].addr,
pubkeys.asUnchecked(),
messages.asUnchecked(),
signatures.asUnchecked(),
N,
domainSepTag.toView(),
secureRandomBytes.unsafeAddr,
id.uint.toBytes(bigEndian),
terminateSignal.addr,
currentItem.addr)
# Stage 3: Reduce partial pairings
# --------------------------------
# Linear merge with latency hiding, we could consider a parallel logarithmic merge via a binary tree merge / divide-and-conquer
block HappyPath: # sync must be called even if result is false in the middle to avoid tasks leaking
result = sync partialStates[0]
for i in 1 ..< numAccums:
result = result and sync partialStates[i]
if result: # As long as no error is returned, accumulate
result = result and accums[0].merge(accums[i])
if not result: # Don't proceed to final exponentiation if there is already an error
break HappyPath
result = accums[0].finalVerify()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant