/
RotationAnalysis.cpp
75 lines (65 loc) · 3.11 KB
/
RotationAnalysis.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include "include/Analysis/RotationAnalysis/RotationAnalysis.h"
#include "include/Dialect/TensorExt/IR/TensorExtOps.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
namespace mlir {
namespace heir {
namespace rotation_analysis {
void RotationAnalysis::visitOperation(
Operation *op, ArrayRef<const RotationLattice *> operands,
ArrayRef<RotationLattice *> results) {
llvm::TypeSwitch<Operation &>(*op)
.Case<tensor_ext::RotateOp>([&](auto rotateOp) {
LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << *op << "\n"; });
auto shiftConstantOp =
rotateOp.getShift().template getDefiningOp<arith::ConstantOp>();
// If the rotation shift can't be statically determined, we can't
// propagate anything through the IR.
if (!shiftConstantOp) return;
int64_t shiftValue =
dyn_cast<IntegerAttr>(shiftConstantOp.getValue()).getInt();
// The target slot propagates from the tensor argument to the result;
// the tensor argument is first in the tablegen definition.
const RotationLattice *lattice = operands[0];
RotationSets latticeRotations = lattice->getValue();
// If it's a block argument, then there is no initialized lattice value
// and we can override it with a "zero rotation"
auto blockArg = dyn_cast<BlockArgument>(rotateOp.getTensor());
if (blockArg) {
latticeRotations = RotationSets::from(blockArg);
}
RotationSets rotated =
RotationSets::rotate(latticeRotations, shiftValue);
for (RotationLattice *r : results) {
ChangeResult result = r->join(rotated);
propagateIfChanged(r, result);
}
})
.Default([&](Operation &op) {
// By default, an op propagates its result target slots to all its
// operands.
for (OpOperand &operand : op.getOpOperands()) {
auto *latticeOperand = operands[operand.getOperandNumber()];
for (RotationLattice *r : results) {
ChangeResult result = r->join(*latticeOperand);
// If the operand is a block arg, this additionally treats this as
// a zero rotation. If the underlying tensor differs across
// operands, this will also cause a Status::TooManyTensors.
// Otherwise, the join is a no-op.
result |= r->join(RotationSets::from(operand.get()));
propagateIfChanged(r, result);
}
}
});
}
void RotationAnalysis::setToEntryState(RotationLattice *lattice) {
lattice->getValue().clear();
}
} // namespace rotation_analysis
} // namespace heir
} // namespace mlir