diff --git a/include/Analysis/RotationAnalysis/BUILD b/include/Analysis/RotationAnalysis/BUILD new file mode 100644 index 000000000..01eca6ba4 --- /dev/null +++ b/include/Analysis/RotationAnalysis/BUILD @@ -0,0 +1,9 @@ +# RotationAnalysis analysis pass +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files( + ["RotationAnalysis.h"], +) diff --git a/include/Analysis/RotationAnalysis/RotationAnalysis.h b/include/Analysis/RotationAnalysis/RotationAnalysis.h new file mode 100644 index 000000000..a1b2f2313 --- /dev/null +++ b/include/Analysis/RotationAnalysis/RotationAnalysis.h @@ -0,0 +1,219 @@ +#ifndef INCLUDE_ANALYSIS_ROTATIONANALYSIS_ROTATIONANALYSIS_H_ +#define INCLUDE_ANALYSIS_ROTATIONANALYSIS_ROTATIONANALYSIS_H_ + +#include + +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project + +#define DEBUG_TYPE "rotation-analysis" + +namespace mlir { +namespace heir { +namespace rotation_analysis { + +// A wrapper around a mapping from a single tensor SSA value to a set of its +// access indices. +class RotationSets { + public: + enum class Status { + // The tensor value has not been set + Uninitialized, + + // The rotation set is in a normal state. + Normal, + + // The rotation set has a property that makes it invalid for later + // optimizations: + // + // - It involves operations touch more than one source tensor (not + // including value-semantic outputs) + Overdetermined + + }; + + public: + RotationSets() = default; + ~RotationSets() = default; + + // Clear the member data, i.e., set the value back to an uninitialized + // state. + void clear() { + accessedIndices.clear(); + status = Status::Uninitialized; + } + + bool empty() const { return accessedIndices.empty(); } + + bool isOverdetermined() const { return status == Status::Overdetermined; } + + bool isUninitialized() const { return status == Status::Uninitialized; } + + void addRotation(int64_t index) { accessedIndices.insert(index); } + + bool operator==(const RotationSets &rhs) const { + return tensor == rhs.tensor && status == rhs.status && + accessedIndices == rhs.accessedIndices; + } + + const std::unordered_set &getAccessedIndices() const { + return accessedIndices; + } + + Value getTensor() const { return tensor; } + + void print(raw_ostream &os) const { + os << tensor << ": ["; + for (auto index : accessedIndices) { + os << index << ", "; + } + os << "]"; + } + + static RotationSets overdetermined() { + RotationSets sets; + sets.status = Status::Overdetermined; + return sets; + } + + static RotationSets from(Value tensor) { + RotationSets sets; + if (!tensor.getType().isa()) { + sets.status = Status::Uninitialized; + return sets; + } + + sets.status = Status::Normal; + sets.tensor = tensor; + if (auto blockArg = dyn_cast(tensor)) { + sets.addRotation(0); + } + return sets; + } + + // Shift the rotation indices by the given amount. This helps in a situation + // where an IR repeatedly rotates by 1, to ensure that rotations accumulate + // like {1, 2, 3, ...} rather than {1, 1, 1, ...} + static RotationSets rotate(const RotationSets &lhs, const int64_t shift) { + if (lhs.status == Status::Overdetermined) { + return overdetermined(); + } + + RotationSets shifted; + shifted.status = Status::Normal; + shifted.tensor = lhs.tensor; + int64_t size = + llvm::cast(lhs.tensor.getType()).getShape()[0]; + for (auto index : lhs.accessedIndices) { + shifted.addRotation((index + shift) % size); + } + return shifted; + } + + static RotationSets join(const RotationSets &lhs, const RotationSets &rhs) { + if (lhs.status == Status::Overdetermined || + rhs.status == Status::Overdetermined) { + return overdetermined(); + } + + if (rhs.status == Status::Uninitialized || rhs.accessedIndices.empty()) + return lhs; + if (lhs.status == Status::Uninitialized || lhs.accessedIndices.empty()) + return rhs; + + if (lhs.tensor != rhs.tensor) { + LLVM_DEBUG({ + llvm::dbgs() << "Joining rotations of different tensors: " << lhs.tensor + << " and " << rhs.tensor << "\n"; + }); + return overdetermined(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "Joining :" << lhs.tensor << " and " << rhs.tensor + << "\n"; + }); + RotationSets merged; + merged.status = Status::Normal; + merged.tensor = lhs.tensor; + for (auto index : lhs.accessedIndices) { + merged.addRotation(index); + } + for (auto index : rhs.accessedIndices) { + merged.addRotation(index); + } + return merged; + } + + // Assuming two not-overdetermined rotation sets, compute the overlap in + // their access indices. + static RotationSets overlap(const RotationSets &lhs, + const RotationSets &rhs) { + assert(!lhs.isOverdetermined() && !rhs.isOverdetermined() && + "Expected inputs to RotationSets::overlap to be not overdetermined"); + if (lhs.status == Status::Uninitialized || lhs.empty()) { + return lhs; + } + + if (rhs.status == Status::Uninitialized || rhs.empty()) { + return rhs; + } + + RotationSets merged; + merged.status = Status::Normal; + merged.tensor = lhs.tensor; + for (auto index : lhs.accessedIndices) { + if (rhs.accessedIndices.count(index)) merged.addRotation(index); + } + return merged; + } + + private: + /// The accessed indices of a single SSA value of tensor type. + Value tensor; + + // There is likely a data structure that can more efficiently represent a set + // of intervals of integers, which properly merges adjacent intervals as + // values are added. Java/Guava has RangeSet, and boost has interval_set. + std::unordered_set accessedIndices; + Status status = Status::Uninitialized; +}; + +inline raw_ostream &operator<<(raw_ostream &os, const RotationSets &v) { + v.print(os); + return os; +} + +class RotationLattice : public dataflow::Lattice { + public: + using Lattice::Lattice; +}; + +/// An analysis that identifies, for each SSA value, the set of underlying +/// tensors and rotations of those tensors, provided constant rotation shifts +/// can be determined. +class RotationAnalysis + : public dataflow::SparseForwardDataFlowAnalysis { + public: + explicit RotationAnalysis(DataFlowSolver &solver) + : SparseForwardDataFlowAnalysis(solver) {} + ~RotationAnalysis() override = default; + using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; + + // Given the computed results of the operation, update its operand lattice + // values. + void visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override; + + void setToEntryState(RotationLattice *lattice) override; +}; + +} // namespace rotation_analysis +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_ANALYSIS_ROTATIONANALYSIS_ROTATIONANALYSIS_H_ diff --git a/lib/Analysis/RotationAnalysis/BUILD b/lib/Analysis/RotationAnalysis/BUILD new file mode 100644 index 000000000..6a3ba8ff9 --- /dev/null +++ b/lib/Analysis/RotationAnalysis/BUILD @@ -0,0 +1,20 @@ +# RotationAnalysis analysis pass +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "RotationAnalysis", + srcs = ["RotationAnalysis.cpp"], + hdrs = ["@heir//include/Analysis/RotationAnalysis:RotationAnalysis.h"], + deps = [ + "@heir//lib/Dialect:Utils", + "@heir//lib/Dialect/TensorExt/IR:Dialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TensorDialect", + ], +) diff --git a/lib/Analysis/RotationAnalysis/RotationAnalysis.cpp b/lib/Analysis/RotationAnalysis/RotationAnalysis.cpp new file mode 100644 index 000000000..91377fb3d --- /dev/null +++ b/lib/Analysis/RotationAnalysis/RotationAnalysis.cpp @@ -0,0 +1,75 @@ +#include "include/Analysis/RotationAnalysis/RotationAnalysis.h" + +#include "include/Dialect/TensorExt/IR/TensorExtOps.h" +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace rotation_analysis { + +void RotationAnalysis::visitOperation( + Operation *op, ArrayRef operands, + ArrayRef results) { + llvm::TypeSwitch(*op) + .Case([&](auto rotateOp) { + LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << *op << "\n"; }); + auto shiftConstantOp = + rotateOp.getShift().template getDefiningOp(); + // If the rotation shift can't be statically determined, we can't + // propagate anything through the IR. + if (!shiftConstantOp) return; + + int64_t shiftValue = + dyn_cast(shiftConstantOp.getValue()).getInt(); + + // The target slot propagates from the tensor argument to the result; + // the tensor argument is first in the tablegen definition. + const RotationLattice *lattice = operands[0]; + RotationSets latticeRotations = lattice->getValue(); + + // If it's a block argument, then there is no initialized lattice value + // and we can override it with a "zero rotation" + auto blockArg = dyn_cast(rotateOp.getTensor()); + if (blockArg) { + latticeRotations = RotationSets::from(blockArg); + } + RotationSets rotated = + RotationSets::rotate(latticeRotations, shiftValue); + + for (RotationLattice *r : results) { + ChangeResult result = r->join(rotated); + propagateIfChanged(r, result); + } + }) + .Default([&](Operation &op) { + // By default, an op propagates its result target slots to all its + // operands. + for (OpOperand &operand : op.getOpOperands()) { + auto *latticeOperand = operands[operand.getOperandNumber()]; + + for (RotationLattice *r : results) { + ChangeResult result = r->join(*latticeOperand); + // If the operand is a block arg, this additionally treats this as + // a zero rotation. If the underlying tensor differs across + // operands, this will also cause a Status::TooManyTensors. + // Otherwise, the join is a no-op. + result |= r->join(RotationSets::from(operand.get())); + propagateIfChanged(r, result); + } + } + }); +} + +void RotationAnalysis::setToEntryState(RotationLattice *lattice) { + lattice->getValue().clear(); +} + +} // namespace rotation_analysis +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/TensorExt/Transforms/BUILD b/lib/Dialect/TensorExt/Transforms/BUILD index 59f01d8b3..3ad8b8f8a 100644 --- a/lib/Dialect/TensorExt/Transforms/BUILD +++ b/lib/Dialect/TensorExt/Transforms/BUILD @@ -69,6 +69,7 @@ cc_library( ], deps = [ "@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen", + "@heir//lib/Analysis/RotationAnalysis", "@heir//lib/Dialect/Secret/IR:Dialect", "@heir//lib/Dialect/TensorExt/IR:Dialect", "@llvm-project//llvm:Support", diff --git a/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp b/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp index 7c14c4187..20d320b65 100644 --- a/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp +++ b/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp @@ -2,24 +2,26 @@ #include +#include "include/Analysis/RotationAnalysis/RotationAnalysis.h" #include "include/Dialect/Secret/IR/SecretOps.h" #include "include/Dialect/TensorExt/IR/TensorExtOps.h" -#include "llvm/include/llvm/ADT/DenseSet.h" // from @llvm-project -#include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project -#include "mlir/include/mlir/Analysis/SliceAnalysis.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/include/mlir/IR/Iterators.h" // from @llvm-project -#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project - -#define DEBUG_TYPE "rotate-and-reduce" +#include "llvm/include/llvm/ADT/DenseSet.h" // from @llvm-project +#include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/SliceAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/IR/Iterators.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project namespace mlir { namespace heir { @@ -35,8 +37,121 @@ struct RotateAndReduce : impl::RotateAndReduceBase { using RotateAndReduceBase::RotateAndReduceBase; template - void tryReplace(ArithOp op, DenseSet &visited) { - LLVM_DEBUG(llvm::dbgs() << "Trying to replace " << *op << "\n"); + void tryReplaceRotations(ArithOp op, Value tensor, + DenseSet &visited, + DataFlowSolver &solver) { + // The dataflow analysis provides some guarantees, but not enough + // to prove that we can replace the op with the rotate-and-reduce trick + // while still maintaining program correctness. + // + // We need to do some more complicated checks to ensure that: the op tree + // all contains the same op type (all sum or all mul), and that the + // accessed rotations are included only once in the reduction. + // This cannot be done during the dataflow analysis itself due to the + // monotonicity requirements of the framework. + LLVM_DEBUG(llvm::dbgs() + << "Trying to replace rotations ending in " << *op << "\n"); + SetVector backwardSlice; + BackwardSliceOptions options; + // asserts that the parent op has a single region with a single block. + options.omitBlockArguments = false; + + DenseSet visitedReductionOps; + DenseMap opCounts; + opCounts[op->getName().getStringRef()]++; + + getBackwardSlice(op.getOperation(), &backwardSlice, options); + + for (Operation *upstreamOpPtr : backwardSlice) { + auto result = + llvm::TypeSwitch(upstreamOpPtr) + .Case( + [&](auto upstreamOp) { return success(); }) + // Ignore generic ops + .template Case( + [&](auto upstreamOp) { return success(); }) + .template Case([&](auto + upstreamOp) { + opCounts[upstreamOp->getName().getStringRef()]++; + // More than one reduction op is mixed in the reduction. + if (opCounts.size() > 1) { + LLVM_DEBUG(llvm::dbgs() + << "Not replacing op because reduction " + "contains multiple incompatible ops " + << op->getName() << " and " + << upstreamOp->getName() << "\n"); + return failure(); + } + + // Inspect the lattice values at the join point, + // and fail if there is any overlap + auto *lhsLattice = + solver.lookupState( + upstreamOp.getLhs()); + auto *rhsLattice = + solver.lookupState( + upstreamOp.getRhs()); + LLVM_DEBUG(llvm::dbgs() + << "Computing overlap of " + << "lhs: " << lhsLattice->getValue() << "\n" + << "rhs: " << rhsLattice->getValue() << "\n"); + auto mergedLattice = rotation_analysis::RotationSets::overlap( + lhsLattice->getValue(), rhsLattice->getValue()); + LLVM_DEBUG(llvm::dbgs() + << "Overlap is: " << mergedLattice << "\n"); + if (!mergedLattice.empty()) { + LLVM_DEBUG( + llvm::dbgs() + << "Not replacing op because reduction " + "may not be a simple reduction of the input tensor\n" + << "lhs: " << lhsLattice->getValue() << "\n" + << "rhs: " << rhsLattice->getValue() << "\n"); + return failure(); + } + + visitedReductionOps.insert(upstreamOp); + return success(); + }) + .Default([&](Operation *op) { + LLVM_DEBUG(llvm::dbgs() << "Not continuing because type switch " + "encountered unsupported op " + << op->getName() << "\n"); + return failure(); + }); + + if (failed(result)) { + return; + } + } + + // From here we know we will succeed. + auto b = ImplicitLocOpBuilder(op->getLoc(), op); + Operation *finalOp; + auto tensorShape = tensor.getType().cast().getShape(); + for (int64_t shiftSize = tensorShape[0] / 2; shiftSize > 0; + shiftSize /= 2) { + auto rotatedTensor = b.create( + tensor, b.create(b.getIndexAttr(shiftSize))); + auto addOp = b.create(tensor, rotatedTensor); + finalOp = addOp; + tensor = addOp->getResult(0); + } + + [[maybe_unused]] auto *parentOp = op->getParentOp(); + op->replaceAllUsesWith(finalOp); + LLVM_DEBUG(llvm::dbgs() << "Post-replacement: " << *parentOp << "\n"); + + // Mark all ops in the reduction as visited so we don't try to replace them + // twice. + for (Operation *visitedOp : visitedReductionOps) { + visited.insert(visitedOp); + } + } + + template + void tryReplaceExtractions(ArithOp op, DenseSet &visited) { + LLVM_DEBUG(llvm::dbgs() + << "Trying to replace extractions ending in " << *op << "\n"); SetVector backwardSlice; BackwardSliceOptions options; // asserts that the parent op has a single region with a single block. @@ -187,7 +302,77 @@ struct RotateAndReduce : impl::RotateAndReduceBase { } void runOnOperation() override { + DataFlowSolver solver; + // These two upstream analyses are required dependencies for any sparse + // dataflow analysis, or else the analysis will be a no-op. Cf. + // https://github.com/llvm/llvm-project/issues/58922 + solver.load(); + solver.load(); + solver.load(); + + if (failed(solver.initializeAndRun(getOperation()))) { + getOperation()->emitOpError() << "Failed to run dataflow analysis.\n"; + signalPassFailure(); + return; + } + + LLVM_DEBUG({ + getOperation()->walk([&](Operation *op) { + if (op->getNumResults() == 0) return; + auto *targetSlotLattice = + solver.lookupState( + op->getResult(0)); + if (targetSlotLattice->getValue().isOverdetermined()) { + llvm::dbgs() << "Rotation lattice for " << *op + << " is overdetermined\n"; + } else if (targetSlotLattice->getValue().empty()) { + llvm::dbgs() << "Rotation lattice for " << *op << " is empty\n"; + } else { + SmallVector sortedRotations( + targetSlotLattice->getValue().getAccessedIndices().begin(), + targetSlotLattice->getValue().getAccessedIndices().end()); + llvm::sort(sortedRotations); + std::string stringified = llvm::join( + llvm::map_range(sortedRotations, + [](int64_t i) { return std::to_string(i); }), + ","); + llvm::dbgs() << "Rotation lattice for " << *op << ": " << stringified + << "\n"; + } + }); + }); + DenseSet visited; + + getOperation()->walk( + [&](Operation *op) { + if (op->getNumResults() == 0) return; + auto *targetSlotLattice = + solver.lookupState( + op->getResult(0)); + if (targetSlotLattice->getValue().isUninitialized() || + targetSlotLattice->getValue().isOverdetermined()) { + return; + } + + auto tensor = targetSlotLattice->getValue().getTensor(); + auto accessIndices = + targetSlotLattice->getValue().getAccessedIndices(); + int64_t tensorSize = + tensor.getType().cast().getShape()[0]; + if (accessIndices.size() == tensorSize) { + llvm::TypeSwitch(*op) + .Case([&](auto arithOp) { + tryReplaceRotations(arithOp, tensor, visited, + solver); + }) + .Case([&](auto arithOp) { + tryReplaceRotations(arithOp, tensor, visited, + solver); + }); + } + }); + // Traverse the IR in reverse order so that we can eagerly compute backward // slices for each operation. getOperation()->walk( @@ -197,10 +382,10 @@ struct RotateAndReduce : impl::RotateAndReduceBase { } llvm::TypeSwitch(*op) .Case([&](auto arithOp) { - tryReplace(arithOp, visited); + tryReplaceExtractions(arithOp, visited); }) .Case([&](auto arithOp) { - tryReplace(arithOp, visited); + tryReplaceExtractions(arithOp, visited); }); }); } diff --git a/tests/simd/simple_sum.mlir b/tests/simd/simple_sum.mlir new file mode 100644 index 000000000..2af7833dc --- /dev/null +++ b/tests/simd/simple_sum.mlir @@ -0,0 +1,20 @@ +// RUN: heir-opt --secretize=entry-function=simple_sum --wrap-generic --canonicalize --cse \ +// RUN: --full-loop-unroll --insert-rotate --cse --canonicalize \ +// RUN: --rotate-and-reduce --canonicalize \ +// RUN: %s | FileCheck %s + +// Sum all entries of a tensor into a single scalar +// CHECK-LABEL: @simple_sum +// CHECK: secret.generic +// CHECK-COUNT-5: tensor_ext.rotate +// CHECK-NOT: tensor_ext.rotate +func.func @simple_sum(%arg0: tensor<32xi16> {secret.secret}) -> i16 { + %c0 = arith.constant 0 : index + %c0_si16 = arith.constant 0 : i16 + %0 = affine.for %i = 0 to 32 iter_args(%sum_iter = %c0_si16) -> i16 { + %1 = tensor.extract %arg0[%i] : tensor<32xi16> + %2 = arith.addi %1, %sum_iter : i16 + affine.yield %2 : i16 + } + return %0 : i16 +} diff --git a/tests/tensor_ext/rotate_and_reduce.mlir b/tests/tensor_ext/rotate_and_reduce.mlir index 126f70577..fa1c20596 100644 --- a/tests/tensor_ext/rotate_and_reduce.mlir +++ b/tests/tensor_ext/rotate_and_reduce.mlir @@ -1,4 +1,4 @@ -// RUN: heir-opt --rotate-and-reduce --canonicalize %s | FileCheck %s +// RUN: heir-opt --rotate-and-reduce --cse --canonicalize %s | FileCheck %s // Sum all entries of a tensor into a single scalar @@ -305,3 +305,261 @@ func.func @not_supported_non_tensor_operands(%arg0: tensor<8xi32>) -> i32 { %15 = arith.addi %14, %2 : i32 return %15 : i32 } + +// CHECK-LABEL: @sum_of_linear_rotates +// CHECK-COUNT-5: tensor_ext.rotate +// CHECK-NOT: tensor_ext.rotate +func.func @sum_of_linear_rotates(%arg0: !secret.secret>) -> !secret.secret { + %c30 = arith.constant 30 : index + %c29 = arith.constant 29 : index + %c31 = arith.constant 31 : index + %c1 = arith.constant 1 : index + %0 = secret.generic ins(%arg0 : !secret.secret>) { + ^bb0(%arg1: tensor<32xi16>): + %1 = tensor_ext.rotate %arg1, %c1 : tensor<32xi16>, index + %2 = arith.addi %1, %arg1 : tensor<32xi16> + %3 = tensor_ext.rotate %arg1, %c31 : tensor<32xi16>, index + %4 = tensor_ext.rotate %2, %c29 : tensor<32xi16>, index + %5 = arith.addi %3, %4 : tensor<32xi16> + %6 = arith.addi %5, %arg1 : tensor<32xi16> + %7 = tensor_ext.rotate %6, %c30 : tensor<32xi16>, index + %8 = arith.addi %3, %7 : tensor<32xi16> + %9 = arith.addi %8, %arg1 : tensor<32xi16> + %10 = tensor_ext.rotate %9, %c30 : tensor<32xi16>, index + %11 = arith.addi %3, %10 : tensor<32xi16> + %12 = arith.addi %11, %arg1 : tensor<32xi16> + %13 = tensor_ext.rotate %12, %c30 : tensor<32xi16>, index + %14 = arith.addi %3, %13 : tensor<32xi16> + %15 = arith.addi %14, %arg1 : tensor<32xi16> + %16 = tensor_ext.rotate %15, %c30 : tensor<32xi16>, index + %17 = arith.addi %3, %16 : tensor<32xi16> + %18 = arith.addi %17, %arg1 : tensor<32xi16> + %19 = tensor_ext.rotate %18, %c30 : tensor<32xi16>, index + %20 = arith.addi %3, %19 : tensor<32xi16> + %21 = arith.addi %20, %arg1 : tensor<32xi16> + %22 = tensor_ext.rotate %21, %c30 : tensor<32xi16>, index + %23 = arith.addi %3, %22 : tensor<32xi16> + %24 = arith.addi %23, %arg1 : tensor<32xi16> + %25 = tensor_ext.rotate %24, %c30 : tensor<32xi16>, index + %26 = arith.addi %3, %25 : tensor<32xi16> + %27 = arith.addi %26, %arg1 : tensor<32xi16> + %28 = tensor_ext.rotate %27, %c30 : tensor<32xi16>, index + %29 = arith.addi %3, %28 : tensor<32xi16> + %30 = arith.addi %29, %arg1 : tensor<32xi16> + %31 = tensor_ext.rotate %30, %c30 : tensor<32xi16>, index + %32 = arith.addi %3, %31 : tensor<32xi16> + %33 = arith.addi %32, %arg1 : tensor<32xi16> + %34 = tensor_ext.rotate %33, %c30 : tensor<32xi16>, index + %35 = arith.addi %3, %34 : tensor<32xi16> + %36 = arith.addi %35, %arg1 : tensor<32xi16> + %37 = tensor_ext.rotate %36, %c30 : tensor<32xi16>, index + %38 = arith.addi %3, %37 : tensor<32xi16> + %39 = arith.addi %38, %arg1 : tensor<32xi16> + %40 = tensor_ext.rotate %39, %c30 : tensor<32xi16>, index + %41 = arith.addi %3, %40 : tensor<32xi16> + %42 = arith.addi %41, %arg1 : tensor<32xi16> + %43 = tensor_ext.rotate %42, %c30 : tensor<32xi16>, index + %44 = arith.addi %3, %43 : tensor<32xi16> + %45 = arith.addi %44, %arg1 : tensor<32xi16> + %46 = tensor_ext.rotate %45, %c30 : tensor<32xi16>, index + %47 = arith.addi %3, %46 : tensor<32xi16> + %48 = arith.addi %47, %arg1 : tensor<32xi16> + %extracted = tensor.extract %48[%c31] : tensor<32xi16> + secret.yield %extracted : i16 + } -> !secret.secret + return %0 : !secret.secret +} + +// CHECK-LABEL: @rotate_not_applied_because_rotation_missing +// CHECK-COUNT-3: tensor_ext.rotate +func.func @rotate_not_applied_because_rotation_missing(%arg0: !secret.secret>) -> !secret.secret { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = secret.generic ins(%arg0 : !secret.secret>) { + ^bb0(%arg1: tensor<4xi16>): + %1 = tensor_ext.rotate %arg1, %c1 : tensor<4xi16>, index + %2 = arith.addi %1, %arg1 : tensor<4xi16> + %3 = tensor_ext.rotate %1, %c1 : tensor<4xi16>, index + %4 = arith.addi %2, %3 : tensor<4xi16> + // To make the rotation apply, replace %5 with this line + // %5 = tensor_ext.rotate %3, %c1 : tensor<4xi16>, index + %5 = tensor_ext.rotate %3, %c2 : tensor<4xi16>, index + %6 = arith.addi %4, %5 : tensor<4xi16> + %extracted = tensor.extract %6[%c0] : tensor<4xi16> + secret.yield %extracted : i16 + } -> !secret.secret + return %0 : !secret.secret +} + +// CHECK-LABEL: @rotate_not_applied_because_rotation_duplicated +// CHECK-COUNT-3: tensor_ext.rotate +func.func @rotate_not_applied_because_rotation_duplicated(%arg0: !secret.secret>) -> !secret.secret { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = secret.generic ins(%arg0 : !secret.secret>) { + ^bb0(%arg1: tensor<4xi16>): + %1 = tensor_ext.rotate %arg1, %c1 : tensor<4xi16>, index + %2 = arith.addi %1, %arg1 : tensor<4xi16> + %3 = tensor_ext.rotate %1, %c1 : tensor<4xi16>, index + %4 = arith.addi %2, %3 : tensor<4xi16> + // To return to normal, replace %v4_2 with %4 + %v4_2 = arith.addi %4, %3 : tensor<4xi16> + %5 = tensor_ext.rotate %3, %c1 : tensor<4xi16>, index + %6 = arith.addi %v4_2, %5 : tensor<4xi16> + %extracted = tensor.extract %6[%c1] : tensor<4xi16> + secret.yield %extracted : i16 + } -> !secret.secret + return %0 : !secret.secret +} + +// CHECK-LABEL: @rotate_not_applied_because_multiple_tensors +// CHECK-COUNT-3: tensor_ext.rotate +func.func @rotate_not_applied_because_multiple_tensors( + %arg0 : tensor<4xi16>, %arg1 : tensor<4xi16>) -> i16 { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %1 = tensor_ext.rotate %arg1, %c1 : tensor<4xi16>, index + %2 = arith.addi %1, %arg1 : tensor<4xi16> + %3 = tensor_ext.rotate %1, %c1 : tensor<4xi16>, index + %4 = arith.addi %2, %3 : tensor<4xi16> + // To return to normal, replace %v4_2 with %4 + %v4_2 = arith.addi %4, %arg0 : tensor<4xi16> + %5 = tensor_ext.rotate %3, %c1 : tensor<4xi16>, index + %6 = arith.addi %v4_2, %5 : tensor<4xi16> + %extracted = tensor.extract %6[%c1] : tensor<4xi16> + return %extracted : i16 +} + +// CHECK-LABEL: @rotate_not_applied_because_mixed_ops +// CHECK-COUNT-3: tensor_ext.rotate +func.func @rotate_not_applied_because_mixed_ops(%arg1 : tensor<4xi16>) -> i16 { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %1 = tensor_ext.rotate %arg1, %c1 : tensor<4xi16>, index + %2 = arith.addi %1, %arg1 : tensor<4xi16> + %3 = tensor_ext.rotate %1, %c1 : tensor<4xi16>, index + // To return to normal, replace muli with addi + %4 = arith.muli %2, %3 : tensor<4xi16> + %5 = tensor_ext.rotate %3, %c1 : tensor<4xi16>, index + %6 = arith.addi %4, %5 : tensor<4xi16> + %extracted = tensor.extract %6[%c1] : tensor<4xi16> + return %extracted : i16 +} + +// CHECK-LABEL: @reduce_add_and_mul +// 9 rotations because the first rotation can be re-used between the two +// reductions +// CHECK-COUNT-9: tensor_ext.rotate +// CHECK-NOT: tensor_ext.rotate +func.func @reduce_add_and_mul(%arg1: tensor<32xi16>) -> i16 { + %c30 = arith.constant 30 : index + %c29 = arith.constant 29 : index + %c31 = arith.constant 31 : index + %c1 = arith.constant 1 : index + + // the add reduction + %1 = tensor_ext.rotate %arg1, %c1 : tensor<32xi16>, index + %2 = arith.addi %1, %arg1 : tensor<32xi16> + %3 = tensor_ext.rotate %arg1, %c31 : tensor<32xi16>, index + %4 = tensor_ext.rotate %2, %c29 : tensor<32xi16>, index + %5 = arith.addi %3, %4 : tensor<32xi16> + %6 = arith.addi %5, %arg1 : tensor<32xi16> + %7 = tensor_ext.rotate %6, %c30 : tensor<32xi16>, index + %8 = arith.addi %3, %7 : tensor<32xi16> + %9 = arith.addi %8, %arg1 : tensor<32xi16> + %10 = tensor_ext.rotate %9, %c30 : tensor<32xi16>, index + %11 = arith.addi %3, %10 : tensor<32xi16> + %12 = arith.addi %11, %arg1 : tensor<32xi16> + %13 = tensor_ext.rotate %12, %c30 : tensor<32xi16>, index + %14 = arith.addi %3, %13 : tensor<32xi16> + %15 = arith.addi %14, %arg1 : tensor<32xi16> + %16 = tensor_ext.rotate %15, %c30 : tensor<32xi16>, index + %17 = arith.addi %3, %16 : tensor<32xi16> + %18 = arith.addi %17, %arg1 : tensor<32xi16> + %19 = tensor_ext.rotate %18, %c30 : tensor<32xi16>, index + %20 = arith.addi %3, %19 : tensor<32xi16> + %21 = arith.addi %20, %arg1 : tensor<32xi16> + %22 = tensor_ext.rotate %21, %c30 : tensor<32xi16>, index + %23 = arith.addi %3, %22 : tensor<32xi16> + %24 = arith.addi %23, %arg1 : tensor<32xi16> + %25 = tensor_ext.rotate %24, %c30 : tensor<32xi16>, index + %26 = arith.addi %3, %25 : tensor<32xi16> + %27 = arith.addi %26, %arg1 : tensor<32xi16> + %28 = tensor_ext.rotate %27, %c30 : tensor<32xi16>, index + %29 = arith.addi %3, %28 : tensor<32xi16> + %30 = arith.addi %29, %arg1 : tensor<32xi16> + %31 = tensor_ext.rotate %30, %c30 : tensor<32xi16>, index + %32 = arith.addi %3, %31 : tensor<32xi16> + %33 = arith.addi %32, %arg1 : tensor<32xi16> + %34 = tensor_ext.rotate %33, %c30 : tensor<32xi16>, index + %35 = arith.addi %3, %34 : tensor<32xi16> + %36 = arith.addi %35, %arg1 : tensor<32xi16> + %37 = tensor_ext.rotate %36, %c30 : tensor<32xi16>, index + %38 = arith.addi %3, %37 : tensor<32xi16> + %39 = arith.addi %38, %arg1 : tensor<32xi16> + %40 = tensor_ext.rotate %39, %c30 : tensor<32xi16>, index + %41 = arith.addi %3, %40 : tensor<32xi16> + %42 = arith.addi %41, %arg1 : tensor<32xi16> + %43 = tensor_ext.rotate %42, %c30 : tensor<32xi16>, index + %44 = arith.addi %3, %43 : tensor<32xi16> + %45 = arith.addi %44, %arg1 : tensor<32xi16> + %46 = tensor_ext.rotate %45, %c30 : tensor<32xi16>, index + %47 = arith.addi %3, %46 : tensor<32xi16> + %48 = arith.addi %47, %arg1 : tensor<32xi16> + %extracted = tensor.extract %48[%c31] : tensor<32xi16> + + // the mul reduction + %v1_2 = tensor_ext.rotate %arg1, %c1 : tensor<32xi16>, index + %v2_2 = arith.muli %v1_2, %arg1 : tensor<32xi16> + %v3_2 = tensor_ext.rotate %arg1, %c31 : tensor<32xi16>, index + %v4_2 = tensor_ext.rotate %v2_2, %c29 : tensor<32xi16>, index + %v5_2 = arith.muli %v3_2, %v4_2 : tensor<32xi16> + %v6_2 = arith.muli %v5_2, %arg1 : tensor<32xi16> + %v7_2 = tensor_ext.rotate %v6_2, %c30 : tensor<32xi16>, index + %v8_2 = arith.muli %v3_2, %v7_2 : tensor<32xi16> + %v9_2 = arith.muli %v8_2, %arg1 : tensor<32xi16> + %v10_2 = tensor_ext.rotate %v9_2, %c30 : tensor<32xi16>, index + %v11_2 = arith.muli %v3_2, %v10_2 : tensor<32xi16> + %v12_2 = arith.muli %v11_2, %arg1 : tensor<32xi16> + %v13_2 = tensor_ext.rotate %v12_2, %c30 : tensor<32xi16>, index + %v14_2 = arith.muli %v3_2, %v13_2 : tensor<32xi16> + %v15_2 = arith.muli %v14_2, %arg1 : tensor<32xi16> + %v16_2 = tensor_ext.rotate %v15_2, %c30 : tensor<32xi16>, index + %v17_2 = arith.muli %v3_2, %v16_2 : tensor<32xi16> + %v18_2 = arith.muli %v17_2, %arg1 : tensor<32xi16> + %v19_2 = tensor_ext.rotate %v18_2, %c30 : tensor<32xi16>, index + %v20_2 = arith.muli %v3_2, %v19_2 : tensor<32xi16> + %v21_2 = arith.muli %v20_2, %arg1 : tensor<32xi16> + %v22_2 = tensor_ext.rotate %v21_2, %c30 : tensor<32xi16>, index + %v23_2 = arith.muli %v3_2, %v22_2 : tensor<32xi16> + %v24_2 = arith.muli %v23_2, %arg1 : tensor<32xi16> + %v25_2 = tensor_ext.rotate %v24_2, %c30 : tensor<32xi16>, index + %v26_2 = arith.muli %v3_2, %v25_2 : tensor<32xi16> + %v27_2 = arith.muli %v26_2, %arg1 : tensor<32xi16> + %v28_2 = tensor_ext.rotate %v27_2, %c30 : tensor<32xi16>, index + %v29_2 = arith.muli %v3_2, %v28_2 : tensor<32xi16> + %v30_2 = arith.muli %v29_2, %arg1 : tensor<32xi16> + %v31_2 = tensor_ext.rotate %v30_2, %c30 : tensor<32xi16>, index + %v32_2 = arith.muli %v3_2, %v31_2 : tensor<32xi16> + %v33_2 = arith.muli %v32_2, %arg1 : tensor<32xi16> + %v34_2 = tensor_ext.rotate %v33_2, %c30 : tensor<32xi16>, index + %v35_2 = arith.muli %v3_2, %v34_2 : tensor<32xi16> + %v36_2 = arith.muli %v35_2, %arg1 : tensor<32xi16> + %v37_2 = tensor_ext.rotate %v36_2, %c30 : tensor<32xi16>, index + %v38_2 = arith.muli %v3_2, %v37_2 : tensor<32xi16> + %v39_2 = arith.muli %v38_2, %arg1 : tensor<32xi16> + %v40_2 = tensor_ext.rotate %v39_2, %c30 : tensor<32xi16>, index + %v41_2 = arith.muli %v3_2, %v40_2 : tensor<32xi16> + %v42_2 = arith.muli %v41_2, %arg1 : tensor<32xi16> + %v43_2 = tensor_ext.rotate %v42_2, %c30 : tensor<32xi16>, index + %v44_2 = arith.muli %v3_2, %v43_2 : tensor<32xi16> + %v45_2 = arith.muli %v44_2, %arg1 : tensor<32xi16> + %v46_2 = tensor_ext.rotate %v45_2, %c30 : tensor<32xi16>, index + %v47_2 = arith.muli %v3_2, %v46_2 : tensor<32xi16> + %v48_2 = arith.muli %v47_2, %arg1 : tensor<32xi16> + %extracted_2 = tensor.extract %v48_2[%c31] : tensor<32xi16> + + %out = arith.addi %extracted, %extracted_2 : i16 + return %out : i16 +}