diff --git a/include/Analysis/TargetSlotAnalysis/BUILD b/include/Analysis/TargetSlotAnalysis/BUILD new file mode 100644 index 000000000..085e170b2 --- /dev/null +++ b/include/Analysis/TargetSlotAnalysis/BUILD @@ -0,0 +1,9 @@ +# TargetSlotAnalysis analysis pass +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files( + ["TargetSlotAnalysis.h"], +) diff --git a/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h b/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h new file mode 100644 index 000000000..984fb338c --- /dev/null +++ b/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h @@ -0,0 +1,134 @@ +#ifndef INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_ +#define INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_ + +#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace target_slot_analysis { + +/// A target slot is an identification of a downstream tensor index at which an +/// SSA value will be used. To make the previous sentence even mildly +/// comprehensible, consider it in the following example. +/// +/// %c3 = arith.constant 3 : index +/// %c4 = arith.constant 4 : index +/// %c11 = arith.constant 11 : index +/// %c15 = arith.constant 15 : index +/// %v11 = tensor.extract %arg1[%c11] : tensor<16xi32> +/// %v15 = tensor.extract %arg1[%c15] : tensor<16xi32> +/// %1 = arith.addi %v11, %v15: i32 +/// %v3 = tensor.extract %arg1[%c3] : tensor<16xi32> +/// %2 = arith.addi %v3, %1 : i32 +/// %inserted = tensor.insert %2 into %output[%c4] : tensor<16xi32> +/// +/// In vectorized FHE schemes like BGV, the computation model does not +/// efficiently support extracting values at particular indices; instead, it +/// supports SIMD additions of entire vectors, and cyclic rotations of vectors +/// by constant shifts. To optimize the above computation, we want to convert +/// the extractions to rotations, and minimize rotations as much as possible. +/// +/// A naive conversion convert tensor.extract %arg1[Z] to arith.rotate %arg1, +/// Z, always placing the needed values in the zero-th slot. However, the last +/// line above indicates that the downstream dependencies of these computations +/// are ultimately needed in slot 4 of the %output tensor. So one could reduce +/// the number of rotations by rotating instead to slot 4, so that the final +/// rotation is not needed. +/// +/// This analysis identifies that downstream insertion index, and propagates it +/// backward through the IR to attach it to each SSA value, enabling later +/// optimization passes to access it easily. +/// +/// As it turns out, if the IR is well-structured, such as an unrolled affine +/// for loop with simple iteration strides, then aligning to target slots in +/// this way leads to many common sub-expressions that can be eliminated. Cf. +/// the insert-rotate pass for more on that. + +class TargetSlot { + public: + TargetSlot() : value(std::nullopt) {} + TargetSlot(int64_t value) : value(value) {} + ~TargetSlot() = default; + + /// Whether the slot target is initialized. It can be uninitialized when the + /// state hasn't been set during the analysis. + bool isInitialized() const { return value.has_value(); } + + /// Get a known slot target. + const int64_t &getValue() const { + assert(isInitialized()); + return *value; + } + + bool operator==(const TargetSlot &rhs) const { return value == rhs.value; } + + /// Join two target slots. + static TargetSlot join(const TargetSlot &lhs, const TargetSlot &rhs) { + if (!lhs.isInitialized()) return rhs; + if (!rhs.isInitialized()) return lhs; + // If they are both initialized, use an arbitrary deterministic rule to + // select one. A more sophisticated analysis could try to determine which + // slot is more likely to lead to beneficial optimizations. + return TargetSlot{lhs.getValue() < rhs.getValue() ? lhs.getValue() + : rhs.getValue()}; + } + + void print(raw_ostream &os) const { os << value; } + + private: + /// The target slot, if known. + std::optional value; + + friend mlir::Diagnostic &operator<<(mlir::Diagnostic &diagnostic, + const TargetSlot &foo) { + if (foo.isInitialized()) { + return diagnostic << foo.getValue(); + } + return diagnostic << "uninitialized"; + } +}; + +inline raw_ostream &operator<<(raw_ostream &os, const TargetSlot &v) { + v.print(os); + return os; +} + +class TargetSlotLattice : public dataflow::Lattice { + public: + using Lattice::Lattice; +}; + +/// An analysis that identifies a target slot for an SSA value in a program. +/// This is used by downstream passes to determine how to align rotations in +/// vectorized FHE schemes. +/// +/// We use a backward dataflow analysis because the target slot propagates +/// backward from its final use to the arithmetic operations at which rotations +/// can be optimized. +class TargetSlotAnalysis + : public dataflow::SparseBackwardDataFlowAnalysis { + public: + explicit TargetSlotAnalysis(DataFlowSolver &solver, + SymbolTableCollection &symbolTable) + : SparseBackwardDataFlowAnalysis(solver, symbolTable) {} + ~TargetSlotAnalysis() override = default; + using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; + + // Given the computed results of the operation, update its operand lattice + // values. + void visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override; + + void visitBranchOperand(OpOperand &operand) override{}; + void visitCallOperand(OpOperand &operand) override{}; + void setToExitState(TargetSlotLattice *lattice) override{}; +}; + +} // namespace target_slot_analysis +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_ diff --git a/include/Dialect/BUILD b/include/Dialect/BUILD index bebb71ecf..3eb3ce231 100644 --- a/include/Dialect/BUILD +++ b/include/Dialect/BUILD @@ -10,6 +10,7 @@ package( exports_files( [ "HEIRInterfaces.h", + "Utils.h", ], ) @@ -43,3 +44,18 @@ gentbl_cc_library( ":td_files", ], ) + +cc_library( + name = "Utils", + srcs = [ + "Utils.h", + ], + hdrs = [ + "Utils.h", + ], + deps = [ + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) diff --git a/lib/Analysis/TargetSlotAnalysis/BUILD b/lib/Analysis/TargetSlotAnalysis/BUILD new file mode 100644 index 000000000..cf7438e35 --- /dev/null +++ b/lib/Analysis/TargetSlotAnalysis/BUILD @@ -0,0 +1,18 @@ +# TargetSlotAnalysis analysis pass +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "TargetSlotAnalysis", + srcs = ["TargetSlotAnalysis.cpp"], + hdrs = ["@heir//include/Analysis/TargetSlotAnalysis:TargetSlotAnalysis.h"], + deps = [ + "@heir//lib/Dialect:Utils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TensorDialect", + ], +) diff --git a/lib/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.cpp b/lib/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.cpp new file mode 100644 index 000000000..cdb4b41eb --- /dev/null +++ b/lib/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.cpp @@ -0,0 +1,55 @@ +#include "include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h" + +#include "lib/Dialect/Utils.h" +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project + +#define DEBUG_TYPE "target-slot-analysis" + +namespace mlir { +namespace heir { +namespace target_slot_analysis { + +void TargetSlotAnalysis::visitOperation( + Operation *op, ArrayRef operands, + ArrayRef results) { + llvm::TypeSwitch(*op) + .Case([&](auto insertOp) { + LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << *op << "\n"; }); + auto insertIndexRes = get1DExtractionIndex(insertOp); + // If the target slot can't be statically determined, we can't + // propagate anything through the IR. + if (failed(insertIndexRes)) return; + + // The target slot propagates to the value inserted, which is the first + // positional argument + TargetSlotLattice *lattice = operands[0]; + TargetSlot newSlot = TargetSlot{insertIndexRes.value()}; + LLVM_DEBUG({ + llvm::dbgs() << "Joining " << lattice->getValue() << " and " + << newSlot << " --> " + << TargetSlot::join(lattice->getValue(), newSlot) + << "\n"; + }); + ChangeResult changed = lattice->join(newSlot); + propagateIfChanged(lattice, changed); + }) + .Default([&](Operation &op) { + // By default, an op propagates its result target slots to all its + // operands. + for (const TargetSlotLattice *r : results) { + for (TargetSlotLattice *operand : operands) { + ChangeResult result = operand->join(*r); + propagateIfChanged(operand, result); + } + } + }); +} + +} // namespace target_slot_analysis +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/BUILD b/lib/Dialect/BUILD index fb5356f75..9d031b48e 100644 --- a/lib/Dialect/BUILD +++ b/lib/Dialect/BUILD @@ -18,3 +18,15 @@ cc_library( "@llvm-project//mlir:IR", ], ) + +cc_library( + name = "Utils", + srcs = [ + "Utils.h", + ], + deps = [ + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) diff --git a/lib/Dialect/TensorExt/Transforms/BUILD b/lib/Dialect/TensorExt/Transforms/BUILD index 4258b08ff..bfcc3fad2 100644 --- a/lib/Dialect/TensorExt/Transforms/BUILD +++ b/lib/Dialect/TensorExt/Transforms/BUILD @@ -28,6 +28,7 @@ cc_library( "@heir//include/Dialect/TensorExt/IR:canonicalize_inc_gen", "@heir//include/Dialect/TensorExt/Transforms:insert_rotate_inc_gen", "@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen", + "@heir//lib/Analysis/TargetSlotAnalysis", "@heir//lib/Dialect/TensorExt/IR:Dialect", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:IR", @@ -45,6 +46,7 @@ cc_library( ], deps = [ "@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen", + "@heir//lib/Dialect:Utils", "@heir//lib/Dialect/TensorExt/IR:Dialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", diff --git a/lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp b/lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp index 63c2a9983..f64337780 100644 --- a/lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp +++ b/lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp @@ -4,6 +4,7 @@ #include #include "include/Dialect/TensorExt/IR/TensorExtOps.h" +#include "lib/Dialect/Utils.h" #include "llvm/include/llvm/Support/Casting.h" // from @llvm-project #include "llvm/include/llvm/Support/Debug.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project @@ -27,25 +28,6 @@ namespace tensor_ext { #define GEN_PASS_DEF_COLLAPSEINSERTIONCHAINS #include "include/Dialect/TensorExt/Transforms/Passes.h.inc" -template -FailureOr get1DExtractionIndex(Op op) { - auto insertIndices = op.getIndices(); - if (insertIndices.size() != 1) return failure(); - - // Each index must be constant; this may require running --canonicalize or - // -sccp before this pass to apply folding rules (use -sccp if you need to - // fold constants through control flow). - Value insertIndex = *insertIndices.begin(); - auto insertIndexConstOp = insertIndex.getDefiningOp(); - if (!insertIndexConstOp) return failure(); - - auto insertOffsetAttr = - llvm::dyn_cast(insertIndexConstOp.getValue()); - if (!insertOffsetAttr) return failure(); - - return insertOffsetAttr.getInt(); -} - /// A pattern that searches for sequences of extract + insert, where the /// indices extracted and inserted have the same offset, and replaced them with /// a single rotate operation. diff --git a/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp b/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp index c829ba34d..2641792dd 100644 --- a/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp +++ b/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp @@ -2,14 +2,21 @@ #include +#include "include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h" #include "include/Dialect/TensorExt/IR/TensorExtOps.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project -#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/include/mlir/IR/Matchers.h" // from @llvm-project -#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/include/mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#define DEBUG_TYPE "insert-rotate" + namespace mlir { namespace heir { namespace tensor_ext { @@ -33,6 +40,31 @@ struct InsertRotate : impl::InsertRotateBase { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); + SymbolTableCollection symbolTable; + DataFlowSolver solver; + // These two upstream analyses are required dependencies for any sparse + // dataflow analysis, or else the analysis will be a no-op. Cf. + // https://github.com/llvm/llvm-project/issues/58922 + solver.load(); + solver.load(); + solver.load(symbolTable); + if (failed(solver.initializeAndRun(getOperation()))) { + getOperation()->emitOpError() << "Failed to run the analysis.\n"; + signalPassFailure(); + return; + } + + LLVM_DEBUG({ + getOperation()->walk([&](Operation *op) { + if (op->getNumResults() == 0) return; + auto *targetSlotLattice = + solver.lookupState( + op->getResult(0)); + llvm::dbgs() << "Target slot for op " << *op << ": " + << targetSlotLattice->getValue() << "\n"; + }); + }); + alignment::populateWithGenerated(patterns); canonicalization::populateWithGenerated(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); diff --git a/lib/Dialect/Utils.h b/lib/Dialect/Utils.h new file mode 100644 index 000000000..4c25a5e04 --- /dev/null +++ b/lib/Dialect/Utils.h @@ -0,0 +1,37 @@ +#ifndef INCLUDE_DIALECT_UTILS_H_ +#define INCLUDE_DIALECT_UTILS_H_ + +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace heir { + +/// Given a tensor::InsertOp or tensor::ExtractOp, and assuming the shape +/// of the input tensor is 1-dimensional and the input index is constant, +/// return the constant index value. If any of these conditions are not +/// met, return a failure. +template +FailureOr get1DExtractionIndex(Op op) { + auto insertIndices = op.getIndices(); + if (insertIndices.size() != 1) return failure(); + + // Each index must be constant; this may require running --canonicalize or + // -sccp before this pass to apply folding rules (use -sccp if you need to + // fold constants through control flow). + Value insertIndex = *insertIndices.begin(); + auto insertIndexConstOp = insertIndex.getDefiningOp(); + if (!insertIndexConstOp) return failure(); + + auto insertOffsetAttr = + llvm::dyn_cast(insertIndexConstOp.getValue()); + if (!insertOffsetAttr) return failure(); + + return insertOffsetAttr.getInt(); +} + +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_DIALECT_UTILS_H_