Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a more thorough target slot analysis #526

Merged
merged 1 commit into from Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/Analysis/TargetSlotAnalysis/BUILD
@@ -0,0 +1,9 @@
# TargetSlotAnalysis analysis pass
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

exports_files(
["TargetSlotAnalysis.h"],
)
134 changes: 134 additions & 0 deletions include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h
@@ -0,0 +1,134 @@
#ifndef INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_
#define INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_

#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace target_slot_analysis {

/// A target slot is an identification of a downstream tensor index at which an
/// SSA value will be used. To make the previous sentence even mildly
/// comprehensible, consider it in the following example.
///
/// %c3 = arith.constant 3 : index
/// %c4 = arith.constant 4 : index
/// %c11 = arith.constant 11 : index
/// %c15 = arith.constant 15 : index
/// %v11 = tensor.extract %arg1[%c11] : tensor<16xi32>
/// %v15 = tensor.extract %arg1[%c15] : tensor<16xi32>
/// %1 = arith.addi %v11, %v15: i32
/// %v3 = tensor.extract %arg1[%c3] : tensor<16xi32>
/// %2 = arith.addi %v3, %1 : i32
/// %inserted = tensor.insert %2 into %output[%c4] : tensor<16xi32>
///
/// In vectorized FHE schemes like BGV, the computation model does not
/// efficiently support extracting values at particular indices; instead, it
/// supports SIMD additions of entire vectors, and cyclic rotations of vectors
/// by constant shifts. To optimize the above computation, we want to convert
/// the extractions to rotations, and minimize rotations as much as possible.
///
/// A naive conversion convert tensor.extract %arg1[Z] to arith.rotate %arg1,
/// Z, always placing the needed values in the zero-th slot. However, the last
/// line above indicates that the downstream dependencies of these computations
/// are ultimately needed in slot 4 of the %output tensor. So one could reduce
/// the number of rotations by rotating instead to slot 4, so that the final
/// rotation is not needed.
///
/// This analysis identifies that downstream insertion index, and propagates it
/// backward through the IR to attach it to each SSA value, enabling later
/// optimization passes to access it easily.
///
/// As it turns out, if the IR is well-structured, such as an unrolled affine
/// for loop with simple iteration strides, then aligning to target slots in
/// this way leads to many common sub-expressions that can be eliminated. Cf.
/// the insert-rotate pass for more on that.

class TargetSlot {
public:
TargetSlot() : value(std::nullopt) {}
TargetSlot(int64_t value) : value(value) {}
~TargetSlot() = default;

/// Whether the slot target is initialized. It can be uninitialized when the
/// state hasn't been set during the analysis.
bool isInitialized() const { return value.has_value(); }

/// Get a known slot target.
const int64_t &getValue() const {
assert(isInitialized());
j2kun marked this conversation as resolved.
Show resolved Hide resolved
return *value;
}

bool operator==(const TargetSlot &rhs) const { return value == rhs.value; }

/// Join two target slots.
j2kun marked this conversation as resolved.
Show resolved Hide resolved
static TargetSlot join(const TargetSlot &lhs, const TargetSlot &rhs) {
if (!lhs.isInitialized()) return rhs;
if (!rhs.isInitialized()) return lhs;
// If they are both initialized, use an arbitrary deterministic rule to
// select one. A more sophisticated analysis could try to determine which
// slot is more likely to lead to beneficial optimizations.
return TargetSlot{lhs.getValue() < rhs.getValue() ? lhs.getValue()
: rhs.getValue()};
}

void print(raw_ostream &os) const { os << value; }

private:
/// The target slot, if known.
std::optional<int64_t> value;

friend mlir::Diagnostic &operator<<(mlir::Diagnostic &diagnostic,
const TargetSlot &foo) {
if (foo.isInitialized()) {
return diagnostic << foo.getValue();
}
return diagnostic << "uninitialized";
}
};

inline raw_ostream &operator<<(raw_ostream &os, const TargetSlot &v) {
v.print(os);
return os;
}

class TargetSlotLattice : public dataflow::Lattice<TargetSlot> {
public:
using Lattice::Lattice;
};

/// An analysis that identifies a target slot for an SSA value in a program.
/// This is used by downstream passes to determine how to align rotations in
/// vectorized FHE schemes.
///
/// We use a backward dataflow analysis because the target slot propagates
/// backward from its final use to the arithmetic operations at which rotations
/// can be optimized.
class TargetSlotAnalysis
: public dataflow::SparseBackwardDataFlowAnalysis<TargetSlotLattice> {
public:
explicit TargetSlotAnalysis(DataFlowSolver &solver,
SymbolTableCollection &symbolTable)
: SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
~TargetSlotAnalysis() override = default;
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;

// Given the computed results of the operation, update its operand lattice
// values.
void visitOperation(Operation *op, ArrayRef<TargetSlotLattice *> operands,
ArrayRef<const TargetSlotLattice *> results) override;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Less of a review comment, more a question for my own benefit/understanding of the analysis framework: what's going on here with these empty implementations?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe these hooks are for adding special behavior at control flow entry and exit points, where by default the framework would have no propagation through them. For something like a liveness analysis

https://github.com/llvm/llvm-project/blob/2a547f0f6c6b456342b9bfad38787f54f265fc96/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp#L104-L106

I don't specifically know which arguments of the branch ops, for example, are "not forwarded" automatically. But for our purposes for now we're assuming there is no control flow.

void visitBranchOperand(OpOperand &operand) override{};
void visitCallOperand(OpOperand &operand) override{};
void setToExitState(TargetSlotLattice *lattice) override{};
};

} // namespace target_slot_analysis
} // namespace heir
} // namespace mlir

#endif // INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_
16 changes: 16 additions & 0 deletions include/Dialect/BUILD
Expand Up @@ -10,6 +10,7 @@ package(
exports_files(
[
"HEIRInterfaces.h",
"Utils.h",
],
)

Expand Down Expand Up @@ -43,3 +44,18 @@ gentbl_cc_library(
":td_files",
],
)

cc_library(
name = "Utils",
srcs = [
"Utils.h",
],
hdrs = [
"Utils.h",
],
deps = [
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
18 changes: 18 additions & 0 deletions lib/Analysis/TargetSlotAnalysis/BUILD
@@ -0,0 +1,18 @@
# TargetSlotAnalysis analysis pass
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "TargetSlotAnalysis",
srcs = ["TargetSlotAnalysis.cpp"],
hdrs = ["@heir//include/Analysis/TargetSlotAnalysis:TargetSlotAnalysis.h"],
deps = [
"@heir//lib/Dialect:Utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:TensorDialect",
],
)
55 changes: 55 additions & 0 deletions lib/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.cpp
@@ -0,0 +1,55 @@
#include "include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h"

#include "lib/Dialect/Utils.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project

#define DEBUG_TYPE "target-slot-analysis"

namespace mlir {
namespace heir {
namespace target_slot_analysis {

void TargetSlotAnalysis::visitOperation(
Operation *op, ArrayRef<TargetSlotLattice *> operands,
ArrayRef<const TargetSlotLattice *> results) {
llvm::TypeSwitch<Operation &>(*op)
.Case<tensor::InsertOp>([&](auto insertOp) {
LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << *op << "\n"; });
auto insertIndexRes = get1DExtractionIndex<tensor::InsertOp>(insertOp);
// If the target slot can't be statically determined, we can't
// propagate anything through the IR.
if (failed(insertIndexRes)) return;

// The target slot propagates to the value inserted, which is the first
// positional argument
TargetSlotLattice *lattice = operands[0];
TargetSlot newSlot = TargetSlot{insertIndexRes.value()};
LLVM_DEBUG({
llvm::dbgs() << "Joining " << lattice->getValue() << " and "
<< newSlot << " --> "
<< TargetSlot::join(lattice->getValue(), newSlot)
<< "\n";
});
ChangeResult changed = lattice->join(newSlot);
propagateIfChanged(lattice, changed);
})
.Default([&](Operation &op) {
// By default, an op propagates its result target slots to all its
// operands.
for (const TargetSlotLattice *r : results) {
for (TargetSlotLattice *operand : operands) {
ChangeResult result = operand->join(*r);
propagateIfChanged(operand, result);
}
}
});
}

} // namespace target_slot_analysis
} // namespace heir
} // namespace mlir
12 changes: 12 additions & 0 deletions lib/Dialect/BUILD
Expand Up @@ -18,3 +18,15 @@ cc_library(
"@llvm-project//mlir:IR",
],
)

cc_library(
name = "Utils",
srcs = [
"Utils.h",
],
deps = [
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
2 changes: 2 additions & 0 deletions lib/Dialect/TensorExt/Transforms/BUILD
Expand Up @@ -28,6 +28,7 @@ cc_library(
"@heir//include/Dialect/TensorExt/IR:canonicalize_inc_gen",
"@heir//include/Dialect/TensorExt/Transforms:insert_rotate_inc_gen",
"@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen",
"@heir//lib/Analysis/TargetSlotAnalysis",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
Expand All @@ -45,6 +46,7 @@ cc_library(
],
deps = [
"@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen",
"@heir//lib/Dialect:Utils",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
Expand Down
20 changes: 1 addition & 19 deletions lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp
Expand Up @@ -4,6 +4,7 @@
#include <utility>

#include "include/Dialect/TensorExt/IR/TensorExtOps.h"
#include "lib/Dialect/Utils.h"
#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
Expand All @@ -27,25 +28,6 @@ namespace tensor_ext {
#define GEN_PASS_DEF_COLLAPSEINSERTIONCHAINS
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc"

template <typename Op>
FailureOr<int64_t> get1DExtractionIndex(Op op) {
auto insertIndices = op.getIndices();
if (insertIndices.size() != 1) return failure();

// Each index must be constant; this may require running --canonicalize or
// -sccp before this pass to apply folding rules (use -sccp if you need to
// fold constants through control flow).
Value insertIndex = *insertIndices.begin();
auto insertIndexConstOp = insertIndex.getDefiningOp<arith::ConstantIndexOp>();
if (!insertIndexConstOp) return failure();

auto insertOffsetAttr =
llvm::dyn_cast<IntegerAttr>(insertIndexConstOp.getValue());
if (!insertOffsetAttr) return failure();

return insertOffsetAttr.getInt();
}

/// A pattern that searches for sequences of extract + insert, where the
/// indices extracted and inserted have the same offset, and replaced them with
/// a single rotate operation.
Expand Down
42 changes: 37 additions & 5 deletions lib/Dialect/TensorExt/Transforms/InsertRotate.cpp
Expand Up @@ -2,14 +2,21 @@

#include <utility>

#include "include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h"
#include "include/Dialect/TensorExt/IR/TensorExtOps.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project

#define DEBUG_TYPE "insert-rotate"

namespace mlir {
namespace heir {
namespace tensor_ext {
Expand All @@ -33,6 +40,31 @@ struct InsertRotate : impl::InsertRotateBase<InsertRotate> {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);

SymbolTableCollection symbolTable;
DataFlowSolver solver;
// These two upstream analyses are required dependencies for any sparse
// dataflow analysis, or else the analysis will be a no-op. Cf.
// https://github.com/llvm/llvm-project/issues/58922
solver.load<dataflow::DeadCodeAnalysis>();
solver.load<dataflow::SparseConstantPropagation>();
solver.load<target_slot_analysis::TargetSlotAnalysis>(symbolTable);
if (failed(solver.initializeAndRun(getOperation()))) {
getOperation()->emitOpError() << "Failed to run the analysis.\n";
signalPassFailure();
return;
}

LLVM_DEBUG({
getOperation()->walk([&](Operation *op) {
if (op->getNumResults() == 0) return;
auto *targetSlotLattice =
solver.lookupState<target_slot_analysis::TargetSlotLattice>(
op->getResult(0));
llvm::dbgs() << "Target slot for op " << *op << ": "
<< targetSlotLattice->getValue() << "\n";
asraa marked this conversation as resolved.
Show resolved Hide resolved
});
});

alignment::populateWithGenerated(patterns);
canonicalization::populateWithGenerated(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
Expand Down