-
Notifications
You must be signed in to change notification settings - Fork 25
/
TargetSlotAnalysis.cpp
55 lines (49 loc) · 2.23 KB
/
TargetSlotAnalysis.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include "include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h"
#include "lib/Dialect/Utils.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#define DEBUG_TYPE "target-slot-analysis"
namespace mlir {
namespace heir {
namespace target_slot_analysis {
void TargetSlotAnalysis::visitOperation(
Operation *op, ArrayRef<TargetSlotLattice *> operands,
ArrayRef<const TargetSlotLattice *> results) {
llvm::TypeSwitch<Operation &>(*op)
.Case<tensor::InsertOp>([&](auto insertOp) {
LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << *op << "\n"; });
auto insertIndexRes = get1DExtractionIndex<tensor::InsertOp>(insertOp);
// If the target slot can't be statically determined, we can't
// propagate anything through the IR.
if (failed(insertIndexRes)) return;
// The target slot propagates to the value inserted, which is the first
// positional argument
TargetSlotLattice *lattice = operands[0];
TargetSlot newSlot = TargetSlot{insertIndexRes.value()};
LLVM_DEBUG({
llvm::dbgs() << "Joining " << lattice->getValue() << " and "
<< newSlot << " --> "
<< TargetSlot::join(lattice->getValue(), newSlot)
<< "\n";
});
ChangeResult changed = lattice->join(newSlot);
propagateIfChanged(lattice, changed);
})
.Default([&](Operation &op) {
// By default, an op propagates its result target slots to all its
// operands.
for (const TargetSlotLattice *r : results) {
for (TargetSlotLattice *operand : operands) {
ChangeResult result = operand->join(*r);
propagateIfChanged(operand, result);
}
}
});
}
} // namespace target_slot_analysis
} // namespace heir
} // namespace mlir