From b1c0edcd829e678d4828a7281808cd581a10d02b Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Thu, 14 Mar 2024 17:38:45 -0700 Subject: [PATCH] start to implement actual rotation replacement --- .../RotationAnalysis/RotationAnalysis.h | 2 + .../TensorExt/Transforms/RotateAndReduce.cpp | 63 +++++++++++++++++-- tests/simd/simple_sum.mlir | 8 ++- 3 files changed, 67 insertions(+), 6 deletions(-) diff --git a/include/Analysis/RotationAnalysis/RotationAnalysis.h b/include/Analysis/RotationAnalysis/RotationAnalysis.h index 4a804bec7..030f3a452 100644 --- a/include/Analysis/RotationAnalysis/RotationAnalysis.h +++ b/include/Analysis/RotationAnalysis/RotationAnalysis.h @@ -47,6 +47,8 @@ class RotationSets { return accessedIndices; } + Value getTensor() const { return tensor; } + void print(raw_ostream &os) const { os << tensor << ": ["; for (auto index : accessedIndices) { diff --git a/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp b/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp index 095cfc4ba..c15097385 100644 --- a/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp +++ b/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp @@ -39,8 +39,36 @@ struct RotateAndReduce : impl::RotateAndReduceBase { using RotateAndReduceBase::RotateAndReduceBase; template - void tryReplace(ArithOp op, DenseSet &visited) { - LLVM_DEBUG(llvm::dbgs() << "Trying to replace " << *op << "\n"); + void tryReplaceRotations(ArithOp op, Value tensor, + DenseSet &visited) { + // The dataflow analysis provides some guarantees, but not enough + // to prove that we can replace the op with the rotate-and-reduce trick + // while still maintaining program correctness. + // + // We need to do some more complicated checks to ensure that: + // (a) the op tree all contains the same op type (all sum or all mul) + // (b) each tensor index ultimately contributes exactly once to the overall + // reduction + + LLVM_DEBUG(llvm::dbgs() + << "Trying to replace rotations ending in " << *op << "\n"); + SetVector backwardSlice; + BackwardSliceOptions options; + // asserts that the parent op has a single region with a single block. + options.omitBlockArguments = false; + + DenseSet visitedReductionOps; + DenseSet accessIndices; + DenseMap opCounts; + opCounts[op->getName().getStringRef()]++; + + getBackwardSlice(op.getOperation(), &backwardSlice, options); + } + + template + void tryReplaceExtractions(ArithOp op, DenseSet &visited) { + LLVM_DEBUG(llvm::dbgs() + << "Trying to replace extractions ending in " << *op << "\n"); SetVector backwardSlice; BackwardSliceOptions options; // asserts that the parent op has a single region with a single block. @@ -232,6 +260,33 @@ struct RotateAndReduce : impl::RotateAndReduceBase { }); DenseSet visited; + + getOperation()->walk( + [&](Operation *op) { + if (op->getNumResults() == 0) return; + auto *targetSlotLattice = + solver.lookupState( + op->getResult(0)); + if (targetSlotLattice->getValue().isOverdetermined()) { + return; + } + + auto tensor = targetSlotLattice->getValue().getTensor(); + auto accessIndices = + targetSlotLattice->getValue().getAccessedIndices(); + int64_t tensorSize = + tensor.getType().cast().getShape()[0]; + if (accessIndices.size() == tensorSize) { + llvm::TypeSwitch(*op) + .Case([&](auto arithOp) { + tryReplaceRotations(arithOp, tensor, visited); + }) + .Case([&](auto arithOp) { + tryReplaceRotations(arithOp, tensor, visited); + }); + } + }); + // Traverse the IR in reverse order so that we can eagerly compute backward // slices for each operation. getOperation()->walk( @@ -241,10 +296,10 @@ struct RotateAndReduce : impl::RotateAndReduceBase { } llvm::TypeSwitch(*op) .Case([&](auto arithOp) { - tryReplace(arithOp, visited); + tryReplaceExtractions(arithOp, visited); }) .Case([&](auto arithOp) { - tryReplace(arithOp, visited); + tryReplaceExtractions(arithOp, visited); }); }); } diff --git a/tests/simd/simple_sum.mlir b/tests/simd/simple_sum.mlir index c61c146a6..62282ba9a 100644 --- a/tests/simd/simple_sum.mlir +++ b/tests/simd/simple_sum.mlir @@ -1,8 +1,12 @@ // RUN: heir-opt --secretize=entry-function=simple_sum --wrap-generic --canonicalize --cse \ -// RUN: --full-loop-unroll --cse --canonicalize --insert-rotate \ -// RUN: %s | FileCheck %s +// RUN: --full-loop-unroll --insert-rotate --cse --canonicalize --rotate-and-reduce \ +// RUN: %s | FileCheck %s +// Sum all entries of a tensor into a single scalar // CHECK-LABEL: @simple_sum +// CHECK: secret.generic +// CHECK-COUNT-3: tensor_ext.rotate +// CHECK-NOT: tensor_ext.rotate func.func @simple_sum(%arg0: tensor<32xi16> {secret.secret}) -> i16 { %c0 = arith.constant 0 : index %c0_si16 = arith.constant 0 : i16