Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
try adding tensor ext patterns for inserting rotations
- Loading branch information
Showing
11 changed files
with
230 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_ | ||
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_ | ||
|
||
include "TensorExtOps.td" | ||
include "mlir/Dialect/Arith/IR/ArithOps.td" | ||
include "mlir/Dialect/Tensor/IR/TensorOps.td" | ||
include "mlir/IR/PatternBase.td" | ||
|
||
def GetZeroAttr : NativeCodeCall<"$_builder.getIndexAttr(0)">; | ||
|
||
def InsertRotations : Pattern< | ||
(Arith_AddIOp | ||
(Tensor_ExtractOp $t1, (variadic $i1)), | ||
(Tensor_ExtractOp $t2, (variadic $i2)), | ||
$overflow), | ||
[ | ||
(Arith_ConstantOp:$zero (GetZeroAttr)), | ||
(TensorExt_RotateOp:$r1 $t1, $i1), | ||
(TensorExt_RotateOp:$r2 $t2, $i2), | ||
(Arith_AddIOp:$addResult $r1, $r2, $overflow), | ||
(Tensor_ExtractOp $addResult, $zero), | ||
] | ||
>; | ||
|
||
#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_ | ||
|
||
|
||
// (Tensor_ExtractOp | ||
// (Arith_AddIOp | ||
// (Tensor_ExtractOp $t1, (variadic $i1)), | ||
// (Tensor_ExtractOp $t2, (variadic $i2))), | ||
// (variadic $targetSlot)) | ||
// [ | ||
// (TensorExt_RotateOp:$r1 $t1, (Arith_SubIOp $i1, $targetSlot)), | ||
// (TensorExt_RotateOp:$r2 $t2, (Arith_SubIOp $i2, $targetSlot)), | ||
// (Arith_AddIOp:$addResult $r1, $r2), | ||
// (Tensor_ExtractOp $addResult, $targetSlot), | ||
// ] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# InsertRotate tablegen and headers. | ||
|
||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") | ||
|
||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
gentbl_cc_library( | ||
name = "pass_inc_gen", | ||
tbl_outs = [ | ||
( | ||
[ | ||
"-gen-pass-decls", | ||
"-name=TensorExt", | ||
], | ||
"Passes.h.inc", | ||
), | ||
( | ||
["-gen-pass-doc"], | ||
"TensorExtPasses.md", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "Passes.td", | ||
deps = [ | ||
"@llvm-project//mlir:OpBaseTdFiles", | ||
"@llvm-project//mlir:PassBaseTdFiles", | ||
], | ||
) | ||
|
||
exports_files([ | ||
"Passes.h", | ||
"InsertRotate.h", | ||
]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_ | ||
#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_ | ||
|
||
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace tensor_ext { | ||
|
||
#define GEN_PASS_DECL_INSERTROTATE | ||
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc" | ||
|
||
} // namespace tensor_ext | ||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_H_ | ||
#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_H_ | ||
|
||
#include "include/Dialect/TensorExt/IR/TensorExtDialect.h" | ||
#include "include/Dialect/TensorExt/Transforms/InsertRotate.h" | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace tensor_ext { | ||
|
||
#define GEN_PASS_REGISTRATION | ||
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc" | ||
|
||
} // namespace tensor_ext | ||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_TD_ | ||
#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_TD_ | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def InsertRotate : Pass<"insert-rotate"> { | ||
// FIXME: add add summary/description | ||
let summary = ""; | ||
let description = [{ | ||
}]; | ||
let dependentDialects = ["mlir::heir::tensor_ext::TensorExtDialect"]; | ||
} | ||
|
||
#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_TD_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
cc_library( | ||
name = "Transforms", | ||
hdrs = [ | ||
"@heir//include/Dialect/TensorExt/Transforms:Passes.h", | ||
], | ||
deps = [ | ||
":InsertRotate", | ||
"@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen", | ||
"@heir//lib/Dialect/TensorExt/IR:Dialect", | ||
"@llvm-project//mlir:IR", | ||
], | ||
) | ||
|
||
cc_library( | ||
name = "InsertRotate", | ||
srcs = ["InsertRotate.cpp"], | ||
hdrs = [ | ||
"@heir//include/Dialect/TensorExt/Transforms:InsertRotate.h", | ||
], | ||
deps = [ | ||
"@heir//include/Dialect/TensorExt/IR:patterns_inc_gen", | ||
"@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen", | ||
"@heir//lib/Dialect/TensorExt/IR:Dialect", | ||
"@llvm-project//mlir:ArithDialect", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:Pass", | ||
"@llvm-project//mlir:TensorDialect", | ||
"@llvm-project//mlir:Transforms", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#include "include/Dialect/TensorExt/Transforms/InsertRotate.h" | ||
|
||
#include "include/Dialect/TensorExt/IR/TensorExtOps.h" | ||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project | ||
|
||
// Required at the end to absorb includes above | ||
#include "include/Dialect/TensorExt/IR/TensorExtPatterns.cpp.inc" | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace tensor_ext { | ||
|
||
#define GEN_PASS_DEF_INSERTROTATE | ||
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc" | ||
|
||
struct InsertRotate : impl::InsertRotateBase<InsertRotate> { | ||
using InsertRotateBase::InsertRotateBase; | ||
|
||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
RewritePatternSet patterns(context); | ||
|
||
patterns.add<InsertRotations>(context); | ||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); | ||
} | ||
}; | ||
|
||
} // namespace tensor_ext | ||
} // namespace heir | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// RUN: heir-opt --insert-rotate %s | FileCheck %s | ||
|
||
func.func @test_insert_rotation_for_add(%arg1: tensor<16xi32>) -> tensor<16xi32> { | ||
%c4 = arith.constant 4 : index | ||
%c11 = arith.constant 11 : index | ||
%c15 = arith.constant 15 : index | ||
|
||
%extracted = tensor.extract %arg1[%c11] : tensor<16xi32> | ||
%extracted_0 = tensor.extract %arg1[%c15] : tensor<16xi32> | ||
%1 = arith.addi %extracted, %extracted_0 : i32 | ||
|
||
%extracted_1 = tensor.extract %arg1[%c4] : tensor<16xi32> | ||
%2 = arith.addi %1, %extracted_1 : i32 | ||
|
||
%inserted = tensor.insert %2 into %arg1[%c4] : tensor<16xi32> | ||
return %inserted : tensor<16xi32> | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters