Skip to content

Commit

Permalink
Implement rotation replacement in rotate-and-reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Mar 15, 2024
1 parent 82bf9c7 commit cea97e1
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 19 deletions.
75 changes: 64 additions & 11 deletions include/Analysis/RotationAnalysis/RotationAnalysis.h
Expand Up @@ -19,22 +19,38 @@ namespace rotation_analysis {
// A wrapper around a mapping from tensor SSA values to sets of access indices
class RotationSets {
public:
enum class Status { Normal, TooManyTensors };
enum class Status {
// The tensor value has not been set
Uninitialized,

// The rotation set is in a normal state.
Normal,

// The rotation set has a property that makes it invalid for later
// optimizations:
//
// - It involves operations touch more than one source tensor (not
// including value-semantic outputs)
Overdetermined

};

public:
RotationSets() : status(Status::Normal){};
RotationSets() = default;
~RotationSets() = default;

// Clear the member data, i.e., set the value back to an uninitialized
// state.
void clear() {
accessedIndices.clear();
status = Status::Normal;
status = Status::Uninitialized;
}

bool empty() const { return accessedIndices.empty(); }

bool isOverdetermined() const { return status == Status::TooManyTensors; }
bool isOverdetermined() const { return status == Status::Overdetermined; }

bool isUninitialized() const { return status == Status::Uninitialized; }

void addRotation(int64_t index) { accessedIndices.insert(index); }

Expand All @@ -47,6 +63,8 @@ class RotationSets {
return accessedIndices;
}

Value getTensor() const { return tensor; }

void print(raw_ostream &os) const {
os << tensor << ": [";
for (auto index : accessedIndices) {
Expand All @@ -57,12 +75,17 @@ class RotationSets {

static RotationSets overdetermined() {
RotationSets sets;
sets.status = Status::TooManyTensors;
sets.status = Status::Overdetermined;
return sets;
}

static RotationSets from(Value tensor) {
RotationSets sets;
if (!tensor.getType().isa<RankedTensorType>()) {
sets.status = Status::Uninitialized;
return sets;
}

sets.status = Status::Normal;
sets.tensor = tensor;
if (auto blockArg = dyn_cast<BlockArgument>(tensor)) {
Expand All @@ -75,7 +98,7 @@ class RotationSets {
// where an IR repeatedly rotates by 1, to ensure that rotations accumulate
// like {1, 2, 3, ...} rather than {1, 1, 1, ...}
static RotationSets rotate(const RotationSets &lhs, const int64_t shift) {
if (lhs.status == Status::TooManyTensors) {
if (lhs.status == Status::Overdetermined) {
return overdetermined();
}

Expand All @@ -91,13 +114,15 @@ class RotationSets {
}

static RotationSets join(const RotationSets &lhs, const RotationSets &rhs) {
if (lhs.status == Status::TooManyTensors ||
rhs.status == Status::TooManyTensors) {
if (lhs.status == Status::Overdetermined ||
rhs.status == Status::Overdetermined) {
return overdetermined();
}

if (rhs.accessedIndices.empty()) return lhs;
if (lhs.accessedIndices.empty()) return rhs;
if (rhs.status == Status::Uninitialized || rhs.accessedIndices.empty())
return lhs;
if (lhs.status == Status::Uninitialized || lhs.accessedIndices.empty())
return rhs;

if (lhs.tensor != rhs.tensor) {
LLVM_DEBUG({
Expand All @@ -107,6 +132,10 @@ class RotationSets {
return overdetermined();
}

LLVM_DEBUG({
llvm::dbgs() << "Joining :" << lhs.tensor << " and " << rhs.tensor
<< "\n";
});
RotationSets merged;
merged.status = Status::Normal;
merged.tensor = lhs.tensor;
Expand All @@ -119,6 +148,30 @@ class RotationSets {
return merged;
}

// Assuming two not-overdetermined rotation sets, compute the overlap in
// their access indices.
static RotationSets overlap(const RotationSets &lhs,
const RotationSets &rhs) {
if (lhs.status == Status::Uninitialized || lhs.empty()) {
return lhs;
}

if (rhs.status == Status::Uninitialized || rhs.empty()) {
return rhs;
}

RotationSets merged;
merged.status = Status::Normal;
merged.tensor = lhs.tensor;
for (auto index : lhs.accessedIndices) {
if (rhs.accessedIndices.count(index)) merged.addRotation(index);
}
for (auto index : rhs.accessedIndices) {
if (lhs.accessedIndices.count(index)) merged.addRotation(index);
}
return merged;
}

private:
/// The accessed indices of a single SSA value of tensor type.
Value tensor;
Expand All @@ -127,7 +180,7 @@ class RotationSets {
// intervals as values are added. Java/Guava has RangeSet, and boost has
// interval_set. Otherwise might want to roll our own.
std::unordered_set<int64_t> accessedIndices;
Status status;
Status status = Status::Uninitialized;
};

inline raw_ostream &operator<<(raw_ostream &os, const RotationSets &v) {
Expand Down
2 changes: 0 additions & 2 deletions lib/Analysis/RotationAnalysis/RotationAnalysis.cpp
Expand Up @@ -9,8 +9,6 @@
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project

#define DEBUG_TYPE "rotation-analysis"

namespace mlir {
namespace heir {
namespace rotation_analysis {
Expand Down
153 changes: 147 additions & 6 deletions lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp
Expand Up @@ -23,8 +23,6 @@
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project

#define DEBUG_TYPE "rotate-and-reduce"

namespace mlir {
namespace heir {
namespace tensor_ext {
Expand All @@ -39,8 +37,121 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
using RotateAndReduceBase::RotateAndReduceBase;

template <typename ArithOp>
void tryReplace(ArithOp op, DenseSet<Operation *> &visited) {
LLVM_DEBUG(llvm::dbgs() << "Trying to replace " << *op << "\n");
void tryReplaceRotations(ArithOp op, Value tensor,
DenseSet<Operation *> &visited,
DataFlowSolver &solver) {
// The dataflow analysis provides some guarantees, but not enough
// to prove that we can replace the op with the rotate-and-reduce trick
// while still maintaining program correctness.
//
// We need to do some more complicated checks to ensure that: the op tree
// all contains the same op type (all sum or all mul), and that the
// accessed rotations are included only once in the reduction.
// This cannot be done during the dataflow analysis itself due to the
// monotonicity requirements of the framework.
LLVM_DEBUG(llvm::dbgs()
<< "Trying to replace rotations ending in " << *op << "\n");
SetVector<Operation *> backwardSlice;
BackwardSliceOptions options;
// asserts that the parent op has a single region with a single block.
options.omitBlockArguments = false;

DenseSet<Operation *> visitedReductionOps;
DenseMap<llvm::StringRef, int> opCounts;
opCounts[op->getName().getStringRef()]++;

getBackwardSlice(op.getOperation(), &backwardSlice, options);

for (Operation *upstreamOpPtr : backwardSlice) {
auto result =
llvm::TypeSwitch<Operation *, LogicalResult>(upstreamOpPtr)
.Case<arith::ConstantOp, tensor_ext::RotateOp>(
[&](auto upstreamOp) { return success(); })
// Ignore generic ops
.template Case<secret::GenericOp>(
[&](auto upstreamOp) { return success(); })
.template Case<arith::AddIOp, arith::MulIOp>([&](auto
upstreamOp) {
opCounts[upstreamOp->getName().getStringRef()]++;
// More than one reduction op is mixed in the reduction.
if (opCounts.size() > 1) {
LLVM_DEBUG(llvm::dbgs()
<< "Not replacing op because reduction "
"contains multiple incompatible ops "
<< op->getName() << " and "
<< upstreamOp->getName() << "\n");
return failure();
}

// Inspect the lattice values at the join point,
// and fail if there is any overlap
auto *lhsLattice =
solver.lookupState<rotation_analysis::RotationLattice>(
upstreamOp.getLhs());
auto *rhsLattice =
solver.lookupState<rotation_analysis::RotationLattice>(
upstreamOp.getRhs());
LLVM_DEBUG(llvm::dbgs()
<< "Computing overlap of "
<< "lhs: " << lhsLattice->getValue() << "\n"
<< "rhs: " << rhsLattice->getValue() << "\n");
auto mergedLattice = rotation_analysis::RotationSets::overlap(
lhsLattice->getValue(), rhsLattice->getValue());
LLVM_DEBUG(llvm::dbgs()
<< "Overlap is: " << mergedLattice << "\n");
if (!mergedLattice.empty()) {
LLVM_DEBUG(
llvm::dbgs()
<< "Not replacing op because reduction "
"may not be a simple reduction of the input tensor\n"
<< "lhs: " << lhsLattice->getValue() << "\n"
<< "rhs: " << rhsLattice->getValue() << "\n");
return failure();
}

visitedReductionOps.insert(upstreamOp);
return success();
})
.Default([&](Operation *op) {
LLVM_DEBUG(llvm::dbgs() << "Not continuing because type switch "
"encountered unsupported op "
<< op->getName() << "\n");
return failure();
});

if (failed(result)) {
return;
}
}

// From here we know we will succeed.
auto b = ImplicitLocOpBuilder(op->getLoc(), op);
Operation *finalOp;
auto tensorShape = tensor.getType().cast<RankedTensorType>().getShape();
for (int64_t shiftSize = tensorShape[0] / 2; shiftSize > 0;
shiftSize /= 2) {
auto rotatedTensor = b.create<tensor_ext::RotateOp>(
tensor, b.create<arith::ConstantOp>(b.getIndexAttr(shiftSize)));
auto addOp = b.create<ArithOp>(tensor, rotatedTensor);
finalOp = addOp;
tensor = addOp->getResult(0);
}

[[maybe_unused]] auto *parentOp = op->getParentOp();
op->replaceAllUsesWith(finalOp);
LLVM_DEBUG(llvm::dbgs() << "Post-replacement: " << *parentOp << "\n");

// Mark all ops in the reduction as visited so we don't try to replace them
// twice.
for (Operation *visitedOp : visitedReductionOps) {
visited.insert(visitedOp);
}
}

template <typename ArithOp>
void tryReplaceExtractions(ArithOp op, DenseSet<Operation *> &visited) {
LLVM_DEBUG(llvm::dbgs()
<< "Trying to replace extractions ending in " << *op << "\n");
SetVector<Operation *> backwardSlice;
BackwardSliceOptions options;
// asserts that the parent op has a single region with a single block.
Expand Down Expand Up @@ -232,6 +343,36 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
});

DenseSet<Operation *> visited;

getOperation()->walk<WalkOrder::PreOrder, ReverseIterator>(
[&](Operation *op) {
if (op->getNumResults() == 0) return;
auto *targetSlotLattice =
solver.lookupState<rotation_analysis::RotationLattice>(
op->getResult(0));
if (targetSlotLattice->getValue().isUninitialized() ||
targetSlotLattice->getValue().isOverdetermined()) {
return;
}

auto tensor = targetSlotLattice->getValue().getTensor();
auto accessIndices =
targetSlotLattice->getValue().getAccessedIndices();
int64_t tensorSize =
tensor.getType().cast<RankedTensorType>().getShape()[0];
if (accessIndices.size() == tensorSize) {
llvm::TypeSwitch<Operation &>(*op)
.Case<arith::AddIOp>([&](auto arithOp) {
tryReplaceRotations<arith::AddIOp>(arithOp, tensor, visited,
solver);
})
.Case<arith::MulIOp>([&](auto arithOp) {
tryReplaceRotations<arith::MulIOp>(arithOp, tensor, visited,
solver);
});
}
});

// Traverse the IR in reverse order so that we can eagerly compute backward
// slices for each operation.
getOperation()->walk<WalkOrder::PreOrder, ReverseIterator>(
Expand All @@ -241,10 +382,10 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
}
llvm::TypeSwitch<Operation &>(*op)
.Case<arith::AddIOp>([&](auto arithOp) {
tryReplace<arith::AddIOp>(arithOp, visited);
tryReplaceExtractions<arith::AddIOp>(arithOp, visited);
})
.Case<arith::MulIOp>([&](auto arithOp) {
tryReplace<arith::MulIOp>(arithOp, visited);
tryReplaceExtractions<arith::MulIOp>(arithOp, visited);
});
});
}
Expand Down
1 change: 1 addition & 0 deletions tests/simd/simple_sum.mlir
Expand Up @@ -3,6 +3,7 @@
// RUN: --rotate-and-reduce --canonicalize \
// RUN: %s | FileCheck %s

// Sum all entries of a tensor into a single scalar
// CHECK-LABEL: @simple_sum
// CHECK: secret.generic
// CHECK-COUNT-5: tensor_ext.rotate
Expand Down

0 comments on commit cea97e1

Please sign in to comment.