Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
54 changed files
with
2,175 additions
and
101 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# TargetSlotAnalysis analysis pass | ||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
exports_files( | ||
["TargetSlotAnalysis.h"], | ||
) |
134 changes: 134 additions & 0 deletions
134
include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
#ifndef INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_ | ||
#define INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_ | ||
|
||
#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace target_slot_analysis { | ||
|
||
/// A target slot is an identification of a downstream tensor index at which an | ||
/// SSA value will be used. To make the previous sentence even mildly | ||
/// comprehensible, consider it in the following example. | ||
/// | ||
/// %c3 = arith.constant 3 : index | ||
/// %c4 = arith.constant 4 : index | ||
/// %c11 = arith.constant 11 : index | ||
/// %c15 = arith.constant 15 : index | ||
/// %v11 = tensor.extract %arg1[%c11] : tensor<16xi32> | ||
/// %v15 = tensor.extract %arg1[%c15] : tensor<16xi32> | ||
/// %1 = arith.addi %v11, %v15: i32 | ||
/// %v3 = tensor.extract %arg1[%c3] : tensor<16xi32> | ||
/// %2 = arith.addi %v3, %1 : i32 | ||
/// %inserted = tensor.insert %2 into %output[%c4] : tensor<16xi32> | ||
/// | ||
/// In vectorized FHE schemes like BGV, the computation model does not | ||
/// efficiently support extracting values at particular indices; instead, it | ||
/// supports SIMD additions of entire vectors, and cyclic rotations of vectors | ||
/// by constant shifts. To optimize the above computation, we want to convert | ||
/// the extractions to rotations, and minimize rotations as much as possible. | ||
/// | ||
/// A naive conversion convert tensor.extract %arg1[Z] to arith.rotate %arg1, | ||
/// Z, always placing the needed values in the zero-th slot. However, the last | ||
/// line above indicates that the downstream dependencies of these computations | ||
/// are ultimately needed in slot 4 of the %output tensor. So one could reduce | ||
/// the number of rotations by rotating instead to slot 4, so that the final | ||
/// rotation is not needed. | ||
/// | ||
/// This analysis identifies that downstream insertion index, and propagates it | ||
/// backward through the IR to attach it to each SSA value, enabling later | ||
/// optimization passes to access it easily. | ||
/// | ||
/// As it turns out, if the IR is well-structured, such as an unrolled affine | ||
/// for loop with simple iteration strides, then aligning to target slots in | ||
/// this way leads to many common sub-expressions that can be eliminated. Cf. | ||
/// the insert-rotate pass for more on that. | ||
|
||
class TargetSlot { | ||
public: | ||
TargetSlot() : value(std::nullopt) {} | ||
TargetSlot(int64_t value) : value(value) {} | ||
~TargetSlot() = default; | ||
|
||
/// Whether the slot target is initialized. It can be uninitialized when the | ||
/// state hasn't been set during the analysis. | ||
bool isInitialized() const { return value.has_value(); } | ||
|
||
/// Get a known slot target. | ||
const int64_t &getValue() const { | ||
assert(isInitialized()); | ||
return *value; | ||
} | ||
|
||
bool operator==(const TargetSlot &rhs) const { return value == rhs.value; } | ||
|
||
/// Join two target slots. | ||
static TargetSlot join(const TargetSlot &lhs, const TargetSlot &rhs) { | ||
if (!lhs.isInitialized()) return rhs; | ||
if (!rhs.isInitialized()) return lhs; | ||
// If they are both initialized, use an arbitrary deterministic rule to | ||
// select one. A more sophisticated analysis could try to determine which | ||
// slot is more likely to lead to beneficial optimizations. | ||
return TargetSlot{lhs.getValue() < rhs.getValue() ? lhs.getValue() | ||
: rhs.getValue()}; | ||
} | ||
|
||
void print(raw_ostream &os) const { os << value; } | ||
|
||
private: | ||
/// The target slot, if known. | ||
std::optional<int64_t> value; | ||
|
||
friend mlir::Diagnostic &operator<<(mlir::Diagnostic &diagnostic, | ||
const TargetSlot &foo) { | ||
if (foo.isInitialized()) { | ||
return diagnostic << foo.getValue(); | ||
} | ||
return diagnostic << "uninitialized"; | ||
} | ||
}; | ||
|
||
inline raw_ostream &operator<<(raw_ostream &os, const TargetSlot &v) { | ||
v.print(os); | ||
return os; | ||
} | ||
|
||
class TargetSlotLattice : public dataflow::Lattice<TargetSlot> { | ||
public: | ||
using Lattice::Lattice; | ||
}; | ||
|
||
/// An analysis that identifies a target slot for an SSA value in a program. | ||
/// This is used by downstream passes to determine how to align rotations in | ||
/// vectorized FHE schemes. | ||
/// | ||
/// We use a backward dataflow analysis because the target slot propagates | ||
/// backward from its final use to the arithmetic operations at which rotations | ||
/// can be optimized. | ||
class TargetSlotAnalysis | ||
: public dataflow::SparseBackwardDataFlowAnalysis<TargetSlotLattice> { | ||
public: | ||
explicit TargetSlotAnalysis(DataFlowSolver &solver, | ||
SymbolTableCollection &symbolTable) | ||
: SparseBackwardDataFlowAnalysis(solver, symbolTable) {} | ||
~TargetSlotAnalysis() override = default; | ||
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; | ||
|
||
// Given the computed results of the operation, update its operand lattice | ||
// values. | ||
void visitOperation(Operation *op, ArrayRef<TargetSlotLattice *> operands, | ||
ArrayRef<const TargetSlotLattice *> results) override; | ||
|
||
void visitBranchOperand(OpOperand &operand) override{}; | ||
void visitCallOperand(OpOperand &operand) override{}; | ||
void setToExitState(TargetSlotLattice *lattice) override{}; | ||
}; | ||
|
||
} // namespace target_slot_analysis | ||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_ROTATEANDREDUCE_H_ | ||
#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_ROTATEANDREDUCE_H_ | ||
|
||
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace tensor_ext { | ||
|
||
#define GEN_PASS_DECL_ROTATEANDREDUCE | ||
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc" | ||
|
||
} // namespace tensor_ext | ||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_ROTATEANDREDUCE_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# ElementwiseToAffine tablegen and headers. | ||
|
||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") | ||
|
||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
exports_files([ | ||
"ElementwiseToAffine.h", | ||
]) | ||
|
||
gentbl_cc_library( | ||
name = "pass_inc_gen", | ||
tbl_outs = [ | ||
( | ||
[ | ||
"-gen-pass-decls", | ||
"-name=ElementwiseToAffine", | ||
], | ||
"ElementwiseToAffine.h.inc", | ||
), | ||
( | ||
["-gen-pass-doc"], | ||
"ElementwiseToAffinePasses.md", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "ElementwiseToAffine.td", | ||
deps = [ | ||
"@llvm-project//mlir:OpBaseTdFiles", | ||
"@llvm-project//mlir:PassBaseTdFiles", | ||
], | ||
) |
18 changes: 18 additions & 0 deletions
18
include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#ifndef INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_H_ | ||
#define INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_H_ | ||
|
||
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
|
||
#define GEN_PASS_DECL | ||
#include "include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h.inc" | ||
|
||
#define GEN_PASS_REGISTRATION | ||
#include "include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h.inc" | ||
|
||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_H_ |
18 changes: 18 additions & 0 deletions
18
include/Transforms/ElementwiseToAffine/ElementwiseToAffine.td
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#ifndef INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_TD_ | ||
#define INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_TD_ | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def ElementwiseToAffine : Pass<"convert-elementwise-to-affine"> { | ||
let summary = "This pass lowers ElementwiseMappable operations to Affine loops."; | ||
let description = [{ | ||
This pass lowers ElementwiseMappable operations over tensors | ||
to affine loop nests that instead apply the operation to the underlying scalar values. | ||
}]; | ||
let dependentDialects = [ | ||
"mlir::affine::AffineDialect", | ||
"mlir::tensor::TensorDialect" | ||
]; | ||
} | ||
|
||
#endif // INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_TD_ |
Oops, something went wrong.