-
Notifications
You must be signed in to change notification settings - Fork 25
/
TargetSlotAnalysis.h
134 lines (116 loc) · 5.39 KB
/
TargetSlotAnalysis.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#ifndef INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_
#define INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_
#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
namespace mlir {
namespace heir {
namespace target_slot_analysis {
/// A target slot is an identification of a downstream tensor index at which an
/// SSA value will be used. To make the previous sentence even mildly
/// comprehensible, consider it in the following example.
///
/// %c3 = arith.constant 3 : index
/// %c4 = arith.constant 4 : index
/// %c11 = arith.constant 11 : index
/// %c15 = arith.constant 15 : index
/// %v11 = tensor.extract %arg1[%c11] : tensor<16xi32>
/// %v15 = tensor.extract %arg1[%c15] : tensor<16xi32>
/// %1 = arith.addi %v11, %v15: i32
/// %v3 = tensor.extract %arg1[%c3] : tensor<16xi32>
/// %2 = arith.addi %v3, %1 : i32
/// %inserted = tensor.insert %2 into %output[%c4] : tensor<16xi32>
///
/// In vectorized FHE schemes like BGV, the computation model does not
/// efficiently support extracting values at particular indices; instead, it
/// supports SIMD additions of entire vectors, and cyclic rotations of vectors
/// by constant shifts. To optimize the above computation, we want to convert
/// the extractions to rotations, and minimize rotations as much as possible.
///
/// A naive conversion convert tensor.extract %arg1[Z] to arith.rotate %arg1,
/// Z, always placing the needed values in the zero-th slot. However, the last
/// line above indicates that the downstream dependencies of these computations
/// are ultimately needed in slot 4 of the %output tensor. So one could reduce
/// the number of rotations by rotating instead to slot 4, so that the final
/// rotation is not needed.
///
/// This analysis identifies that downstream insertion index, and propagates it
/// backward through the IR to attach it to each SSA value, enabling later
/// optimization passes to access it easily.
///
/// As it turns out, if the IR is well-structured, such as an unrolled affine
/// for loop with simple iteration strides, then aligning to target slots in
/// this way leads to many common sub-expressions that can be eliminated. Cf.
/// the insert-rotate pass for more on that.
class TargetSlot {
public:
TargetSlot() : value(std::nullopt) {}
TargetSlot(int64_t value) : value(value) {}
~TargetSlot() = default;
/// Whether the slot target is initialized. It can be uninitialized when the
/// state hasn't been set during the analysis.
bool isInitialized() const { return value.has_value(); }
/// Get a known slot target.
const int64_t &getValue() const {
assert(isInitialized());
return *value;
}
bool operator==(const TargetSlot &rhs) const { return value == rhs.value; }
/// Join two target slots.
static TargetSlot join(const TargetSlot &lhs, const TargetSlot &rhs) {
if (!lhs.isInitialized()) return rhs;
if (!rhs.isInitialized()) return lhs;
// If they are both initialized, use an arbitrary deterministic rule to
// select one. A more sophisticated analysis could try to determine which
// slot is more likely to lead to beneficial optimizations.
return TargetSlot{lhs.getValue() < rhs.getValue() ? lhs.getValue()
: rhs.getValue()};
}
void print(raw_ostream &os) const { os << value; }
private:
/// The target slot, if known.
std::optional<int64_t> value;
friend mlir::Diagnostic &operator<<(mlir::Diagnostic &diagnostic,
const TargetSlot &foo) {
if (foo.isInitialized()) {
return diagnostic << foo.getValue();
}
return diagnostic << "uninitialized";
}
};
inline raw_ostream &operator<<(raw_ostream &os, const TargetSlot &v) {
v.print(os);
return os;
}
class TargetSlotLattice : public dataflow::Lattice<TargetSlot> {
public:
using Lattice::Lattice;
};
/// An analysis that identifies a target slot for an SSA value in a program.
/// This is used by downstream passes to determine how to align rotations in
/// vectorized FHE schemes.
///
/// We use a backward dataflow analysis because the target slot propagates
/// backward from its final use to the arithmetic operations at which rotations
/// can be optimized.
class TargetSlotAnalysis
: public dataflow::SparseBackwardDataFlowAnalysis<TargetSlotLattice> {
public:
explicit TargetSlotAnalysis(DataFlowSolver &solver,
SymbolTableCollection &symbolTable)
: SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
~TargetSlotAnalysis() override = default;
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
// Given the computed results of the operation, update its operand lattice
// values.
void visitOperation(Operation *op, ArrayRef<TargetSlotLattice *> operands,
ArrayRef<const TargetSlotLattice *> results) override;
void visitBranchOperand(OpOperand &operand) override{};
void visitCallOperand(OpOperand &operand) override{};
void setToExitState(TargetSlotLattice *lattice) override{};
};
} // namespace target_slot_analysis
} // namespace heir
} // namespace mlir
#endif // INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_