Skip to content

Commit

Permalink
try adding tensor ext patterns for inserting rotations
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Mar 4, 2024
1 parent 5a0525d commit f881d16
Show file tree
Hide file tree
Showing 11 changed files with 230 additions and 0 deletions.
19 changes: 19 additions & 0 deletions include/Dialect/TensorExt/IR/BUILD
Expand Up @@ -11,6 +11,7 @@ exports_files(
[
"TensorExtDialect.h",
"TensorExtOps.h",
"TensorExtPatterns.h",
],
)

Expand Down Expand Up @@ -77,3 +78,21 @@ gentbl_cc_library(
"@heir//include/Dialect/Polynomial/IR:td_files",
],
)

gentbl_cc_library(
name = "patterns_inc_gen",
tbl_outs = [
(
["-gen-rewriters"],
"TensorExtPatterns.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "TensorExtPatterns.td",
deps = [
":ops_inc_gen",
":td_files",
"@llvm-project//mlir:ArithOpsTdFiles",
"@llvm-project//mlir:TensorOpsTdFiles",
],
)
38 changes: 38 additions & 0 deletions include/Dialect/TensorExt/IR/TensorExtPatterns.td
@@ -0,0 +1,38 @@
#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_

include "TensorExtOps.td"
include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"
include "mlir/IR/PatternBase.td"

def GetZeroAttr : NativeCodeCall<"$_builder.getIndexAttr(0)">;

def InsertRotations : Pattern<
(Arith_AddIOp
(Tensor_ExtractOp $t1, (variadic $i1)),
(Tensor_ExtractOp $t2, (variadic $i2)),
$overflow),
[
(Arith_ConstantOp:$zero (GetZeroAttr)),
(TensorExt_RotateOp:$r1 $t1, $i1),
(TensorExt_RotateOp:$r2 $t2, $i2),
(Arith_AddIOp:$addResult $r1, $r2, $overflow),
(Tensor_ExtractOp $addResult, $zero),
]
>;

#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_


// (Tensor_ExtractOp
// (Arith_AddIOp
// (Tensor_ExtractOp $t1, (variadic $i1)),
// (Tensor_ExtractOp $t2, (variadic $i2))),
// (variadic $targetSlot))
// [
// (TensorExt_RotateOp:$r1 $t1, (Arith_SubIOp $i1, $targetSlot)),
// (TensorExt_RotateOp:$r2 $t2, (Arith_SubIOp $i2, $targetSlot)),
// (Arith_AddIOp:$addResult $r1, $r2),
// (Tensor_ExtractOp $addResult, $targetSlot),
// ]
36 changes: 36 additions & 0 deletions include/Dialect/TensorExt/Transforms/BUILD
@@ -0,0 +1,36 @@
# InsertRotate tablegen and headers.

load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=TensorExt",
],
"Passes.h.inc",
),
(
["-gen-pass-doc"],
"TensorExtPasses.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Passes.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)

exports_files([
"Passes.h",
"InsertRotate.h",
])
17 changes: 17 additions & 0 deletions include/Dialect/TensorExt/Transforms/InsertRotate.h
@@ -0,0 +1,17 @@
#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_
#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace tensor_ext {

#define GEN_PASS_DECL_INSERTROTATE
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc"

} // namespace tensor_ext
} // namespace heir
} // namespace mlir

#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_
18 changes: 18 additions & 0 deletions include/Dialect/TensorExt/Transforms/Passes.h
@@ -0,0 +1,18 @@
#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_H_
#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_H_

#include "include/Dialect/TensorExt/IR/TensorExtDialect.h"
#include "include/Dialect/TensorExt/Transforms/InsertRotate.h"

namespace mlir {
namespace heir {
namespace tensor_ext {

#define GEN_PASS_REGISTRATION
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc"

} // namespace tensor_ext
} // namespace heir
} // namespace mlir

#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_H_
14 changes: 14 additions & 0 deletions include/Dialect/TensorExt/Transforms/Passes.td
@@ -0,0 +1,14 @@
#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_TD_
#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_TD_

include "mlir/Pass/PassBase.td"

def InsertRotate : Pass<"insert-rotate"> {
// FIXME: add add summary/description
let summary = "";
let description = [{
}];
let dependentDialects = ["mlir::heir::tensor_ext::TensorExtDialect"];
}

#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_TD_
35 changes: 35 additions & 0 deletions lib/Dialect/TensorExt/Transforms/BUILD
@@ -0,0 +1,35 @@
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "Transforms",
hdrs = [
"@heir//include/Dialect/TensorExt/Transforms:Passes.h",
],
deps = [
":InsertRotate",
"@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@llvm-project//mlir:IR",
],
)

cc_library(
name = "InsertRotate",
srcs = ["InsertRotate.cpp"],
hdrs = [
"@heir//include/Dialect/TensorExt/Transforms:InsertRotate.h",
],
deps = [
"@heir//include/Dialect/TensorExt/IR:patterns_inc_gen",
"@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)
33 changes: 33 additions & 0 deletions lib/Dialect/TensorExt/Transforms/InsertRotate.cpp
@@ -0,0 +1,33 @@
#include "include/Dialect/TensorExt/Transforms/InsertRotate.h"

#include "include/Dialect/TensorExt/IR/TensorExtOps.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project

// Required at the end to absorb includes above
#include "include/Dialect/TensorExt/IR/TensorExtPatterns.cpp.inc"

namespace mlir {
namespace heir {
namespace tensor_ext {

#define GEN_PASS_DEF_INSERTROTATE
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc"

struct InsertRotate : impl::InsertRotateBase<InsertRotate> {
using InsertRotateBase::InsertRotateBase;

void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);

patterns.add<InsertRotations>(context);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace tensor_ext
} // namespace heir
} // namespace mlir
17 changes: 17 additions & 0 deletions tests/tensor_ext/insert_rotations.mlir
@@ -0,0 +1,17 @@
// RUN: heir-opt --insert-rotate %s | FileCheck %s

func.func @test_insert_rotation_for_add(%arg1: tensor<16xi32>) -> tensor<16xi32> {
%c4 = arith.constant 4 : index
%c11 = arith.constant 11 : index
%c15 = arith.constant 15 : index

%extracted = tensor.extract %arg1[%c11] : tensor<16xi32>
%extracted_0 = tensor.extract %arg1[%c15] : tensor<16xi32>
%1 = arith.addi %extracted, %extracted_0 : i32

%extracted_1 = tensor.extract %arg1[%c4] : tensor<16xi32>
%2 = arith.addi %1, %extracted_1 : i32

%inserted = tensor.insert %2 into %arg1[%c4] : tensor<16xi32>
return %inserted : tensor<16xi32>
}
1 change: 1 addition & 0 deletions tools/BUILD
Expand Up @@ -53,6 +53,7 @@ cc_binary(
"@heir//lib/Dialect/Secret/Transforms",
"@heir//lib/Dialect/Secret/Transforms:DistributeGeneric",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@heir//lib/Dialect/TensorExt/Transforms",
"@heir//lib/Dialect/TfheRust/IR:Dialect",
"@heir//lib/Dialect/TfheRustBool/IR:Dialect",
"@heir//lib/Transforms/ForwardStoreToLoad",
Expand Down
2 changes: 2 additions & 0 deletions tools/heir-opt.cpp
Expand Up @@ -20,6 +20,7 @@
#include "include/Dialect/Secret/Transforms/DistributeGeneric.h"
#include "include/Dialect/Secret/Transforms/Passes.h"
#include "include/Dialect/TensorExt/IR/TensorExtDialect.h"
#include "include/Dialect/TensorExt/Transforms/Passes.h"
#include "include/Dialect/TfheRust/IR/TfheRustDialect.h"
#include "include/Dialect/TfheRustBool/IR/TfheRustBoolDialect.h"
#include "include/Transforms/ForwardStoreToLoad/ForwardStoreToLoad.h"
Expand Down Expand Up @@ -291,6 +292,7 @@ int main(int argc, char **argv) {
cggi::registerCGGIPasses();
lwe::registerLWEPasses();
secret::registerSecretPasses();
tensor_ext::registerTensorExtPasses();
registerSecretizePasses();
registerFullLoopUnrollPasses();
registerForwardStoreToLoadPasses();
Expand Down

0 comments on commit f881d16

Please sign in to comment.