Skip to content

Commit

Permalink
[IR][K/N] Extracted explicit var spilling phase from coroutines lowering
Browse files Browse the repository at this point in the history
Not only does it give some flexibility but also it could be turned off
in the future (by providing trivial replacement) as a workaround

 #KT-65153
  • Loading branch information
homuroll authored and Space Team committed Feb 7, 2024
1 parent 954eade commit 6c05219
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ object LivenessAnalysis {
private val loopEndsLV = mutableMapOf<IrLoop, BitSet>()
private val loopStartsLV = mutableMapOf<IrLoop, BitSet>()
private var catchesLV = BitSet()
private val suspensionPointIdParameters = BitSet()

fun run(body: IrBody): Map<IrElement, List<IrVariable>> {
body.accept(this, BitSet() /* No variable is live at the end */)
Expand Down Expand Up @@ -81,6 +82,8 @@ object LivenessAnalysis {
val elementLV = filteredElementEndsLV.getOrPut(element) { BitSet() }
elementLV.or(liveVariables)
elementLV.or(catchesLV)
// Suspension points id parameters are never live since they are provided at codegen.
elementLV.andNot(suspensionPointIdParameters)
}
return compute()
}
Expand Down Expand Up @@ -184,6 +187,14 @@ object LivenessAnalysis {
currentCatchesLV
}

override fun visitSuspensionPoint(expression: IrSuspensionPoint, data: BitSet): BitSet {
val variableId = getVariableId(expression.suspensionPointIdParameter)
suspensionPointIdParameters.set(variableId)
return super.visitSuspensionPoint(expression, data).also {
suspensionPointIdParameters.clear(variableId)
}
}

override fun visitBreak(jump: IrBreak, data: BitSet) = saveAndCompute(jump, data) {
loopEndsLV[jump.loop] ?: error("Break from an unknown loop ${jump.loop.render()}")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,13 @@ private val coroutinesPhase = createFileLoweringPhase(
prerequisite = setOf(localFunctionsPhase, finallyBlocksPhase, kotlinNothingValueExceptionPhase)
)

private val coroutinesVarSpillingPhase = createFileLoweringPhase(
::CoroutinesVarSpillingLowering,
name = "CoroutinesVarSpilling",
description = "Save/restore coroutines variables before/after suspension",
prerequisite = setOf(coroutinesPhase)
)

private val typeOperatorPhase = createFileLoweringPhase(
::TypeOperatorLowering,
name = "TypeOperators",
Expand Down Expand Up @@ -547,6 +554,7 @@ private fun PhaseEngine<NativeGenerationState>.getAllLowerings() = listOfNotNull
varargPhase,
kotlinNothingValueExceptionPhase,
coroutinesPhase,
coroutinesVarSpillingPhase,
typeOperatorPhase,
expressionBodyTransformPhase,
objectClassesPhase,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,9 @@ internal class KonanSymbols(
override val coroutineSuspendedGetter =
findTopLevelPropertyGetter(StandardNames.COROUTINES_INTRINSICS_PACKAGE_FQ_NAME, COROUTINE_SUSPENDED_NAME, null)

val saveCoroutineState = internalFunction("saveCoroutineState")
val restoreCoroutineState = internalFunction("restoreCoroutineState")

val cancellationException = topLevelClass(KonanFqNames.cancellationException)

val kotlinResult = irBuiltIns.findClass(Name.identifier("Result"))!!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ internal enum class IntrinsicType {
// Coroutines
GET_CONTINUATION,
RETURN_IF_SUSPENDED,
SAVE_COROUTINE_STATE,
RESTORE_COROUTINE_STATE,
// Interop
INTEROP_READ_BITS,
INTEROP_WRITE_BITS,
Expand Down Expand Up @@ -279,6 +281,8 @@ internal class IntrinsicGenerator(private val environment: IntrinsicGeneratorEnv
IntrinsicType.GET_AND_ADD_ARRAY_ELEMENT -> emitGetAndAddArrayElement(callSite, args)
IntrinsicType.GET_CONTINUATION,
IntrinsicType.RETURN_IF_SUSPENDED,
IntrinsicType.SAVE_COROUTINE_STATE,
IntrinsicType.RESTORE_COROUTINE_STATE,
IntrinsicType.INTEROP_BITS_TO_FLOAT,
IntrinsicType.INTEROP_BITS_TO_DOUBLE,
IntrinsicType.INTEROP_SIGN_EXTEND,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright 2010-2024 JetBrains s.r.o. and Kotlin Programming Language contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/

package org.jetbrains.kotlin.backend.konan.lower

import org.jetbrains.kotlin.backend.common.BodyLoweringPass
import org.jetbrains.kotlin.backend.common.lower.createIrBuilder
import org.jetbrains.kotlin.backend.common.lower.irBlock
import org.jetbrains.kotlin.backend.common.lower.optimizations.LivenessAnalysis
import org.jetbrains.kotlin.backend.konan.Context
import org.jetbrains.kotlin.descriptors.DescriptorVisibilities
import org.jetbrains.kotlin.ir.builders.declarations.buildField
import org.jetbrains.kotlin.ir.builders.irGet
import org.jetbrains.kotlin.ir.builders.irGetField
import org.jetbrains.kotlin.ir.builders.irSet
import org.jetbrains.kotlin.ir.builders.irSetField
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.IrBody
import org.jetbrains.kotlin.ir.expressions.IrCall
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.IrSuspensionPoint
import org.jetbrains.kotlin.ir.symbols.IrVariableSymbol
import org.jetbrains.kotlin.ir.util.addChild
import org.jetbrains.kotlin.ir.util.overrides
import org.jetbrains.kotlin.ir.util.parentAsClass
import org.jetbrains.kotlin.ir.visitors.IrElementTransformer

internal val DECLARATION_ORIGIN_COROUTINE_VAR_SPILLING = IrDeclarationOriginImpl("COROUTINE_VAR_SPILLING")

internal class CoroutinesVarSpillingLowering(val context: Context) : BodyLoweringPass {
private val irFactory = context.irFactory
private val symbols = context.ir.symbols
private val saveCoroutineState = symbols.saveCoroutineState
private val restoreCoroutineState = symbols.restoreCoroutineState

override fun lower(irBody: IrBody, container: IrDeclaration) {
val thisReceiver = (container as? IrSimpleFunction)?.dispatchReceiverParameter
if (thisReceiver == null || !container.overrides(context.ir.symbols.invokeSuspendFunction.owner))
return

val coroutineClass = container.parentAsClass
val liveLocals = LivenessAnalysis.run(irBody) { it is IrSuspensionPoint }

// TODO: optimize by using the same property for different locals.
val localToPropertyMap = mutableMapOf<IrVariableSymbol, IrField>()
fun getFieldForSpilling(variable: IrVariable) = localToPropertyMap.getOrPut(variable.symbol) {
variable.isVar = true // Make variables mutable in order to save/restore them.
irFactory.buildField {
startOffset = coroutineClass.startOffset
endOffset = coroutineClass.endOffset
origin = DECLARATION_ORIGIN_COROUTINE_VAR_SPILLING
name = variable.name
type = variable.type
visibility = DescriptorVisibilities.PRIVATE
isFinal = false
}.apply {
coroutineClass.addChild(this)
}
}

// Save/restore state at suspension points.
val irBuilder = context.createIrBuilder(container.symbol, container.startOffset, container.endOffset)
irBody.transformChildren(object : IrElementTransformer<List<IrVariable>> {
override fun visitSuspensionPoint(expression: IrSuspensionPoint, data: List<IrVariable>): IrExpression {
expression.transformChildren(this, liveLocals[expression]!!)

return expression
}

override fun visitCall(expression: IrCall, data: List<IrVariable>): IrExpression {
expression.transformChildren(this, data)

return when (expression.symbol) {
saveCoroutineState -> irBuilder.run {
irBlock(expression) {
for (variable in data) {
val field = getFieldForSpilling(variable)
+irSetField(irGet(thisReceiver), field, irGet(variable))
}
}
}
restoreCoroutineState -> irBuilder.run {
irBlock(expression) {
for (variable in data) {
val field = getFieldForSpilling(variable)
+irSet(variable, irGetField(irGet(thisReceiver), field))
}
}
}
else -> expression
}
}
}, data = emptyList())
}
}

0 comments on commit 6c05219

Please sign in to comment.