From 9574c3983b039e105e8eee25221211f0f528e9ad Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Mon, 4 Mar 2024 09:29:09 -0800 Subject: [PATCH] add TensorExt dialect with rotate op --- include/Dialect/TensorExt/IR/BUILD | 79 +++++++++++++++++++ .../Dialect/TensorExt/IR/TensorExtDialect.h | 10 +++ .../Dialect/TensorExt/IR/TensorExtDialect.td | 17 ++++ include/Dialect/TensorExt/IR/TensorExtOps.h | 11 +++ include/Dialect/TensorExt/IR/TensorExtOps.td | 40 ++++++++++ lib/Dialect/TensorExt/IR/BUILD | 41 ++++++++++ lib/Dialect/TensorExt/IR/TensorExtDialect.cpp | 27 +++++++ lib/Dialect/TensorExt/IR/TensorExtOps.cpp | 7 ++ tests/tensor_ext/BUILD | 10 +++ tests/tensor_ext/ops.mlir | 9 +++ tools/BUILD | 1 + tools/heir-opt.cpp | 2 + 12 files changed, 254 insertions(+) create mode 100644 include/Dialect/TensorExt/IR/BUILD create mode 100644 include/Dialect/TensorExt/IR/TensorExtDialect.h create mode 100644 include/Dialect/TensorExt/IR/TensorExtDialect.td create mode 100644 include/Dialect/TensorExt/IR/TensorExtOps.h create mode 100644 include/Dialect/TensorExt/IR/TensorExtOps.td create mode 100644 lib/Dialect/TensorExt/IR/BUILD create mode 100644 lib/Dialect/TensorExt/IR/TensorExtDialect.cpp create mode 100644 lib/Dialect/TensorExt/IR/TensorExtOps.cpp create mode 100644 tests/tensor_ext/BUILD create mode 100644 tests/tensor_ext/ops.mlir diff --git a/include/Dialect/TensorExt/IR/BUILD b/include/Dialect/TensorExt/IR/BUILD new file mode 100644 index 000000000..3f546387e --- /dev/null +++ b/include/Dialect/TensorExt/IR/BUILD @@ -0,0 +1,79 @@ +# TensorExt tablegen and headers + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files( + [ + "TensorExtDialect.h", + "TensorExtOps.h", + ], +) + +td_library( + name = "td_files", + srcs = [ + "TensorExtDialect.td", + "TensorExtOps.td", + ], + # include from the heir-root to enable fully-qualified include-paths + includes = ["../../../.."], + deps = [ + "@llvm-project//mlir:BuiltinDialectTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "dialect_inc_gen", + tbl_outs = [ + ( + [ + "-gen-dialect-decls", + ], + "TensorExtDialect.h.inc", + ), + ( + [ + "-gen-dialect-defs", + ], + "TensorExtDialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "TensorExtDialect.td", + deps = [ + ":td_files", + ], +) + +gentbl_cc_library( + name = "ops_inc_gen", + tbl_outs = [ + ( + ["-gen-op-decls"], + "TensorExtOps.h.inc", + ), + ( + ["-gen-op-defs"], + "TensorExtOps.cpp.inc", + ), + ( + ["-gen-op-doc"], + "TensorExtOps.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "TensorExtOps.td", + deps = [ + ":dialect_inc_gen", + ":td_files", + "@heir//include/Dialect/Polynomial/IR:td_files", + ], +) diff --git a/include/Dialect/TensorExt/IR/TensorExtDialect.h b/include/Dialect/TensorExt/IR/TensorExtDialect.h new file mode 100644 index 000000000..a71742389 --- /dev/null +++ b/include/Dialect/TensorExt/IR/TensorExtDialect.h @@ -0,0 +1,10 @@ +#ifndef HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_H_ +#define HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_H_ + +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project + +// Generated headers (block clang-format from messing up order) +#include "include/Dialect/TensorExt/IR/TensorExtDialect.h.inc" + +#endif // HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_H_ diff --git a/include/Dialect/TensorExt/IR/TensorExtDialect.td b/include/Dialect/TensorExt/IR/TensorExtDialect.td new file mode 100644 index 000000000..218d00d6e --- /dev/null +++ b/include/Dialect/TensorExt/IR/TensorExtDialect.td @@ -0,0 +1,17 @@ +#ifndef HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_ +#define HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_ + +include "mlir/IR/DialectBase.td" + +def TensorExt_Dialect : Dialect { + let name = "tensor_ext"; + let description = [{ + The `tensor_ext` dialect contains operations on plaintext tensors that + correspond to the computation model of certain FHE schemes, but are + unlikely to be upstreamed to MLIR due to their specificity to FHE. + }]; + + let cppNamespace = "::mlir::heir::tensor_ext"; +} + +#endif // HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_ diff --git a/include/Dialect/TensorExt/IR/TensorExtOps.h b/include/Dialect/TensorExt/IR/TensorExtOps.h new file mode 100644 index 000000000..026a97b6c --- /dev/null +++ b/include/Dialect/TensorExt/IR/TensorExtOps.h @@ -0,0 +1,11 @@ +#ifndef HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_ +#define HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_ + +#include "include/Dialect/TensorExt/IR/TensorExtDialect.h" +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project + +#define GET_OP_CLASSES +#include "include/Dialect/TensorExt/IR/TensorExtOps.h.inc" + +#endif // HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_ diff --git a/include/Dialect/TensorExt/IR/TensorExtOps.td b/include/Dialect/TensorExt/IR/TensorExtOps.td new file mode 100644 index 000000000..eb56fc9aa --- /dev/null +++ b/include/Dialect/TensorExt/IR/TensorExtOps.td @@ -0,0 +1,40 @@ +#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_TD_ +#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_TD_ + +include "include/Dialect/TensorExt/IR/TensorExtDialect.td" +include "mlir/IR/BuiltinAttributes.td" +include "mlir/IR/CommonTypeConstraints.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + + +class TensorExt_Op traits = []> : + Op { + let cppNamespace = "::mlir::heir::tensor_ext"; +} + +def TensorExt_RotateOp : TensorExt_Op<"rotate", [Pure, AllTypesMatch<["tensor", "output"]>]> { + let summary = "Rotate a tensor some number of indices left."; + let description = [{ + This op represents a left-rotation of a tensor by given number of indices. + Negative shift values are interpreted as right-rotations. + + This corresponds to the `rotate` operation in arithmetic FHE schemes like + BGV. + + Examples: + + ```mlir + %0 = ... : tensor<16xi32> + %c7 = arith.constant 7 : i32 + %1 = tensor_ext.rotate %0, %c7 : tensor<16xi32>, i32 + ``` + }]; + + let arguments = (ins AnyTensor:$tensor, SignlessIntegerLike:$shift); + let results = (outs AnyTensor:$output); + let assemblyFormat = "operands attr-dict `:` qualified(type($tensor)) `,` type($shift)"; +} + +#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_TD_ diff --git a/lib/Dialect/TensorExt/IR/BUILD b/lib/Dialect/TensorExt/IR/BUILD new file mode 100644 index 000000000..211de3931 --- /dev/null +++ b/lib/Dialect/TensorExt/IR/BUILD @@ -0,0 +1,41 @@ +# TensorExt dialect implementation + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "Dialect", + srcs = [ + "TensorExtDialect.cpp", + ], + hdrs = [ + "@heir//include/Dialect/TensorExt/IR:TensorExtDialect.h", + "@heir//include/Dialect/TensorExt/IR:TensorExtOps.h", + ], + deps = [ + ":TensorExtOps", + "@heir//include/Dialect/TensorExt/IR:dialect_inc_gen", + "@heir//include/Dialect/TensorExt/IR:ops_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "TensorExtOps", + srcs = [ + "TensorExtOps.cpp", + ], + hdrs = [ + "@heir//include/Dialect/TensorExt/IR:TensorExtDialect.h", + "@heir//include/Dialect/TensorExt/IR:TensorExtOps.h", + ], + deps = [ + "@heir//include/Dialect/TensorExt/IR:dialect_inc_gen", + "@heir//include/Dialect/TensorExt/IR:ops_inc_gen", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + ], +) diff --git a/lib/Dialect/TensorExt/IR/TensorExtDialect.cpp b/lib/Dialect/TensorExt/IR/TensorExtDialect.cpp new file mode 100644 index 000000000..25f9e41c6 --- /dev/null +++ b/lib/Dialect/TensorExt/IR/TensorExtDialect.cpp @@ -0,0 +1,27 @@ +#include "include/Dialect/TensorExt/IR/TensorExtDialect.h" + +#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project + +// NOLINTNEXTLINE(misc-include-cleaner): Required to define TensorExtOps +#include "include/Dialect/TensorExt/IR/TensorExtOps.h" + +// Generated definitions +#include "include/Dialect/TensorExt/IR/TensorExtDialect.cpp.inc" + +#define GET_OP_CLASSES +#include "include/Dialect/TensorExt/IR/TensorExtOps.cpp.inc" + +namespace mlir { +namespace heir { +namespace tensor_ext { + +void TensorExtDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "include/Dialect/TensorExt/IR/TensorExtOps.cpp.inc" + >(); +} + +} // namespace tensor_ext +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/TensorExt/IR/TensorExtOps.cpp b/lib/Dialect/TensorExt/IR/TensorExtOps.cpp new file mode 100644 index 000000000..a7da6e668 --- /dev/null +++ b/lib/Dialect/TensorExt/IR/TensorExtOps.cpp @@ -0,0 +1,7 @@ +#include "include/Dialect/TensorExt/IR/TensorExtOps.h" + +namespace mlir { +namespace heir { +namespace tensor_ext {} // namespace tensor_ext +} // namespace heir +} // namespace mlir diff --git a/tests/tensor_ext/BUILD b/tests/tensor_ext/BUILD new file mode 100644 index 000000000..c571e6fc6 --- /dev/null +++ b/tests/tensor_ext/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/tensor_ext/ops.mlir b/tests/tensor_ext/ops.mlir new file mode 100644 index 000000000..ab7d39229 --- /dev/null +++ b/tests/tensor_ext/ops.mlir @@ -0,0 +1,9 @@ +// RUN: heir-opt %s + +// Test for syntax + +func.func @test_rotate(%0: tensor<16xi32>) -> tensor<16xi32> { + %c1 = arith.constant 1 : i32 + %1 = tensor_ext.rotate %0, %c1 : tensor<16xi32>, i32 + return %1 : tensor<16xi32> +} diff --git a/tools/BUILD b/tools/BUILD index 823635c91..f8b07c2dd 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -52,6 +52,7 @@ cc_binary( "@heir//lib/Dialect/Secret/IR:Dialect", "@heir//lib/Dialect/Secret/Transforms", "@heir//lib/Dialect/Secret/Transforms:DistributeGeneric", + "@heir//lib/Dialect/TensorExt/IR:Dialect", "@heir//lib/Dialect/TfheRust/IR:Dialect", "@heir//lib/Dialect/TfheRustBool/IR:Dialect", "@heir//lib/Transforms/ForwardStoreToLoad", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 1e9d54b42..0f6e90eb9 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -19,6 +19,7 @@ #include "include/Dialect/Secret/IR/SecretDialect.h" #include "include/Dialect/Secret/Transforms/DistributeGeneric.h" #include "include/Dialect/Secret/Transforms/Passes.h" +#include "include/Dialect/TensorExt/IR/TensorExtDialect.h" #include "include/Dialect/TfheRust/IR/TfheRustDialect.h" #include "include/Dialect/TfheRustBool/IR/TfheRustBoolDialect.h" #include "include/Transforms/ForwardStoreToLoad/ForwardStoreToLoad.h" @@ -268,6 +269,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); // Add expected MLIR dialects to the registry. registry.insert();