/
ConversionPass.cpp.jinja
62 lines (49 loc) · 2.1 KB
/
ConversionPass.cpp.jinja
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include "include/Conversion/{{ pass_name }}/{{ pass_name }}.h"
#include "include/Dialect/{{ source_dialect_name }}/IR/{{ source_dialect_name }}Dialect.h"
#include "include/Dialect/{{ target_dialect_name }}/IR/{{ target_dialect_name }}Dialect.h"
#include "lib/Conversion/Utils.h"
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
namespace mlir::heir {
#define GEN_PASS_DEF_{{ pass_name | upper }}
#include "include/Conversion/{{ pass_name }}/{{ pass_name }}.h.inc"
// Remove this class if no type conversions are necessary
class {{ pass_name }}TypeConverter : public TypeConverter {
public:
{{ pass_name }}TypeConverter(MLIRContext *ctx) {
addConversion([](Type type) { return type; });
// FIXME: implement, replace FooType with the type that needs
// to be converted or remove this class
addConversion([ctx](FooType type) -> Type {
return type;
});
}
};
// FIXME: rename to Convert<OpName>Op
struct ConvertFooOp : public OpConversionPattern<FooOp> {
ConvertFooOp(mlir::MLIRContext *context)
: OpConversionPattern<FooOp>(context) {}
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
FooOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// FIXME: implement
return failure();
}
};
struct {{ pass_name }} : public impl::{{ pass_name }}Base<{{ pass_name }}> {
void runOnOperation() override {
MLIRContext *context = &getContext();
auto *module = getOperation();
{{ pass_name }}TypeConverter typeConverter(context);
RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<{{ target_dialect_namespace }}::{{ target_dialect_name }}Dialect>();
target.addIllegalDialect<{{ source_dialect_namespace }}::{{ source_dialect_name }}Dialect>();
patterns.add<ConvertFooOp>(typeConverter, context);
addStructuralConversionPatterns(typeConverter, patterns, target);
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace mlir::heir