Skip to content

Commit

Permalink
speed up HoistPlaintextOps
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Mar 5, 2024
1 parent 1939b79 commit 61a93ff
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 23 deletions.
27 changes: 24 additions & 3 deletions lib/Dialect/Secret/IR/SecretOps.cpp
Expand Up @@ -388,12 +388,33 @@ GenericOp GenericOp::extractOpBeforeGeneric(Operation *opToExtract,
newResultTypes.push_back(SecretType::get(ty));
}

// The inputs to the new single-op generic are the subset of the current
// generic's inputs that correspond to the opToExtract's operands.
SmallVector<Value> newGenericOperands;
SmallVector<Value> oldBlockArgs;
newGenericOperands.reserve(opToExtract->getNumOperands());
oldBlockArgs.reserve(opToExtract->getNumOperands());
for (auto operand : opToExtract->getOperands()) {
// If the yielded value is ambient, skip it and it continues to be ambient.
auto *correspondingOperand = getOpOperandForBlockArgument(operand);
if (!correspondingOperand) {
// Assume the operand is ambient. This may fail for some cases where the
// operand is a block argument of some region-holding op within the
// generic. Cross that bridge when we get to it.
continue;
}
newGenericOperands.push_back(correspondingOperand->get());
oldBlockArgs.push_back(operand);
}
LLVM_DEBUG(llvm::dbgs() << "New single-op generic will have "
<< newGenericOperands.size() << " operands\n");

auto newGeneric = rewriter.create<GenericOp>(
getLoc(), getInputs(), newResultTypes,
getLoc(), newGenericOperands, newResultTypes,
[&](OpBuilder &b, Location loc, ValueRange blockArguments) {
IRMapping mp;
for (BlockArgument blockArg : getBody()->getArguments()) {
mp.map(blockArg, blockArguments[blockArg.getArgNumber()]);
for (auto [oldArg, newArg] : llvm::zip(oldBlockArgs, blockArguments)) {
mp.map(oldArg, newArg);
}
auto *newOp = b.clone(*opToExtract, mp);
b.create<YieldOp>(loc, newOp->getResults());
Expand Down
54 changes: 34 additions & 20 deletions lib/Dialect/Secret/IR/SecretPatterns.cpp
Expand Up @@ -526,10 +526,10 @@ LogicalResult HoistPlaintextOps::matchAndRewrite(
if (isa<YieldOp>(op)) {
return false;
}
LLVM_DEBUG(llvm::dbgs()
<< "Considering whether " << op << " can be hoisted\n");
// LLVM_DEBUG(llvm::dbgs()
// << "Considering whether " << op << " can be hoisted\n");
if (!isSpeculatable(&op)) {
LLVM_DEBUG(llvm::dbgs() << "Op is not speculatable\n");
// LLVM_DEBUG(llvm::dbgs() << "Op is not speculatable\n");
return false;
}
for (Value operand : op.getOperands()) {
Expand All @@ -542,10 +542,10 @@ LogicalResult HoistPlaintextOps::matchAndRewrite(
owningGeneric &&
isa<SecretType>(
owningGeneric.getOperand(blockArg.getArgNumber()).getType());
LLVM_DEBUG(llvm::dbgs()
<< "operand " << operand << " is a "
<< (isEncryptedBlockArg ? "encrypted" : "plaintext")
<< " block arg\n");
// LLVM_DEBUG(llvm::dbgs()
// << "operand " << operand << " is a "
// << (isEncryptedBlockArg ? "encrypted" : "plaintext")
// << " block arg\n");
if (isEncryptedBlockArg) {
return false;
}
Expand All @@ -554,10 +554,10 @@ LogicalResult HoistPlaintextOps::matchAndRewrite(
operand.getDefiningOp()->getBlock() != op.getBlock() &&
!operand.getType().isa<SecretType>();

LLVM_DEBUG(llvm::dbgs()
<< "operand " << operand << " is a "
<< (isPlaintextAmbient ? "plaintext" : "encrypted")
<< " ambient SSA value\n");
// LLVM_DEBUG(llvm::dbgs()
// << "operand " << operand << " is a "
// << (isPlaintextAmbient ? "plaintext" : "encrypted")
// << " ambient SSA value\n");
if (!isPlaintextAmbient) {
return false;
}
Expand All @@ -567,16 +567,30 @@ LogicalResult HoistPlaintextOps::matchAndRewrite(
return true;
};

auto it = std::find_if(opRange.begin(), opRange.end(),
[&](Operation &op) { return canHoist(op); });
if (it == opRange.end()) {
return failure();
LLVM_DEBUG(
llvm::dbgs() << "Scanning generic body looking for ops to hoist...\n");

// We can't hoist them as they are detected because the process of hoisting
// alters the context generic op.
llvm::SmallVector<Operation *> opsToHoist;
bool hoistedAny = false;
for (Operation &op : opRange) {
if (canHoist(op)) {
opsToHoist.push_back(&op);
hoistedAny = true;
}
}

Operation *opToHoist = &*it;
LLVM_DEBUG(llvm::dbgs() << "Hoisting " << *opToHoist << "\n");
genericOp.extractOpBeforeGeneric(opToHoist, rewriter);
return success();
LLVM_DEBUG(llvm::dbgs() << "Found " << opsToHoist.size()
<< " ops to hoist\n");

for (Operation *op : opsToHoist) {
genericOp.extractOpBeforeGeneric(op, rewriter);
}

LLVM_DEBUG(llvm::dbgs() << "Done hoisting\n");

return hoistedAny ? success() : failure();
}

void genericAbsorbConstants(secret::GenericOp genericOp,
Expand All @@ -596,7 +610,7 @@ void genericAbsorbConstants(secret::GenericOp genericOp,
// inside the region.
Region *operandRegion = definingOp->getParentRegion();
if (operandRegion && !genericOp.getRegion().isAncestor(operandRegion)) {
auto copiedOp = rewriter.clone(*definingOp);
auto *copiedOp = rewriter.clone(*definingOp);
rewriter.replaceAllUsesWith(operand, copiedOp->getResults());
// If this was a block argument, additionally remove the block
// argument.
Expand Down
21 changes: 21 additions & 0 deletions tests/secret/canonicalize_perf.mlir
@@ -0,0 +1,21 @@
// RUN: heir-opt --affine-loop-unroll=unroll-factor=1024 --canonicalize %s | FileCheck %s

// A test to ensure that the canonicalize pass is not slow for large secret generic bodies
// Cf. https://github.com/google/heir/issues/482

// FIXME: find a way to ensure this test runs in < 1s (with the bug it takes 15s).
// CHECK-LABEL: func @fast_unrolled_loop
func.func @fast_unrolled_loop(
%arg1 : !secret.secret<memref<1024xi32>>) -> !secret.secret<memref<1024xi32>> {
%c5 = arith.constant 5 : i32
%out = secret.generic ins(%arg1 : !secret.secret<memref<1024xi32>>) {
^bb0(%pt_arg: memref<1024xi32>):
affine.for %i = 0 to 1024 {
%x = memref.load %pt_arg[%i] : memref<1024xi32>
%y = arith.addi %x, %c5 : i32
memref.store %y, %pt_arg[%i] : memref<1024xi32>
}
secret.yield %pt_arg : memref<1024xi32>
} -> !secret.secret<memref<1024xi32>>
return %out : !secret.secret<memref<1024xi32>>
}

0 comments on commit 61a93ff

Please sign in to comment.