Skip to content

Commit

Permalink
Add target slot analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Mar 15, 2024
1 parent 1fccbc9 commit 989363f
Show file tree
Hide file tree
Showing 10 changed files with 321 additions and 24 deletions.
9 changes: 9 additions & 0 deletions include/Analysis/TargetSlotAnalysis/BUILD
@@ -0,0 +1,9 @@
# TargetSlotAnalysis analysis pass
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

exports_files(
["TargetSlotAnalysis.h"],
)
134 changes: 134 additions & 0 deletions include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h
@@ -0,0 +1,134 @@
#ifndef INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_
#define INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_

#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace target_slot_analysis {

/// A target slot is an identification of a downstream tensor index at which an
/// SSA value will be used. To make the previous sentence even mildly
/// comprehensible, consider it in the following example.
///
/// %c3 = arith.constant 3 : index
/// %c4 = arith.constant 4 : index
/// %c11 = arith.constant 11 : index
/// %c15 = arith.constant 15 : index
/// %v11 = tensor.extract %arg1[%c11] : tensor<16xi32>
/// %v15 = tensor.extract %arg1[%c15] : tensor<16xi32>
/// %1 = arith.addi %v11, %v15: i32
/// %v3 = tensor.extract %arg1[%c3] : tensor<16xi32>
/// %2 = arith.addi %v3, %1 : i32
/// %inserted = tensor.insert %2 into %output[%c4] : tensor<16xi32>
///
/// In vectorized FHE schemes like BGV, the computation model does not
/// efficiently support extracting values at particular indices; instead, it
/// supports SIMD additions of entire vectors, and cyclic rotations of vectors
/// by constant shifts. To optimize the above computation, we want to convert
/// the extractions to rotations, and minimize rotations as much as possible.
///
/// A naive conversion convert tensor.extract %arg1[Z] to arith.rotate %arg1,
/// Z, always placing the needed values in the zero-th slot. However, the last
/// line above indicates that the downstream dependencies of these computations
/// are ultimately needed in slot 4 of the %output tensor. So one could reduce
/// the number of rotations by rotating instead to slot 4, so that the final
/// rotation is not needed.
///
/// This analysis identifies that downstream insertion index, and propagates it
/// backward through the IR to attach it to each SSA value, enabling later
/// optimization passes to access it easily.
///
/// As it turns out, if the IR is well-structured, such as an unrolled affine
/// for loop with simple iteration strides, then aligning to target slots in
/// this way leads to many common sub-expressions that can be eliminated. Cf.
/// the insert-rotate pass for more on that.

class TargetSlot {
public:
TargetSlot() : value(std::nullopt) {}
TargetSlot(int64_t value) : value(value) {}
~TargetSlot() = default;

/// Whether the slot target is initialized. It can be uninitialized when the
/// state hasn't been set during the analysis.
bool isInitialized() const { return value.has_value(); }

/// Get a known slot target.
const int64_t &getValue() const {
assert(isInitialized());
return *value;
}

bool operator==(const TargetSlot &rhs) const { return value == rhs.value; }

/// Join two target slots.
static TargetSlot join(const TargetSlot &lhs, const TargetSlot &rhs) {
if (!lhs.isInitialized()) return rhs;
if (!rhs.isInitialized()) return lhs;
// If they are both initialized, use an arbitrary deterministic rule to
// select one. A more sophisticated analysis could try to determine which
// slot is more likely to lead to beneficial optimizations.
return TargetSlot{lhs.getValue() < rhs.getValue() ? lhs.getValue()
: rhs.getValue()};
}

void print(raw_ostream &os) const { os << value; }

private:
/// The target slot, if known.
std::optional<int64_t> value;

friend mlir::Diagnostic &operator<<(mlir::Diagnostic &diagnostic,
const TargetSlot &foo) {
if (foo.isInitialized()) {
return diagnostic << foo.getValue();
}
return diagnostic << "uninitialized";
}
};

inline raw_ostream &operator<<(raw_ostream &os, const TargetSlot &v) {
v.print(os);
return os;
}

class TargetSlotLattice : public dataflow::Lattice<TargetSlot> {
public:
using Lattice::Lattice;
};

/// An analysis that identifies a target slot for an SSA value in a program.
/// This is used by downstream passes to determine how to align rotations in
/// vectorized FHE schemes.
///
/// We use a backward dataflow analysis because the target slot propagates
/// backward from its final use to the arithmetic operations at which rotations
/// can be optimized.
class TargetSlotAnalysis
: public dataflow::SparseBackwardDataFlowAnalysis<TargetSlotLattice> {
public:
explicit TargetSlotAnalysis(DataFlowSolver &solver,
SymbolTableCollection &symbolTable)
: SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
~TargetSlotAnalysis() override = default;
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;

// Given the computed results of the operation, update its operand lattice
// values.
void visitOperation(Operation *op, ArrayRef<TargetSlotLattice *> operands,
ArrayRef<const TargetSlotLattice *> results) override;

void visitBranchOperand(OpOperand &operand) override{};
void visitCallOperand(OpOperand &operand) override{};
void setToExitState(TargetSlotLattice *lattice) override{};
};

} // namespace target_slot_analysis
} // namespace heir
} // namespace mlir

#endif // INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_
16 changes: 16 additions & 0 deletions include/Dialect/BUILD
Expand Up @@ -10,6 +10,7 @@ package(
exports_files(
[
"HEIRInterfaces.h",
"Utils.h",
],
)

Expand Down Expand Up @@ -43,3 +44,18 @@ gentbl_cc_library(
":td_files",
],
)

cc_library(
name = "Utils",
srcs = [
"Utils.h",
],
hdrs = [
"Utils.h",
],
deps = [
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
18 changes: 18 additions & 0 deletions lib/Analysis/TargetSlotAnalysis/BUILD
@@ -0,0 +1,18 @@
# TargetSlotAnalysis analysis pass
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "TargetSlotAnalysis",
srcs = ["TargetSlotAnalysis.cpp"],
hdrs = ["@heir//include/Analysis/TargetSlotAnalysis:TargetSlotAnalysis.h"],
deps = [
"@heir//lib/Dialect:Utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:TensorDialect",
],
)
55 changes: 55 additions & 0 deletions lib/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.cpp
@@ -0,0 +1,55 @@
#include "include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h"

#include "lib/Dialect/Utils.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project

#define DEBUG_TYPE "target-slot-analysis"

namespace mlir {
namespace heir {
namespace target_slot_analysis {

void TargetSlotAnalysis::visitOperation(
Operation *op, ArrayRef<TargetSlotLattice *> operands,
ArrayRef<const TargetSlotLattice *> results) {
llvm::TypeSwitch<Operation &>(*op)
.Case<tensor::InsertOp>([&](auto insertOp) {
LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << *op << "\n"; });
auto insertIndexRes = get1DExtractionIndex<tensor::InsertOp>(insertOp);
// If the target slot can't be statically determined, we can't
// propagate anything through the IR.
if (failed(insertIndexRes)) return;

// The target slot propagates to the value inserted, which is the first
// positional argument
TargetSlotLattice *lattice = operands[0];
TargetSlot newSlot = TargetSlot{insertIndexRes.value()};
LLVM_DEBUG({
llvm::dbgs() << "Joining " << lattice->getValue() << " and "
<< newSlot << " --> "
<< TargetSlot::join(lattice->getValue(), newSlot)
<< "\n";
});
ChangeResult changed = lattice->join(newSlot);
propagateIfChanged(lattice, changed);
})
.Default([&](Operation &op) {
// By default, an op propagates its result target slots to all its
// operands.
for (const TargetSlotLattice *r : results) {
for (TargetSlotLattice *operand : operands) {
ChangeResult result = operand->join(*r);
propagateIfChanged(operand, result);
}
}
});
}

} // namespace target_slot_analysis
} // namespace heir
} // namespace mlir
12 changes: 12 additions & 0 deletions lib/Dialect/BUILD
Expand Up @@ -18,3 +18,15 @@ cc_library(
"@llvm-project//mlir:IR",
],
)

cc_library(
name = "Utils",
srcs = [
"Utils.h",
],
deps = [
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
2 changes: 2 additions & 0 deletions lib/Dialect/TensorExt/Transforms/BUILD
Expand Up @@ -28,6 +28,7 @@ cc_library(
"@heir//include/Dialect/TensorExt/IR:canonicalize_inc_gen",
"@heir//include/Dialect/TensorExt/Transforms:insert_rotate_inc_gen",
"@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen",
"@heir//lib/Analysis/TargetSlotAnalysis",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
Expand All @@ -45,6 +46,7 @@ cc_library(
],
deps = [
"@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen",
"@heir//lib/Dialect:Utils",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
Expand Down
20 changes: 1 addition & 19 deletions lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp
Expand Up @@ -4,6 +4,7 @@
#include <utility>

#include "include/Dialect/TensorExt/IR/TensorExtOps.h"
#include "lib/Dialect/Utils.h"
#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
Expand All @@ -27,25 +28,6 @@ namespace tensor_ext {
#define GEN_PASS_DEF_COLLAPSEINSERTIONCHAINS
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc"

template <typename Op>
FailureOr<int64_t> get1DExtractionIndex(Op op) {
auto insertIndices = op.getIndices();
if (insertIndices.size() != 1) return failure();

// Each index must be constant; this may require running --canonicalize or
// -sccp before this pass to apply folding rules (use -sccp if you need to
// fold constants through control flow).
Value insertIndex = *insertIndices.begin();
auto insertIndexConstOp = insertIndex.getDefiningOp<arith::ConstantIndexOp>();
if (!insertIndexConstOp) return failure();

auto insertOffsetAttr =
llvm::dyn_cast<IntegerAttr>(insertIndexConstOp.getValue());
if (!insertOffsetAttr) return failure();

return insertOffsetAttr.getInt();
}

/// A pattern that searches for sequences of extract + insert, where the
/// indices extracted and inserted have the same offset, and replaced them with
/// a single rotate operation.
Expand Down
42 changes: 37 additions & 5 deletions lib/Dialect/TensorExt/Transforms/InsertRotate.cpp
Expand Up @@ -2,14 +2,21 @@

#include <utility>

#include "include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h"
#include "include/Dialect/TensorExt/IR/TensorExtOps.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project

#define DEBUG_TYPE "insert-rotate"

namespace mlir {
namespace heir {
namespace tensor_ext {
Expand All @@ -33,6 +40,31 @@ struct InsertRotate : impl::InsertRotateBase<InsertRotate> {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);

SymbolTableCollection symbolTable;
DataFlowSolver solver;
// These two upstream analyses are required dependencies for any sparse
// dataflow analysis, or else the analysis will be a no-op. Cf.
// https://github.com/llvm/llvm-project/issues/58922
solver.load<dataflow::DeadCodeAnalysis>();
solver.load<dataflow::SparseConstantPropagation>();
solver.load<target_slot_analysis::TargetSlotAnalysis>(symbolTable);
if (failed(solver.initializeAndRun(getOperation()))) {
getOperation()->emitOpError() << "Failed to run the analysis.\n";
signalPassFailure();
return;
}

LLVM_DEBUG({
getOperation()->walk([&](Operation *op) {
if (op->getNumResults() == 0) return;
auto *targetSlotLattice =
solver.lookupState<target_slot_analysis::TargetSlotLattice>(
op->getResult(0));
llvm::dbgs() << "Target slot for op " << *op << ": "
<< targetSlotLattice->getValue() << "\n";
});
});

alignment::populateWithGenerated(patterns);
canonicalization::populateWithGenerated(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
Expand Down

0 comments on commit 989363f

Please sign in to comment.