Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up HoistPlaintextOps #485

Merged
merged 1 commit into from Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 46 additions & 3 deletions lib/Dialect/Secret/IR/SecretOps.cpp
Expand Up @@ -388,12 +388,55 @@ GenericOp GenericOp::extractOpBeforeGeneric(Operation *opToExtract,
newResultTypes.push_back(SecretType::get(ty));
}

// The inputs to the new single-op generic are the subset of the current
// generic's inputs that correspond to the opToExtract's operands, and any
// operands among ops in opToExtract's nested regions.
SmallVector<Value> newGenericOperands;
SmallVector<Value> oldBlockArgs;
DenseSet<Value> processedValues;
newGenericOperands.reserve(opToExtract->getNumOperands());
oldBlockArgs.reserve(opToExtract->getNumOperands());
processedValues.reserve(opToExtract->getNumOperands());
for (auto operand : opToExtract->getOperands()) {
if (processedValues.count(operand)) continue;
// If the yielded value is ambient, skip it and it continues to be ambient.
auto *correspondingOperand = getOpOperandForBlockArgument(operand);
if (!correspondingOperand) {
// The operand must be ambient
continue;
}
newGenericOperands.push_back(correspondingOperand->get());
oldBlockArgs.push_back(operand);
processedValues.insert(operand);
}
opToExtract->walk([&](Operation *nestedOp) {
for (Value operand : nestedOp->getOperands()) {
if (processedValues.count(operand)) continue;
auto *correspondingOperand = getOpOperandForBlockArgument(operand);
if (!correspondingOperand) {
// Assume the operand is ambient, or else a block argument of
// opToExtract or an op within a nested region of opToExtract.
continue;
}
newGenericOperands.push_back(correspondingOperand->get());
oldBlockArgs.push_back(operand);
processedValues.insert(operand);
}
});

LLVM_DEBUG(llvm::dbgs() << "New single-op generic will have "
<< newGenericOperands.size() << " operands\n");

auto newGeneric = rewriter.create<GenericOp>(
getLoc(), getInputs(), newResultTypes,
getLoc(), newGenericOperands, newResultTypes,
[&](OpBuilder &b, Location loc, ValueRange blockArguments) {
IRMapping mp;
for (BlockArgument blockArg : getBody()->getArguments()) {
mp.map(blockArg, blockArguments[blockArg.getArgNumber()]);
// the newly-created blockArguments have the same index order as
// newGenericOperands, which in turn shares the index ordering of
// oldBlockArgs (they were constructed this way specifically to enable
// this IR Mapping).
for (auto [oldArg, newArg] : llvm::zip(oldBlockArgs, blockArguments)) {
mp.map(oldArg, newArg);
}
auto *newOp = b.clone(*opToExtract, mp);
b.create<YieldOp>(loc, newOp->getResults());
Expand Down
32 changes: 23 additions & 9 deletions lib/Dialect/Secret/IR/SecretPatterns.cpp
Expand Up @@ -567,16 +567,30 @@ LogicalResult HoistPlaintextOps::matchAndRewrite(
return true;
};

auto it = std::find_if(opRange.begin(), opRange.end(),
[&](Operation &op) { return canHoist(op); });
if (it == opRange.end()) {
return failure();
LLVM_DEBUG(
llvm::dbgs() << "Scanning generic body looking for ops to hoist...\n");

// We can't hoist them as they are detected because the process of hoisting
// alters the context generic op.
llvm::SmallVector<Operation *> opsToHoist;
bool hoistedAny = false;
for (Operation &op : opRange) {
if (canHoist(op)) {
opsToHoist.push_back(&op);
hoistedAny = true;
}
}
Comment on lines +577 to 582
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you could consider llvm::make_filter_range here


Operation *opToHoist = &*it;
LLVM_DEBUG(llvm::dbgs() << "Hoisting " << *opToHoist << "\n");
genericOp.extractOpBeforeGeneric(opToHoist, rewriter);
return success();
LLVM_DEBUG(llvm::dbgs() << "Found " << opsToHoist.size()
<< " ops to hoist\n");

for (Operation *op : opsToHoist) {
genericOp.extractOpBeforeGeneric(op, rewriter);
}

LLVM_DEBUG(llvm::dbgs() << "Done hoisting\n");

return hoistedAny ? success() : failure();
}

void genericAbsorbConstants(secret::GenericOp genericOp,
Expand All @@ -596,7 +610,7 @@ void genericAbsorbConstants(secret::GenericOp genericOp,
// inside the region.
Region *operandRegion = definingOp->getParentRegion();
if (operandRegion && !genericOp.getRegion().isAncestor(operandRegion)) {
auto copiedOp = rewriter.clone(*definingOp);
auto *copiedOp = rewriter.clone(*definingOp);
rewriter.replaceAllUsesWith(operand, copiedOp->getResults());
// If this was a block argument, additionally remove the block
// argument.
Expand Down
25 changes: 25 additions & 0 deletions tests/secret/canonicalize_perf.mlir
@@ -0,0 +1,25 @@
// RUN: heir-opt --affine-loop-unroll=unroll-factor=1024 --canonicalize %s | FileCheck %s

// A test to ensure that the canonicalize pass is not slow for large secret generic bodies
// Cf. https://github.com/google/heir/issues/482

// This test should only take ~0.5s to run. Before the bug above, it took ~15s.
// It will take a bit of extra work to put a strict time limit on the bazel test runner
// via lit, so I will just leave this note here and if the test starts running slow we will
// hopefully notice it.

// CHECK-LABEL: func @fast_unrolled_loop
func.func @fast_unrolled_loop(
%arg1 : !secret.secret<memref<1024xi32>>) -> !secret.secret<memref<1024xi32>> {
%c5 = arith.constant 5 : i32
%out = secret.generic ins(%arg1 : !secret.secret<memref<1024xi32>>) {
^bb0(%pt_arg: memref<1024xi32>):
affine.for %i = 0 to 1024 {
%x = memref.load %pt_arg[%i] : memref<1024xi32>
%y = arith.addi %x, %c5 : i32
memref.store %y, %pt_arg[%i] : memref<1024xi32>
}
secret.yield %pt_arg : memref<1024xi32>
} -> !secret.secret<memref<1024xi32>>
return %out : !secret.secret<memref<1024xi32>>
}
5 changes: 3 additions & 2 deletions tests/yosys_optimizer/unroll_and_optimize.mlir
Expand Up @@ -198,6 +198,8 @@ func.func @cumulative_sums(%arg0: !in_ty) -> (!out_ty) {
//
// Extracted plaintext arith op
// CHECK-NEXT: %[[index_minus_one:.*]] = arith.subi %[[index]], %[[c1]]
// Same deal, but for second unwrapped loop iteration marked by SECOND_SUB
// CHECK-NEXT: arith.subi
//
// Extracted load that can only be extracted because the previous
// arith op was extracted.
Expand All @@ -208,8 +210,7 @@ func.func @cumulative_sums(%arg0: !in_ty) -> (!out_ty) {
// CHECK-NEXT: secret.yield
// CHECK-NEXT: }
//
// Same deal, but for second unwrapped loop iteration
// CHECK-NEXT: arith.subi
// mark: SECOND_SUB
// CHECK-NEXT: secret.generic
// CHECK-NEXT: bb
// CHECK-NEXT: memref.load
Expand Down