Skip to content

Commit

Permalink
Merge pull request #485 from j2kun:hoist-plaintext-speed
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 613350156
  • Loading branch information
Copybara-Service committed Mar 6, 2024
2 parents d859950 + 977eac6 commit 04f6106
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 14 deletions.
49 changes: 46 additions & 3 deletions lib/Dialect/Secret/IR/SecretOps.cpp
Expand Up @@ -388,12 +388,55 @@ GenericOp GenericOp::extractOpBeforeGeneric(Operation *opToExtract,
newResultTypes.push_back(SecretType::get(ty));
}

// The inputs to the new single-op generic are the subset of the current
// generic's inputs that correspond to the opToExtract's operands, and any
// operands among ops in opToExtract's nested regions.
SmallVector<Value> newGenericOperands;
SmallVector<Value> oldBlockArgs;
DenseSet<Value> processedValues;
newGenericOperands.reserve(opToExtract->getNumOperands());
oldBlockArgs.reserve(opToExtract->getNumOperands());
processedValues.reserve(opToExtract->getNumOperands());
for (auto operand : opToExtract->getOperands()) {
if (processedValues.count(operand)) continue;
// If the yielded value is ambient, skip it and it continues to be ambient.
auto *correspondingOperand = getOpOperandForBlockArgument(operand);
if (!correspondingOperand) {
// The operand must be ambient
continue;
}
newGenericOperands.push_back(correspondingOperand->get());
oldBlockArgs.push_back(operand);
processedValues.insert(operand);
}
opToExtract->walk([&](Operation *nestedOp) {
for (Value operand : nestedOp->getOperands()) {
if (processedValues.count(operand)) continue;
auto *correspondingOperand = getOpOperandForBlockArgument(operand);
if (!correspondingOperand) {
// Assume the operand is ambient, or else a block argument of
// opToExtract or an op within a nested region of opToExtract.
continue;
}
newGenericOperands.push_back(correspondingOperand->get());
oldBlockArgs.push_back(operand);
processedValues.insert(operand);
}
});

LLVM_DEBUG(llvm::dbgs() << "New single-op generic will have "
<< newGenericOperands.size() << " operands\n");

auto newGeneric = rewriter.create<GenericOp>(
getLoc(), getInputs(), newResultTypes,
getLoc(), newGenericOperands, newResultTypes,
[&](OpBuilder &b, Location loc, ValueRange blockArguments) {
IRMapping mp;
for (BlockArgument blockArg : getBody()->getArguments()) {
mp.map(blockArg, blockArguments[blockArg.getArgNumber()]);
// the newly-created blockArguments have the same index order as
// newGenericOperands, which in turn shares the index ordering of
// oldBlockArgs (they were constructed this way specifically to enable
// this IR Mapping).
for (auto [oldArg, newArg] : llvm::zip(oldBlockArgs, blockArguments)) {
mp.map(oldArg, newArg);
}
auto *newOp = b.clone(*opToExtract, mp);
b.create<YieldOp>(loc, newOp->getResults());
Expand Down
32 changes: 23 additions & 9 deletions lib/Dialect/Secret/IR/SecretPatterns.cpp
Expand Up @@ -567,16 +567,30 @@ LogicalResult HoistPlaintextOps::matchAndRewrite(
return true;
};

auto it = std::find_if(opRange.begin(), opRange.end(),
[&](Operation &op) { return canHoist(op); });
if (it == opRange.end()) {
return failure();
LLVM_DEBUG(
llvm::dbgs() << "Scanning generic body looking for ops to hoist...\n");

// We can't hoist them as they are detected because the process of hoisting
// alters the context generic op.
llvm::SmallVector<Operation *> opsToHoist;
bool hoistedAny = false;
for (Operation &op : opRange) {
if (canHoist(op)) {
opsToHoist.push_back(&op);
hoistedAny = true;
}
}

Operation *opToHoist = &*it;
LLVM_DEBUG(llvm::dbgs() << "Hoisting " << *opToHoist << "\n");
genericOp.extractOpBeforeGeneric(opToHoist, rewriter);
return success();
LLVM_DEBUG(llvm::dbgs() << "Found " << opsToHoist.size()
<< " ops to hoist\n");

for (Operation *op : opsToHoist) {
genericOp.extractOpBeforeGeneric(op, rewriter);
}

LLVM_DEBUG(llvm::dbgs() << "Done hoisting\n");

return hoistedAny ? success() : failure();
}

void genericAbsorbConstants(secret::GenericOp genericOp,
Expand All @@ -596,7 +610,7 @@ void genericAbsorbConstants(secret::GenericOp genericOp,
// inside the region.
Region *operandRegion = definingOp->getParentRegion();
if (operandRegion && !genericOp.getRegion().isAncestor(operandRegion)) {
auto copiedOp = rewriter.clone(*definingOp);
auto *copiedOp = rewriter.clone(*definingOp);
rewriter.replaceAllUsesWith(operand, copiedOp->getResults());
// If this was a block argument, additionally remove the block
// argument.
Expand Down
25 changes: 25 additions & 0 deletions tests/secret/canonicalize_perf.mlir
@@ -0,0 +1,25 @@
// RUN: heir-opt --affine-loop-unroll=unroll-factor=1024 --canonicalize %s | FileCheck %s

// A test to ensure that the canonicalize pass is not slow for large secret generic bodies
// Cf. https://github.com/google/heir/issues/482

// This test should only take ~0.5s to run. Before the bug above, it took ~15s.
// It will take a bit of extra work to put a strict time limit on the bazel test runner
// via lit, so I will just leave this note here and if the test starts running slow we will
// hopefully notice it.

// CHECK-LABEL: func @fast_unrolled_loop
func.func @fast_unrolled_loop(
%arg1 : !secret.secret<memref<1024xi32>>) -> !secret.secret<memref<1024xi32>> {
%c5 = arith.constant 5 : i32
%out = secret.generic ins(%arg1 : !secret.secret<memref<1024xi32>>) {
^bb0(%pt_arg: memref<1024xi32>):
affine.for %i = 0 to 1024 {
%x = memref.load %pt_arg[%i] : memref<1024xi32>
%y = arith.addi %x, %c5 : i32
memref.store %y, %pt_arg[%i] : memref<1024xi32>
}
secret.yield %pt_arg : memref<1024xi32>
} -> !secret.secret<memref<1024xi32>>
return %out : !secret.secret<memref<1024xi32>>
}
5 changes: 3 additions & 2 deletions tests/yosys_optimizer/unroll_and_optimize.mlir
Expand Up @@ -198,6 +198,8 @@ func.func @cumulative_sums(%arg0: !in_ty) -> (!out_ty) {
//
// Extracted plaintext arith op
// CHECK-NEXT: %[[index_minus_one:.*]] = arith.subi %[[index]], %[[c1]]
// Same deal, but for second unwrapped loop iteration marked by SECOND_SUB
// CHECK-NEXT: arith.subi
//
// Extracted load that can only be extracted because the previous
// arith op was extracted.
Expand All @@ -208,8 +210,7 @@ func.func @cumulative_sums(%arg0: !in_ty) -> (!out_ty) {
// CHECK-NEXT: secret.yield
// CHECK-NEXT: }
//
// Same deal, but for second unwrapped loop iteration
// CHECK-NEXT: arith.subi
// mark: SECOND_SUB
// CHECK-NEXT: secret.generic
// CHECK-NEXT: bb
// CHECK-NEXT: memref.load
Expand Down

0 comments on commit 04f6106

Please sign in to comment.