/
InsertRotate.cpp
80 lines (67 loc) · 3.19 KB
/
InsertRotate.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include "include/Dialect/TensorExt/Transforms/InsertRotate.h"
#include <utility>
#include "include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h"
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#define DEBUG_TYPE "insert-rotate"
namespace mlir {
namespace heir {
namespace tensor_ext {
#define GEN_PASS_DEF_INSERTROTATE
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc"
namespace alignment {
// In an inner namespace to avoid conflicts with canonicalization patterns
#include "include/Dialect/TensorExt/Transforms/InsertRotate.cpp.inc"
} // namespace alignment
namespace canonicalization {
#include "include/Dialect/TensorExt/IR/TensorExtCanonicalization.cpp.inc"
} // namespace canonicalization
struct InsertRotate : impl::InsertRotateBase<InsertRotate> {
using InsertRotateBase::InsertRotateBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
SymbolTableCollection symbolTable;
DataFlowSolver solver;
// These two upstream analyses are required dependencies for any sparse
// dataflow analysis, or else the analysis will be a no-op. Cf.
// https://github.com/llvm/llvm-project/issues/58922
solver.load<dataflow::DeadCodeAnalysis>();
solver.load<dataflow::SparseConstantPropagation>();
solver.load<target_slot_analysis::TargetSlotAnalysis>(symbolTable);
if (failed(solver.initializeAndRun(getOperation()))) {
getOperation()->emitOpError() << "Failed to run the analysis.\n";
signalPassFailure();
return;
}
// Annotate all arith ops with their target slot attribute, so that it can
// be matched in the DRR rules.
OpBuilder builder(context);
getOperation()->walk([&](Operation *op) {
if (op->getNumResults() == 0) return;
auto *targetSlotLattice =
solver.lookupState<target_slot_analysis::TargetSlotLattice>(
op->getResult(0));
if (targetSlotLattice->getValue().isInitialized()) {
op->setAttr(
"target_slot",
builder.getIndexAttr(targetSlotLattice->getValue().getValue()));
}
});
alignment::populateWithGenerated(patterns);
canonicalization::populateWithGenerated(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
} // namespace tensor_ext
} // namespace heir
} // namespace mlir