Skip to content

Commit

Permalink
start to implement actual rotation replacement
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Mar 15, 2024
1 parent f0571fe commit b1c0edc
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 6 deletions.
2 changes: 2 additions & 0 deletions include/Analysis/RotationAnalysis/RotationAnalysis.h
Expand Up @@ -47,6 +47,8 @@ class RotationSets {
return accessedIndices;
}

Value getTensor() const { return tensor; }

void print(raw_ostream &os) const {
os << tensor << ": [";
for (auto index : accessedIndices) {
Expand Down
63 changes: 59 additions & 4 deletions lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp
Expand Up @@ -39,8 +39,36 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
using RotateAndReduceBase::RotateAndReduceBase;

template <typename ArithOp>
void tryReplace(ArithOp op, DenseSet<Operation *> &visited) {
LLVM_DEBUG(llvm::dbgs() << "Trying to replace " << *op << "\n");
void tryReplaceRotations(ArithOp op, Value tensor,
DenseSet<Operation *> &visited) {
// The dataflow analysis provides some guarantees, but not enough
// to prove that we can replace the op with the rotate-and-reduce trick
// while still maintaining program correctness.
//
// We need to do some more complicated checks to ensure that:
// (a) the op tree all contains the same op type (all sum or all mul)
// (b) each tensor index ultimately contributes exactly once to the overall
// reduction

LLVM_DEBUG(llvm::dbgs()
<< "Trying to replace rotations ending in " << *op << "\n");
SetVector<Operation *> backwardSlice;
BackwardSliceOptions options;
// asserts that the parent op has a single region with a single block.
options.omitBlockArguments = false;

DenseSet<Operation *> visitedReductionOps;
DenseSet<unsigned> accessIndices;
DenseMap<llvm::StringRef, int> opCounts;
opCounts[op->getName().getStringRef()]++;

getBackwardSlice(op.getOperation(), &backwardSlice, options);
}

template <typename ArithOp>
void tryReplaceExtractions(ArithOp op, DenseSet<Operation *> &visited) {
LLVM_DEBUG(llvm::dbgs()
<< "Trying to replace extractions ending in " << *op << "\n");
SetVector<Operation *> backwardSlice;
BackwardSliceOptions options;
// asserts that the parent op has a single region with a single block.
Expand Down Expand Up @@ -232,6 +260,33 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
});

DenseSet<Operation *> visited;

getOperation()->walk<WalkOrder::PreOrder, ReverseIterator>(
[&](Operation *op) {
if (op->getNumResults() == 0) return;
auto *targetSlotLattice =
solver.lookupState<rotation_analysis::RotationLattice>(
op->getResult(0));
if (targetSlotLattice->getValue().isOverdetermined()) {
return;
}

auto tensor = targetSlotLattice->getValue().getTensor();
auto accessIndices =
targetSlotLattice->getValue().getAccessedIndices();
int64_t tensorSize =
tensor.getType().cast<RankedTensorType>().getShape()[0];
if (accessIndices.size() == tensorSize) {
llvm::TypeSwitch<Operation &>(*op)
.Case<arith::AddIOp>([&](auto arithOp) {
tryReplaceRotations<arith::AddIOp>(arithOp, tensor, visited);
})
.Case<arith::MulIOp>([&](auto arithOp) {
tryReplaceRotations<arith::MulIOp>(arithOp, tensor, visited);
});
}
});

// Traverse the IR in reverse order so that we can eagerly compute backward
// slices for each operation.
getOperation()->walk<WalkOrder::PreOrder, ReverseIterator>(
Expand All @@ -241,10 +296,10 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
}
llvm::TypeSwitch<Operation &>(*op)
.Case<arith::AddIOp>([&](auto arithOp) {
tryReplace<arith::AddIOp>(arithOp, visited);
tryReplaceExtractions<arith::AddIOp>(arithOp, visited);
})
.Case<arith::MulIOp>([&](auto arithOp) {
tryReplace<arith::MulIOp>(arithOp, visited);
tryReplaceExtractions<arith::MulIOp>(arithOp, visited);
});
});
}
Expand Down
8 changes: 6 additions & 2 deletions tests/simd/simple_sum.mlir
@@ -1,8 +1,12 @@
// RUN: heir-opt --secretize=entry-function=simple_sum --wrap-generic --canonicalize --cse \
// RUN: --full-loop-unroll --cse --canonicalize --insert-rotate \
// RUN: %s | FileCheck %s
// RUN: --full-loop-unroll --insert-rotate --cse --canonicalize --rotate-and-reduce \
// RUN: %s | FileCheck %s

// Sum all entries of a tensor into a single scalar
// CHECK-LABEL: @simple_sum
// CHECK: secret.generic
// CHECK-COUNT-3: tensor_ext.rotate
// CHECK-NOT: tensor_ext.rotate
func.func @simple_sum(%arg0: tensor<32xi16> {secret.secret}) -> i16 {
%c0 = arith.constant 0 : index
%c0_si16 = arith.constant 0 : i16
Expand Down

0 comments on commit b1c0edc

Please sign in to comment.