/
TargetSlotAnalysis.h
97 lines (80 loc) · 3.44 KB
/
TargetSlotAnalysis.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#ifndef INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_
#define INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_
#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
namespace mlir {
namespace heir {
namespace target_slot_analysis {
class TargetSlot {
public:
TargetSlot() : value(std::nullopt) {}
TargetSlot(int64_t value) : value(value) {}
~TargetSlot() = default;
/// Whether the slot target is initialized. It can be uninitialized when the
/// state hasn't been set during the analysis.
bool isInitialized() const { return value.has_value(); }
/// Get a known slot target.
const int64_t &getValue() const {
assert(isInitialized());
return *value;
}
bool operator==(const TargetSlot &rhs) const { return value == rhs.value; }
/// Join two target slots.
static TargetSlot join(const TargetSlot &lhs, const TargetSlot &rhs) {
if (!lhs.isInitialized()) return rhs;
if (!rhs.isInitialized()) return lhs;
// If they are both initialized, use an arbitrary deterministic rule to
// select one. A more sophisticated analysis could try to determine which
// slot is more likely to lead to beneficial optimizations.
return TargetSlot{lhs.getValue() < rhs.getValue() ? lhs.getValue()
: rhs.getValue()};
}
void print(raw_ostream &os) const { os << value; }
private:
/// The target slot, if known.
std::optional<int64_t> value;
friend mlir::Diagnostic &operator<<(mlir::Diagnostic &diagnostic,
const TargetSlot &foo) {
if (foo.isInitialized()) {
return diagnostic << foo.getValue();
}
return diagnostic << "uninitialized";
}
};
inline raw_ostream &operator<<(raw_ostream &os, const TargetSlot &v) {
v.print(os);
return os;
}
class TargetSlotLattice : public dataflow::Lattice<TargetSlot> {
public:
using Lattice::Lattice;
};
/// An analysis that identifies a target slot for an SSA value in a program.
/// This is used by downstream passes to determine how to align rotations in
/// vectorized FHE schemes.
///
/// We use a backward dataflow analysis because the target slot propagates
/// backward from its final use to the arithmetic operations at which rotations
/// can be optimized.
class TargetSlotAnalysis
: public dataflow::SparseBackwardDataFlowAnalysis<TargetSlotLattice> {
public:
explicit TargetSlotAnalysis(DataFlowSolver &solver,
SymbolTableCollection &symbolTable)
: SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
~TargetSlotAnalysis() override = default;
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
// Given the computed results of the operation, update its operand lattice
// values.
void visitOperation(Operation *op, ArrayRef<TargetSlotLattice *> operands,
ArrayRef<const TargetSlotLattice *> results) override;
void visitBranchOperand(OpOperand &operand) override{};
void visitCallOperand(OpOperand &operand) override{};
void setToExitState(TargetSlotLattice *lattice) override{};
};
} // namespace target_slot_analysis
} // namespace heir
} // namespace mlir
#endif // INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_