Skip to content

Commit

Permalink
speed up HoistPlaintextOps
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Mar 5, 2024
1 parent 1939b79 commit ec62a43
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 131 deletions.
49 changes: 46 additions & 3 deletions lib/Dialect/Secret/IR/SecretOps.cpp
Expand Up @@ -388,12 +388,55 @@ GenericOp GenericOp::extractOpBeforeGeneric(Operation *opToExtract,
newResultTypes.push_back(SecretType::get(ty));
}

// The inputs to the new single-op generic are the subset of the current
// generic's inputs that correspond to the opToExtract's operands, and any
// operands among ops in opToExtract's nested regions.
SmallVector<Value> newGenericOperands;
SmallVector<Value> oldBlockArgs;
DenseSet<Value> processedValues;
newGenericOperands.reserve(opToExtract->getNumOperands());
oldBlockArgs.reserve(opToExtract->getNumOperands());
processedValues.reserve(opToExtract->getNumOperands());
for (auto operand : opToExtract->getOperands()) {
if (processedValues.count(operand)) continue;
// If the yielded value is ambient, skip it and it continues to be ambient.
auto *correspondingOperand = getOpOperandForBlockArgument(operand);
if (!correspondingOperand) {
// The operand must be ambient
continue;
}
newGenericOperands.push_back(correspondingOperand->get());
oldBlockArgs.push_back(operand);
processedValues.insert(operand);
}
opToExtract->walk([&](Operation *nestedOp) {
for (Value operand : nestedOp->getOperands()) {
if (processedValues.count(operand)) continue;
auto *correspondingOperand = getOpOperandForBlockArgument(operand);
if (!correspondingOperand) {
// Assume the operand is ambient, or else a block argument of
// opToExtract or an op within a nested region of opToExtract.
continue;
}
newGenericOperands.push_back(correspondingOperand->get());
oldBlockArgs.push_back(operand);
processedValues.insert(operand);
}
});

LLVM_DEBUG(llvm::dbgs() << "New single-op generic will have "
<< newGenericOperands.size() << " operands\n");

auto newGeneric = rewriter.create<GenericOp>(
getLoc(), getInputs(), newResultTypes,
getLoc(), newGenericOperands, newResultTypes,
[&](OpBuilder &b, Location loc, ValueRange blockArguments) {
IRMapping mp;
for (BlockArgument blockArg : getBody()->getArguments()) {
mp.map(blockArg, blockArguments[blockArg.getArgNumber()]);
// the newly-created blockArguments have the same index order as
// newGenericOperands, which in turn shares the index ordering of
// oldBlockArgs (they were constructed this way specifically to enable
// this IR Mapping).
for (auto [oldArg, newArg] : llvm::zip(oldBlockArgs, blockArguments)) {
mp.map(oldArg, newArg);
}
auto *newOp = b.clone(*opToExtract, mp);
b.create<YieldOp>(loc, newOp->getResults());
Expand Down
32 changes: 23 additions & 9 deletions lib/Dialect/Secret/IR/SecretPatterns.cpp
Expand Up @@ -567,16 +567,30 @@ LogicalResult HoistPlaintextOps::matchAndRewrite(
return true;
};

auto it = std::find_if(opRange.begin(), opRange.end(),
[&](Operation &op) { return canHoist(op); });
if (it == opRange.end()) {
return failure();
LLVM_DEBUG(
llvm::dbgs() << "Scanning generic body looking for ops to hoist...\n");

// We can't hoist them as they are detected because the process of hoisting
// alters the context generic op.
llvm::SmallVector<Operation *> opsToHoist;
bool hoistedAny = false;
for (Operation &op : opRange) {
if (canHoist(op)) {
opsToHoist.push_back(&op);
hoistedAny = true;
}
}

Operation *opToHoist = &*it;
LLVM_DEBUG(llvm::dbgs() << "Hoisting " << *opToHoist << "\n");
genericOp.extractOpBeforeGeneric(opToHoist, rewriter);
return success();
LLVM_DEBUG(llvm::dbgs() << "Found " << opsToHoist.size()
<< " ops to hoist\n");

for (Operation *op : opsToHoist) {
genericOp.extractOpBeforeGeneric(op, rewriter);
}

LLVM_DEBUG(llvm::dbgs() << "Done hoisting\n");

return hoistedAny ? success() : failure();
}

void genericAbsorbConstants(secret::GenericOp genericOp,
Expand All @@ -596,7 +610,7 @@ void genericAbsorbConstants(secret::GenericOp genericOp,
// inside the region.
Region *operandRegion = definingOp->getParentRegion();
if (operandRegion && !genericOp.getRegion().isAncestor(operandRegion)) {
auto copiedOp = rewriter.clone(*definingOp);
auto *copiedOp = rewriter.clone(*definingOp);
rewriter.replaceAllUsesWith(operand, copiedOp->getResults());
// If this was a block argument, additionally remove the block
// argument.
Expand Down
21 changes: 21 additions & 0 deletions tests/secret/canonicalize_perf.mlir
@@ -0,0 +1,21 @@
// RUN: heir-opt --affine-loop-unroll=unroll-factor=1024 --canonicalize %s | FileCheck %s

// A test to ensure that the canonicalize pass is not slow for large secret generic bodies
// Cf. https://github.com/google/heir/issues/482

// FIXME: find a way to ensure this test runs in < 1s (with the bug it takes 15s).
// CHECK-LABEL: func @fast_unrolled_loop
func.func @fast_unrolled_loop(
%arg1 : !secret.secret<memref<1024xi32>>) -> !secret.secret<memref<1024xi32>> {
%c5 = arith.constant 5 : i32
%out = secret.generic ins(%arg1 : !secret.secret<memref<1024xi32>>) {
^bb0(%pt_arg: memref<1024xi32>):
affine.for %i = 0 to 1024 {
%x = memref.load %pt_arg[%i] : memref<1024xi32>
%y = arith.addi %x, %c5 : i32
memref.store %y, %pt_arg[%i] : memref<1024xi32>
}
secret.yield %pt_arg : memref<1024xi32>
} -> !secret.secret<memref<1024xi32>>
return %out : !secret.secret<memref<1024xi32>>
}
122 changes: 3 additions & 119 deletions tests/yosys_optimizer/unroll_and_optimize.mlir
Expand Up @@ -3,123 +3,6 @@
!in_ty = !secret.secret<memref<10xi8>>
!out_ty = !secret.secret<memref<10xi8>>

func.func @basic_example(%arg0: !in_ty) -> (!out_ty) {
%0 = secret.generic {
^bb0:
%memref = memref.alloc() : memref<10xi8>
secret.yield %memref : memref<10xi8>
} -> !out_ty

affine.for %i = 0 to 10 {
secret.generic ins(%arg0, %0 : !in_ty, !out_ty) {
^bb0(%clean_memref: memref<10xi8>, %clean_outmemref: memref<10xi8>):
%1 = memref.load %clean_memref[%i] : memref<10xi8>
// This is actually such a simple computation that yosys will optimize it
// to be purely assignments (doubling a number is shifting the bits and
// assigning 0 to the lowest bit).
%2 = arith.addi %1, %1 : i8
memref.store %2, %clean_outmemref[%i] : memref<10xi8>
secret.yield
}
}

return %0 : !out_ty
}

// CHECK-LABEL: func.func @basic_example(
// CHECK-SAME: %[[arg0:.*]]: [[secret_ty:!secret.secret<memref<10xi8>>]]
// CHECK-DAG: %[[c0:.*]] = arith.constant 0
// CHECK-DAG: %[[c1:.*]] = arith.constant 1
// CHECK-DAG: %[[c2:.*]] = arith.constant 2
// CHECK-DAG: %[[c3:.*]] = arith.constant 3
// CHECK-DAG: %[[c4:.*]] = arith.constant 4
// CHECK-DAG: %[[c5:.*]] = arith.constant 5
// CHECK-DAG: %[[c6:.*]] = arith.constant 6
// CHECK-DAG: %[[c7:.*]] = arith.constant 7
// CHECK-DAG: %[[false:.*]] = arith.constant false
//
// CHECK: secret.generic
// CHECK-NEXT: memref.alloc
// CHECK-NEXT: secret.yield
//
// CHECK: affine.for %[[index:.*]] = 0 to 10 step 2 {
// CHECK-NEXT: %[[index_plus_one:.*]] = affine.apply
//
// The loads are hoisted out of the generic
// CHECK-NEXT: secret.generic
// CHECK-NEXT: ^bb
// CHECK-NEXT: memref.load
// CHECK-SAME: %[[index]]
// CHECK-NEXT: secret.yield
// CHECK-NEXT: } -> !secret.secret<i8>
// CHECK-NEXT: secret.generic
// CHECK-NEXT: ^bb
// CHECK-NEXT: memref.load
// CHECK-SAME: %[[index_plus_one]]
// CHECK-NEXT: secret.yield
// CHECK-NEXT: } -> !secret.secret<i8>
//
// CHECK-NEXT: secret.cast
// CHECK-NEXT: secret.cast
//
// The main computation
// CHECK-NEXT: secret.generic
// CHECK-NEXT: ^bb{{.*}}(%[[arg2:.*]]: memref<8xi1>, %[[arg3:.*]]: memref<8xi1>):
// Note bit 7 is never loaded because it is shifted out
// CHECK-DAG: %[[arg2bit0:.*]] = memref.load %arg2[%[[c0]]] : memref<8xi1>
// CHECK-DAG: %[[arg2bit1:.*]] = memref.load %arg2[%[[c1]]] : memref<8xi1>
// CHECK-DAG: %[[arg2bit2:.*]] = memref.load %arg2[%[[c2]]] : memref<8xi1>
// CHECK-DAG: %[[arg2bit3:.*]] = memref.load %arg2[%[[c3]]] : memref<8xi1>
// CHECK-DAG: %[[arg2bit4:.*]] = memref.load %arg2[%[[c4]]] : memref<8xi1>
// CHECK-DAG: %[[arg2bit5:.*]] = memref.load %arg2[%[[c5]]] : memref<8xi1>
// CHECK-DAG: %[[arg2bit6:.*]] = memref.load %arg2[%[[c6]]] : memref<8xi1>
//
// CHECK-DAG: %[[arg3bit0:.*]] = memref.load %arg3[%[[c0]]] : memref<8xi1>
// CHECK-DAG: %[[arg3bit1:.*]] = memref.load %arg3[%[[c1]]] : memref<8xi1>
// CHECK-DAG: %[[arg3bit2:.*]] = memref.load %arg3[%[[c2]]] : memref<8xi1>
// CHECK-DAG: %[[arg3bit3:.*]] = memref.load %arg3[%[[c3]]] : memref<8xi1>
// CHECK-DAG: %[[arg3bit4:.*]] = memref.load %arg3[%[[c4]]] : memref<8xi1>
// CHECK-DAG: %[[arg3bit5:.*]] = memref.load %arg3[%[[c5]]] : memref<8xi1>
// CHECK-DAG: %[[arg3bit6:.*]] = memref.load %arg3[%[[c6]]] : memref<8xi1>
//
// The order of use of the two allocs seem arbitrary and nondeterministic,
// so check the stores without the memref names
// CHECK-DAG: memref.alloc() : memref<8xi1>
// CHECK-DAG: memref.alloc() : memref<8xi1>
// CHECK-DAG: memref.store %[[false]], %{{.*}}[%[[c0]]] : memref<8xi1>
// CHECK-DAG: memref.store %[[arg3bit0]], %{{.*}}[%[[c1]]] : memref<8xi1>
// CHECK-DAG: memref.store %[[arg3bit1]], %{{.*}}[%[[c2]]] : memref<8xi1>
// CHECK-DAG: memref.store %[[arg3bit2]], %{{.*}}[%[[c3]]] : memref<8xi1>
// CHECK-DAG: memref.store %[[arg3bit3]], %{{.*}}[%[[c4]]] : memref<8xi1>
// CHECK-DAG: memref.store %[[arg3bit4]], %{{.*}}[%[[c5]]] : memref<8xi1>
// CHECK-DAG: memref.store %[[arg3bit5]], %{{.*}}[%[[c6]]] : memref<8xi1>
// CHECK-DAG: memref.store %[[arg3bit6]], %{{.*}}[%[[c7]]] : memref<8xi1>
// CHECK-DAG: memref.store %[[false]], %{{.*}}[%[[c0]]] : memref<8xi1>
// CHECK-DAG: memref.store %[[arg2bit0]], %{{.*}}[%[[c1]]] : memref<8xi1>
// CHECK-DAG: memref.store %[[arg2bit1]], %{{.*}}[%[[c2]]] : memref<8xi1>
// CHECK-DAG: memref.store %[[arg2bit2]], %{{.*}}[%[[c3]]] : memref<8xi1>
// CHECK-DAG: memref.store %[[arg2bit3]], %{{.*}}[%[[c4]]] : memref<8xi1>
// CHECK-DAG: memref.store %[[arg2bit4]], %{{.*}}[%[[c5]]] : memref<8xi1>
// CHECK-DAG: memref.store %[[arg2bit5]], %{{.*}}[%[[c6]]] : memref<8xi1>
// CHECK-DAG: memref.store %[[arg2bit6]], %{{.*}}[%[[c7]]] : memref<8xi1>
//
// CHECK-NEXT: secret.yield {{.*}}, {{.*}}
// CHECK-NEXT: }
// CHECK-NEXT: secret.cast
// CHECK-NEXT: secret.cast
// CHECK-NEXT: secret.generic
// CHECK-NEXT: ^bb
// CHECK-NEXT: memref.store
// CHECK-NEXT: secret.yield
// CHECK-NEXT: }
// CHECK-NEXT: secret.generic
// CHECK-NEXT: ^bb
// CHECK-NEXT: memref.store
// CHECK-NEXT: secret.yield
// CHECK-NEXT: }
// CHECK-NEXT:}
// CHECK-NEXT: return

// Computes the set of partial cumulative sums of the input array
func.func @cumulative_sums(%arg0: !in_ty) -> (!out_ty) {
%0 = secret.generic {
Expand Down Expand Up @@ -198,6 +81,8 @@ func.func @cumulative_sums(%arg0: !in_ty) -> (!out_ty) {
//
// Extracted plaintext arith op
// CHECK-NEXT: %[[index_minus_one:.*]] = arith.subi %[[index]], %[[c1]]
// Same deal, but for second unwrapped loop iteration marked by SECOND_SUB
// CHECK-NEXT: arith.subi
//
// Extracted load that can only be extracted because the previous
// arith op was extracted.
Expand All @@ -208,8 +93,7 @@ func.func @cumulative_sums(%arg0: !in_ty) -> (!out_ty) {
// CHECK-NEXT: secret.yield
// CHECK-NEXT: }
//
// Same deal, but for second unwrapped loop iteration
// CHECK-NEXT: arith.subi
// mark: SECOND_SUB
// CHECK-NEXT: secret.generic
// CHECK-NEXT: bb
// CHECK-NEXT: memref.load
Expand Down

0 comments on commit ec62a43

Please sign in to comment.