From cea97e14bd361058dd12f52f334e16a3b7334c5f Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Thu, 14 Mar 2024 17:38:45 -0700 Subject: [PATCH] Implement rotation replacement in rotate-and-reduce --- .../RotationAnalysis/RotationAnalysis.h | 75 +++++++-- .../RotationAnalysis/RotationAnalysis.cpp | 2 - .../TensorExt/Transforms/RotateAndReduce.cpp | 153 +++++++++++++++++- tests/simd/simple_sum.mlir | 1 + 4 files changed, 212 insertions(+), 19 deletions(-) diff --git a/include/Analysis/RotationAnalysis/RotationAnalysis.h b/include/Analysis/RotationAnalysis/RotationAnalysis.h index 4a804bec7..3ddb0bb1d 100644 --- a/include/Analysis/RotationAnalysis/RotationAnalysis.h +++ b/include/Analysis/RotationAnalysis/RotationAnalysis.h @@ -19,22 +19,38 @@ namespace rotation_analysis { // A wrapper around a mapping from tensor SSA values to sets of access indices class RotationSets { public: - enum class Status { Normal, TooManyTensors }; + enum class Status { + // The tensor value has not been set + Uninitialized, + + // The rotation set is in a normal state. + Normal, + + // The rotation set has a property that makes it invalid for later + // optimizations: + // + // - It involves operations touch more than one source tensor (not + // including value-semantic outputs) + Overdetermined + + }; public: - RotationSets() : status(Status::Normal){}; + RotationSets() = default; ~RotationSets() = default; // Clear the member data, i.e., set the value back to an uninitialized // state. void clear() { accessedIndices.clear(); - status = Status::Normal; + status = Status::Uninitialized; } bool empty() const { return accessedIndices.empty(); } - bool isOverdetermined() const { return status == Status::TooManyTensors; } + bool isOverdetermined() const { return status == Status::Overdetermined; } + + bool isUninitialized() const { return status == Status::Uninitialized; } void addRotation(int64_t index) { accessedIndices.insert(index); } @@ -47,6 +63,8 @@ class RotationSets { return accessedIndices; } + Value getTensor() const { return tensor; } + void print(raw_ostream &os) const { os << tensor << ": ["; for (auto index : accessedIndices) { @@ -57,12 +75,17 @@ class RotationSets { static RotationSets overdetermined() { RotationSets sets; - sets.status = Status::TooManyTensors; + sets.status = Status::Overdetermined; return sets; } static RotationSets from(Value tensor) { RotationSets sets; + if (!tensor.getType().isa()) { + sets.status = Status::Uninitialized; + return sets; + } + sets.status = Status::Normal; sets.tensor = tensor; if (auto blockArg = dyn_cast(tensor)) { @@ -75,7 +98,7 @@ class RotationSets { // where an IR repeatedly rotates by 1, to ensure that rotations accumulate // like {1, 2, 3, ...} rather than {1, 1, 1, ...} static RotationSets rotate(const RotationSets &lhs, const int64_t shift) { - if (lhs.status == Status::TooManyTensors) { + if (lhs.status == Status::Overdetermined) { return overdetermined(); } @@ -91,13 +114,15 @@ class RotationSets { } static RotationSets join(const RotationSets &lhs, const RotationSets &rhs) { - if (lhs.status == Status::TooManyTensors || - rhs.status == Status::TooManyTensors) { + if (lhs.status == Status::Overdetermined || + rhs.status == Status::Overdetermined) { return overdetermined(); } - if (rhs.accessedIndices.empty()) return lhs; - if (lhs.accessedIndices.empty()) return rhs; + if (rhs.status == Status::Uninitialized || rhs.accessedIndices.empty()) + return lhs; + if (lhs.status == Status::Uninitialized || lhs.accessedIndices.empty()) + return rhs; if (lhs.tensor != rhs.tensor) { LLVM_DEBUG({ @@ -107,6 +132,10 @@ class RotationSets { return overdetermined(); } + LLVM_DEBUG({ + llvm::dbgs() << "Joining :" << lhs.tensor << " and " << rhs.tensor + << "\n"; + }); RotationSets merged; merged.status = Status::Normal; merged.tensor = lhs.tensor; @@ -119,6 +148,30 @@ class RotationSets { return merged; } + // Assuming two not-overdetermined rotation sets, compute the overlap in + // their access indices. + static RotationSets overlap(const RotationSets &lhs, + const RotationSets &rhs) { + if (lhs.status == Status::Uninitialized || lhs.empty()) { + return lhs; + } + + if (rhs.status == Status::Uninitialized || rhs.empty()) { + return rhs; + } + + RotationSets merged; + merged.status = Status::Normal; + merged.tensor = lhs.tensor; + for (auto index : lhs.accessedIndices) { + if (rhs.accessedIndices.count(index)) merged.addRotation(index); + } + for (auto index : rhs.accessedIndices) { + if (lhs.accessedIndices.count(index)) merged.addRotation(index); + } + return merged; + } + private: /// The accessed indices of a single SSA value of tensor type. Value tensor; @@ -127,7 +180,7 @@ class RotationSets { // intervals as values are added. Java/Guava has RangeSet, and boost has // interval_set. Otherwise might want to roll our own. std::unordered_set accessedIndices; - Status status; + Status status = Status::Uninitialized; }; inline raw_ostream &operator<<(raw_ostream &os, const RotationSets &v) { diff --git a/lib/Analysis/RotationAnalysis/RotationAnalysis.cpp b/lib/Analysis/RotationAnalysis/RotationAnalysis.cpp index d8ed047f3..91377fb3d 100644 --- a/lib/Analysis/RotationAnalysis/RotationAnalysis.cpp +++ b/lib/Analysis/RotationAnalysis/RotationAnalysis.cpp @@ -9,8 +9,6 @@ #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project -#define DEBUG_TYPE "rotation-analysis" - namespace mlir { namespace heir { namespace rotation_analysis { diff --git a/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp b/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp index 095cfc4ba..20d320b65 100644 --- a/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp +++ b/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp @@ -23,8 +23,6 @@ #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project -#define DEBUG_TYPE "rotate-and-reduce" - namespace mlir { namespace heir { namespace tensor_ext { @@ -39,8 +37,121 @@ struct RotateAndReduce : impl::RotateAndReduceBase { using RotateAndReduceBase::RotateAndReduceBase; template - void tryReplace(ArithOp op, DenseSet &visited) { - LLVM_DEBUG(llvm::dbgs() << "Trying to replace " << *op << "\n"); + void tryReplaceRotations(ArithOp op, Value tensor, + DenseSet &visited, + DataFlowSolver &solver) { + // The dataflow analysis provides some guarantees, but not enough + // to prove that we can replace the op with the rotate-and-reduce trick + // while still maintaining program correctness. + // + // We need to do some more complicated checks to ensure that: the op tree + // all contains the same op type (all sum or all mul), and that the + // accessed rotations are included only once in the reduction. + // This cannot be done during the dataflow analysis itself due to the + // monotonicity requirements of the framework. + LLVM_DEBUG(llvm::dbgs() + << "Trying to replace rotations ending in " << *op << "\n"); + SetVector backwardSlice; + BackwardSliceOptions options; + // asserts that the parent op has a single region with a single block. + options.omitBlockArguments = false; + + DenseSet visitedReductionOps; + DenseMap opCounts; + opCounts[op->getName().getStringRef()]++; + + getBackwardSlice(op.getOperation(), &backwardSlice, options); + + for (Operation *upstreamOpPtr : backwardSlice) { + auto result = + llvm::TypeSwitch(upstreamOpPtr) + .Case( + [&](auto upstreamOp) { return success(); }) + // Ignore generic ops + .template Case( + [&](auto upstreamOp) { return success(); }) + .template Case([&](auto + upstreamOp) { + opCounts[upstreamOp->getName().getStringRef()]++; + // More than one reduction op is mixed in the reduction. + if (opCounts.size() > 1) { + LLVM_DEBUG(llvm::dbgs() + << "Not replacing op because reduction " + "contains multiple incompatible ops " + << op->getName() << " and " + << upstreamOp->getName() << "\n"); + return failure(); + } + + // Inspect the lattice values at the join point, + // and fail if there is any overlap + auto *lhsLattice = + solver.lookupState( + upstreamOp.getLhs()); + auto *rhsLattice = + solver.lookupState( + upstreamOp.getRhs()); + LLVM_DEBUG(llvm::dbgs() + << "Computing overlap of " + << "lhs: " << lhsLattice->getValue() << "\n" + << "rhs: " << rhsLattice->getValue() << "\n"); + auto mergedLattice = rotation_analysis::RotationSets::overlap( + lhsLattice->getValue(), rhsLattice->getValue()); + LLVM_DEBUG(llvm::dbgs() + << "Overlap is: " << mergedLattice << "\n"); + if (!mergedLattice.empty()) { + LLVM_DEBUG( + llvm::dbgs() + << "Not replacing op because reduction " + "may not be a simple reduction of the input tensor\n" + << "lhs: " << lhsLattice->getValue() << "\n" + << "rhs: " << rhsLattice->getValue() << "\n"); + return failure(); + } + + visitedReductionOps.insert(upstreamOp); + return success(); + }) + .Default([&](Operation *op) { + LLVM_DEBUG(llvm::dbgs() << "Not continuing because type switch " + "encountered unsupported op " + << op->getName() << "\n"); + return failure(); + }); + + if (failed(result)) { + return; + } + } + + // From here we know we will succeed. + auto b = ImplicitLocOpBuilder(op->getLoc(), op); + Operation *finalOp; + auto tensorShape = tensor.getType().cast().getShape(); + for (int64_t shiftSize = tensorShape[0] / 2; shiftSize > 0; + shiftSize /= 2) { + auto rotatedTensor = b.create( + tensor, b.create(b.getIndexAttr(shiftSize))); + auto addOp = b.create(tensor, rotatedTensor); + finalOp = addOp; + tensor = addOp->getResult(0); + } + + [[maybe_unused]] auto *parentOp = op->getParentOp(); + op->replaceAllUsesWith(finalOp); + LLVM_DEBUG(llvm::dbgs() << "Post-replacement: " << *parentOp << "\n"); + + // Mark all ops in the reduction as visited so we don't try to replace them + // twice. + for (Operation *visitedOp : visitedReductionOps) { + visited.insert(visitedOp); + } + } + + template + void tryReplaceExtractions(ArithOp op, DenseSet &visited) { + LLVM_DEBUG(llvm::dbgs() + << "Trying to replace extractions ending in " << *op << "\n"); SetVector backwardSlice; BackwardSliceOptions options; // asserts that the parent op has a single region with a single block. @@ -232,6 +343,36 @@ struct RotateAndReduce : impl::RotateAndReduceBase { }); DenseSet visited; + + getOperation()->walk( + [&](Operation *op) { + if (op->getNumResults() == 0) return; + auto *targetSlotLattice = + solver.lookupState( + op->getResult(0)); + if (targetSlotLattice->getValue().isUninitialized() || + targetSlotLattice->getValue().isOverdetermined()) { + return; + } + + auto tensor = targetSlotLattice->getValue().getTensor(); + auto accessIndices = + targetSlotLattice->getValue().getAccessedIndices(); + int64_t tensorSize = + tensor.getType().cast().getShape()[0]; + if (accessIndices.size() == tensorSize) { + llvm::TypeSwitch(*op) + .Case([&](auto arithOp) { + tryReplaceRotations(arithOp, tensor, visited, + solver); + }) + .Case([&](auto arithOp) { + tryReplaceRotations(arithOp, tensor, visited, + solver); + }); + } + }); + // Traverse the IR in reverse order so that we can eagerly compute backward // slices for each operation. getOperation()->walk( @@ -241,10 +382,10 @@ struct RotateAndReduce : impl::RotateAndReduceBase { } llvm::TypeSwitch(*op) .Case([&](auto arithOp) { - tryReplace(arithOp, visited); + tryReplaceExtractions(arithOp, visited); }) .Case([&](auto arithOp) { - tryReplace(arithOp, visited); + tryReplaceExtractions(arithOp, visited); }); }); } diff --git a/tests/simd/simple_sum.mlir b/tests/simd/simple_sum.mlir index a7c35c927..2af7833dc 100644 --- a/tests/simd/simple_sum.mlir +++ b/tests/simd/simple_sum.mlir @@ -3,6 +3,7 @@ // RUN: --rotate-and-reduce --canonicalize \ // RUN: %s | FileCheck %s +// Sum all entries of a tensor into a single scalar // CHECK-LABEL: @simple_sum // CHECK: secret.generic // CHECK-COUNT-5: tensor_ext.rotate