Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
217 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# TargetSlotAnalysis analysis pass | ||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
exports_files( | ||
["TargetSlotAnalysis.h"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
#ifndef INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_ | ||
#define INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_ | ||
|
||
#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace target_slot_analysis { | ||
|
||
class TargetSlot { | ||
public: | ||
TargetSlot() : value(std::nullopt) {} | ||
TargetSlot(int64_t value) : value(value) {} | ||
~TargetSlot() = default; | ||
|
||
/// Whether the slot target is initialized. It can be uninitialized when the | ||
/// state hasn't been set during the analysis. | ||
bool isInitialized() const { return value.has_value(); } | ||
|
||
/// Get a known slot target. | ||
const int64_t &getValue() const { | ||
assert(isInitialized()); | ||
return *value; | ||
} | ||
|
||
bool operator==(const TargetSlot &rhs) const { return value == rhs.value; } | ||
|
||
/// Join two target slots. | ||
static TargetSlot join(const TargetSlot &lhs, const TargetSlot &rhs) { | ||
if (!lhs.isInitialized()) return rhs; | ||
if (!rhs.isInitialized()) return lhs; | ||
// If they are both initialized, use an arbitrary deterministic rule to | ||
// select one. A more sophisticated analysis could try to determine which | ||
// slot is more likely to lead to beneficial optimizations. | ||
return TargetSlot{lhs.getValue() < rhs.getValue() ? lhs.getValue() | ||
: rhs.getValue()}; | ||
} | ||
|
||
void print(raw_ostream &os) const { os << value; } | ||
|
||
private: | ||
/// The target slot, if known. | ||
std::optional<int64_t> value; | ||
|
||
friend mlir::Diagnostic &operator<<(mlir::Diagnostic &diagnostic, | ||
const TargetSlot &foo) { | ||
if (foo.isInitialized()) { | ||
return diagnostic << foo.getValue(); | ||
} | ||
return diagnostic << "uninitialized"; | ||
} | ||
}; | ||
|
||
inline raw_ostream &operator<<(raw_ostream &os, const TargetSlot &v) { | ||
v.print(os); | ||
return os; | ||
} | ||
|
||
class TargetSlotLattice : public dataflow::Lattice<TargetSlot> { | ||
public: | ||
using Lattice::Lattice; | ||
}; | ||
|
||
/// An analysis that identifies a target slot for an SSA value in a program. | ||
/// This is used by downstream passes to determine how to align rotations in | ||
/// vectorized FHE schemes. | ||
/// | ||
/// We use a backward dataflow analysis because the target slot propagates | ||
/// backward from its final use to the arithmetic operations at which rotations | ||
/// can be optimized. | ||
class TargetSlotAnalysis | ||
: public dataflow::SparseBackwardDataFlowAnalysis<TargetSlotLattice> { | ||
public: | ||
explicit TargetSlotAnalysis(DataFlowSolver &solver, | ||
SymbolTableCollection &symbolTable) | ||
: SparseBackwardDataFlowAnalysis(solver, symbolTable) {} | ||
~TargetSlotAnalysis() override = default; | ||
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; | ||
|
||
// Given the computed results of the operation, update its operand lattice | ||
// values. | ||
void visitOperation(Operation *op, ArrayRef<TargetSlotLattice *> operands, | ||
ArrayRef<const TargetSlotLattice *> results) override; | ||
|
||
void visitBranchOperand(OpOperand &operand) override{}; | ||
void visitCallOperand(OpOperand &operand) override{}; | ||
void setToExitState(TargetSlotLattice *lattice) override{}; | ||
}; | ||
|
||
} // namespace target_slot_analysis | ||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# TargetSlotAnalysis analysis pass | ||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
cc_library( | ||
name = "TargetSlotAnalysis", | ||
srcs = ["TargetSlotAnalysis.cpp"], | ||
hdrs = ["@heir//include/Analysis/TargetSlotAnalysis:TargetSlotAnalysis.h"], | ||
deps = [ | ||
"@heir//lib/Dialect:Utils", | ||
"@llvm-project//llvm:Support", | ||
"@llvm-project//mlir:Analysis", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:TensorDialect", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
#include "include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h" | ||
|
||
#include "lib/Dialect/Utils.h" | ||
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project | ||
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project | ||
|
||
#define DEBUG_TYPE "target-slot-analysis" | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace target_slot_analysis { | ||
|
||
void TargetSlotAnalysis::visitOperation( | ||
Operation *op, ArrayRef<TargetSlotLattice *> operands, | ||
ArrayRef<const TargetSlotLattice *> results) { | ||
llvm::TypeSwitch<Operation &>(*op) | ||
.Case<tensor::InsertOp>([&](auto insertOp) { | ||
LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << *op << "\n"; }); | ||
auto insertIndexRes = get1DExtractionIndex<tensor::InsertOp>(insertOp); | ||
// If the target slot can't be statically determined, we can't | ||
// propagate anything through the IR. | ||
if (failed(insertIndexRes)) return; | ||
|
||
// The target slot propagates to the value inserted, which is the first | ||
// positional argument | ||
TargetSlotLattice *lattice = operands[0]; | ||
TargetSlot newSlot = TargetSlot{insertIndexRes.value()}; | ||
LLVM_DEBUG({ | ||
llvm::dbgs() << "Joining " << lattice->getValue() << " and " | ||
<< newSlot << " --> " | ||
<< TargetSlot::join(lattice->getValue(), newSlot) | ||
<< "\n"; | ||
}); | ||
ChangeResult changed = lattice->join(newSlot); | ||
propagateIfChanged(lattice, changed); | ||
}) | ||
.Default([&](Operation &op) { | ||
// By default, an op propagates its result target slots to all its | ||
// operands. | ||
for (const TargetSlotLattice *r : results) { | ||
for (TargetSlotLattice *operand : operands) { | ||
ChangeResult result = operand->join(*r); | ||
propagateIfChanged(operand, result); | ||
} | ||
} | ||
}); | ||
} | ||
|
||
} // namespace target_slot_analysis | ||
} // namespace heir | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters