Skip to content

Commit

Permalink
add TensorExt dialect with rotate op
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Mar 5, 2024
1 parent 1939b79 commit 9574c39
Show file tree
Hide file tree
Showing 12 changed files with 254 additions and 0 deletions.
79 changes: 79 additions & 0 deletions include/Dialect/TensorExt/IR/BUILD
@@ -0,0 +1,79 @@
# TensorExt tablegen and headers

load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

exports_files(
[
"TensorExtDialect.h",
"TensorExtOps.h",
],
)

td_library(
name = "td_files",
srcs = [
"TensorExtDialect.td",
"TensorExtOps.td",
],
# include from the heir-root to enable fully-qualified include-paths
includes = ["../../../.."],
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
],
)

gentbl_cc_library(
name = "dialect_inc_gen",
tbl_outs = [
(
[
"-gen-dialect-decls",
],
"TensorExtDialect.h.inc",
),
(
[
"-gen-dialect-defs",
],
"TensorExtDialect.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "TensorExtDialect.td",
deps = [
":td_files",
],
)

gentbl_cc_library(
name = "ops_inc_gen",
tbl_outs = [
(
["-gen-op-decls"],
"TensorExtOps.h.inc",
),
(
["-gen-op-defs"],
"TensorExtOps.cpp.inc",
),
(
["-gen-op-doc"],
"TensorExtOps.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "TensorExtOps.td",
deps = [
":dialect_inc_gen",
":td_files",
"@heir//include/Dialect/Polynomial/IR:td_files",
],
)
10 changes: 10 additions & 0 deletions include/Dialect/TensorExt/IR/TensorExtDialect.h
@@ -0,0 +1,10 @@
#ifndef HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_H_
#define HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_H_

#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project

// Generated headers (block clang-format from messing up order)
#include "include/Dialect/TensorExt/IR/TensorExtDialect.h.inc"

#endif // HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_H_
17 changes: 17 additions & 0 deletions include/Dialect/TensorExt/IR/TensorExtDialect.td
@@ -0,0 +1,17 @@
#ifndef HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_
#define HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_

include "mlir/IR/DialectBase.td"

def TensorExt_Dialect : Dialect {
let name = "tensor_ext";
let description = [{
The `tensor_ext` dialect contains operations on plaintext tensors that
correspond to the computation model of certain FHE schemes, but are
unlikely to be upstreamed to MLIR due to their specificity to FHE.
}];

let cppNamespace = "::mlir::heir::tensor_ext";
}

#endif // HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_
11 changes: 11 additions & 0 deletions include/Dialect/TensorExt/IR/TensorExtOps.h
@@ -0,0 +1,11 @@
#ifndef HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_
#define HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_

#include "include/Dialect/TensorExt/IR/TensorExtDialect.h"
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project

#define GET_OP_CLASSES
#include "include/Dialect/TensorExt/IR/TensorExtOps.h.inc"

#endif // HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_
40 changes: 40 additions & 0 deletions include/Dialect/TensorExt/IR/TensorExtOps.td
@@ -0,0 +1,40 @@
#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_TD_
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_TD_

include "include/Dialect/TensorExt/IR/TensorExtDialect.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"


class TensorExt_Op<string mnemonic, list<Trait> traits = []> :
Op<TensorExt_Dialect, mnemonic, traits> {
let cppNamespace = "::mlir::heir::tensor_ext";
}

def TensorExt_RotateOp : TensorExt_Op<"rotate", [Pure, AllTypesMatch<["tensor", "output"]>]> {
let summary = "Rotate a tensor some number of indices left.";
let description = [{
This op represents a left-rotation of a tensor by given number of indices.
Negative shift values are interpreted as right-rotations.

This corresponds to the `rotate` operation in arithmetic FHE schemes like
BGV.

Examples:

```mlir
%0 = ... : tensor<16xi32>
%c7 = arith.constant 7 : i32
%1 = tensor_ext.rotate %0, %c7 : tensor<16xi32>, i32
```
}];

let arguments = (ins AnyTensor:$tensor, SignlessIntegerLike:$shift);
let results = (outs AnyTensor:$output);
let assemblyFormat = "operands attr-dict `:` qualified(type($tensor)) `,` type($shift)";
}

#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_TD_
41 changes: 41 additions & 0 deletions lib/Dialect/TensorExt/IR/BUILD
@@ -0,0 +1,41 @@
# TensorExt dialect implementation

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "Dialect",
srcs = [
"TensorExtDialect.cpp",
],
hdrs = [
"@heir//include/Dialect/TensorExt/IR:TensorExtDialect.h",
"@heir//include/Dialect/TensorExt/IR:TensorExtOps.h",
],
deps = [
":TensorExtOps",
"@heir//include/Dialect/TensorExt/IR:dialect_inc_gen",
"@heir//include/Dialect/TensorExt/IR:ops_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)

cc_library(
name = "TensorExtOps",
srcs = [
"TensorExtOps.cpp",
],
hdrs = [
"@heir//include/Dialect/TensorExt/IR:TensorExtDialect.h",
"@heir//include/Dialect/TensorExt/IR:TensorExtOps.h",
],
deps = [
"@heir//include/Dialect/TensorExt/IR:dialect_inc_gen",
"@heir//include/Dialect/TensorExt/IR:ops_inc_gen",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
],
)
27 changes: 27 additions & 0 deletions lib/Dialect/TensorExt/IR/TensorExtDialect.cpp
@@ -0,0 +1,27 @@
#include "include/Dialect/TensorExt/IR/TensorExtDialect.h"

#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project

// NOLINTNEXTLINE(misc-include-cleaner): Required to define TensorExtOps
#include "include/Dialect/TensorExt/IR/TensorExtOps.h"

// Generated definitions
#include "include/Dialect/TensorExt/IR/TensorExtDialect.cpp.inc"

#define GET_OP_CLASSES
#include "include/Dialect/TensorExt/IR/TensorExtOps.cpp.inc"

namespace mlir {
namespace heir {
namespace tensor_ext {

void TensorExtDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "include/Dialect/TensorExt/IR/TensorExtOps.cpp.inc"
>();
}

} // namespace tensor_ext
} // namespace heir
} // namespace mlir
7 changes: 7 additions & 0 deletions lib/Dialect/TensorExt/IR/TensorExtOps.cpp
@@ -0,0 +1,7 @@
#include "include/Dialect/TensorExt/IR/TensorExtOps.h"

namespace mlir {
namespace heir {
namespace tensor_ext {} // namespace tensor_ext
} // namespace heir
} // namespace mlir
10 changes: 10 additions & 0 deletions tests/tensor_ext/BUILD
@@ -0,0 +1,10 @@
load("//bazel:lit.bzl", "glob_lit_tests")

package(default_applicable_licenses = ["@heir//:license"])

glob_lit_tests(
name = "all_tests",
data = ["@heir//tests:test_utilities"],
driver = "@heir//tests:run_lit.sh",
test_file_exts = ["mlir"],
)
9 changes: 9 additions & 0 deletions tests/tensor_ext/ops.mlir
@@ -0,0 +1,9 @@
// RUN: heir-opt %s

// Test for syntax

func.func @test_rotate(%0: tensor<16xi32>) -> tensor<16xi32> {
%c1 = arith.constant 1 : i32
%1 = tensor_ext.rotate %0, %c1 : tensor<16xi32>, i32
return %1 : tensor<16xi32>
}
1 change: 1 addition & 0 deletions tools/BUILD
Expand Up @@ -52,6 +52,7 @@ cc_binary(
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Dialect/Secret/Transforms",
"@heir//lib/Dialect/Secret/Transforms:DistributeGeneric",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@heir//lib/Dialect/TfheRust/IR:Dialect",
"@heir//lib/Dialect/TfheRustBool/IR:Dialect",
"@heir//lib/Transforms/ForwardStoreToLoad",
Expand Down
2 changes: 2 additions & 0 deletions tools/heir-opt.cpp
Expand Up @@ -19,6 +19,7 @@
#include "include/Dialect/Secret/IR/SecretDialect.h"
#include "include/Dialect/Secret/Transforms/DistributeGeneric.h"
#include "include/Dialect/Secret/Transforms/Passes.h"
#include "include/Dialect/TensorExt/IR/TensorExtDialect.h"
#include "include/Dialect/TfheRust/IR/TfheRustDialect.h"
#include "include/Dialect/TfheRustBool/IR/TfheRustBoolDialect.h"
#include "include/Transforms/ForwardStoreToLoad/ForwardStoreToLoad.h"
Expand Down Expand Up @@ -268,6 +269,7 @@ int main(int argc, char **argv) {
registry.insert<tfhe_rust::TfheRustDialect>();
registry.insert<tfhe_rust_bool::TfheRustBoolDialect>();
registry.insert<openfhe::OpenfheDialect>();
registry.insert<tensor_ext::TensorExtDialect>();

// Add expected MLIR dialects to the registry.
registry.insert<affine::AffineDialect>();
Expand Down

0 comments on commit 9574c39

Please sign in to comment.