Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TensorExt dialect with rotate op #478

Merged
merged 1 commit into from Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
79 changes: 79 additions & 0 deletions include/Dialect/TensorExt/IR/BUILD
@@ -0,0 +1,79 @@
# TensorExt tablegen and headers

load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

exports_files(
[
"TensorExtDialect.h",
"TensorExtOps.h",
],
)

td_library(
name = "td_files",
srcs = [
"TensorExtDialect.td",
"TensorExtOps.td",
],
# include from the heir-root to enable fully-qualified include-paths
includes = ["../../../.."],
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
],
)

gentbl_cc_library(
name = "dialect_inc_gen",
tbl_outs = [
(
[
"-gen-dialect-decls",
],
"TensorExtDialect.h.inc",
),
(
[
"-gen-dialect-defs",
],
"TensorExtDialect.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "TensorExtDialect.td",
deps = [
":td_files",
],
)

gentbl_cc_library(
name = "ops_inc_gen",
tbl_outs = [
(
["-gen-op-decls"],
"TensorExtOps.h.inc",
),
(
["-gen-op-defs"],
"TensorExtOps.cpp.inc",
),
(
["-gen-op-doc"],
"TensorExtOps.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "TensorExtOps.td",
deps = [
":dialect_inc_gen",
":td_files",
"@heir//include/Dialect/Polynomial/IR:td_files",
],
)
10 changes: 10 additions & 0 deletions include/Dialect/TensorExt/IR/TensorExtDialect.h
@@ -0,0 +1,10 @@
#ifndef HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_H_
#define HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_H_

#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project

// Generated headers (block clang-format from messing up order)
#include "include/Dialect/TensorExt/IR/TensorExtDialect.h.inc"

#endif // HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_H_
17 changes: 17 additions & 0 deletions include/Dialect/TensorExt/IR/TensorExtDialect.td
@@ -0,0 +1,17 @@
#ifndef HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_
#define HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_

include "mlir/IR/DialectBase.td"

def TensorExt_Dialect : Dialect {
let name = "tensor_ext";
let description = [{
The `tensor_ext` dialect contains operations on plaintext tensors that
correspond to the computation model of certain FHE schemes, but are
unlikely to be upstreamed to MLIR due to their specificity to FHE.
}];

let cppNamespace = "::mlir::heir::tensor_ext";
}

#endif // HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_
11 changes: 11 additions & 0 deletions include/Dialect/TensorExt/IR/TensorExtOps.h
@@ -0,0 +1,11 @@
#ifndef HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_
#define HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_

#include "include/Dialect/TensorExt/IR/TensorExtDialect.h"
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project

#define GET_OP_CLASSES
#include "include/Dialect/TensorExt/IR/TensorExtOps.h.inc"

#endif // HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_
40 changes: 40 additions & 0 deletions include/Dialect/TensorExt/IR/TensorExtOps.td
@@ -0,0 +1,40 @@
#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_TD_
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_TD_

include "include/Dialect/TensorExt/IR/TensorExtDialect.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"


class TensorExt_Op<string mnemonic, list<Trait> traits = []> :
Op<TensorExt_Dialect, mnemonic, traits> {
let cppNamespace = "::mlir::heir::tensor_ext";
}

def TensorExt_RotateOp : TensorExt_Op<"rotate", [Pure, AllTypesMatch<["tensor", "output"]>]> {
let summary = "Rotate a tensor some number of indices left.";
let description = [{
This op represents a left-rotation of a tensor by given number of indices.
j2kun marked this conversation as resolved.
Show resolved Hide resolved
Negative shift values are interpreted as right-rotations.

This corresponds to the `rotate` operation in arithmetic FHE schemes like
BGV.

Examples:

```mlir
%0 = ... : tensor<16xi32>
%c7 = arith.constant 7 : i32
%1 = tensor_ext.rotate %0, %c7 : tensor<16xi32>, i32
```
}];

let arguments = (ins AnyTensor:$tensor, SignlessIntegerLike:$shift);
let results = (outs AnyTensor:$output);
let assemblyFormat = "operands attr-dict `:` qualified(type($tensor)) `,` type($shift)";
}

#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_TD_
41 changes: 41 additions & 0 deletions lib/Dialect/TensorExt/IR/BUILD
@@ -0,0 +1,41 @@
# TensorExt dialect implementation

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "Dialect",
srcs = [
"TensorExtDialect.cpp",
],
hdrs = [
"@heir//include/Dialect/TensorExt/IR:TensorExtDialect.h",
"@heir//include/Dialect/TensorExt/IR:TensorExtOps.h",
],
deps = [
":TensorExtOps",
"@heir//include/Dialect/TensorExt/IR:dialect_inc_gen",
"@heir//include/Dialect/TensorExt/IR:ops_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)

cc_library(
name = "TensorExtOps",
srcs = [
"TensorExtOps.cpp",
],
hdrs = [
"@heir//include/Dialect/TensorExt/IR:TensorExtDialect.h",
"@heir//include/Dialect/TensorExt/IR:TensorExtOps.h",
],
deps = [
"@heir//include/Dialect/TensorExt/IR:dialect_inc_gen",
"@heir//include/Dialect/TensorExt/IR:ops_inc_gen",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
],
)
27 changes: 27 additions & 0 deletions lib/Dialect/TensorExt/IR/TensorExtDialect.cpp
@@ -0,0 +1,27 @@
#include "include/Dialect/TensorExt/IR/TensorExtDialect.h"

#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project

// NOLINTNEXTLINE(misc-include-cleaner): Required to define TensorExtOps
#include "include/Dialect/TensorExt/IR/TensorExtOps.h"

// Generated definitions
#include "include/Dialect/TensorExt/IR/TensorExtDialect.cpp.inc"

#define GET_OP_CLASSES
#include "include/Dialect/TensorExt/IR/TensorExtOps.cpp.inc"

namespace mlir {
namespace heir {
namespace tensor_ext {

void TensorExtDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "include/Dialect/TensorExt/IR/TensorExtOps.cpp.inc"
>();
}

} // namespace tensor_ext
} // namespace heir
} // namespace mlir
7 changes: 7 additions & 0 deletions lib/Dialect/TensorExt/IR/TensorExtOps.cpp
@@ -0,0 +1,7 @@
#include "include/Dialect/TensorExt/IR/TensorExtOps.h"

namespace mlir {
namespace heir {
namespace tensor_ext {} // namespace tensor_ext
} // namespace heir
} // namespace mlir
10 changes: 10 additions & 0 deletions tests/tensor_ext/BUILD
@@ -0,0 +1,10 @@
load("//bazel:lit.bzl", "glob_lit_tests")

package(default_applicable_licenses = ["@heir//:license"])

glob_lit_tests(
name = "all_tests",
data = ["@heir//tests:test_utilities"],
driver = "@heir//tests:run_lit.sh",
test_file_exts = ["mlir"],
)
9 changes: 9 additions & 0 deletions tests/tensor_ext/ops.mlir
@@ -0,0 +1,9 @@
// RUN: heir-opt %s

// Test for syntax

func.func @test_rotate(%0: tensor<16xi32>) -> tensor<16xi32> {
%c1 = arith.constant 1 : i32
%1 = tensor_ext.rotate %0, %c1 : tensor<16xi32>, i32
return %1 : tensor<16xi32>
}
1 change: 1 addition & 0 deletions tools/BUILD
Expand Up @@ -52,6 +52,7 @@ cc_binary(
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Dialect/Secret/Transforms",
"@heir//lib/Dialect/Secret/Transforms:DistributeGeneric",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@heir//lib/Dialect/TfheRust/IR:Dialect",
"@heir//lib/Dialect/TfheRustBool/IR:Dialect",
"@heir//lib/Transforms/ForwardStoreToLoad",
Expand Down
2 changes: 2 additions & 0 deletions tools/heir-opt.cpp
Expand Up @@ -19,6 +19,7 @@
#include "include/Dialect/Secret/IR/SecretDialect.h"
#include "include/Dialect/Secret/Transforms/DistributeGeneric.h"
#include "include/Dialect/Secret/Transforms/Passes.h"
#include "include/Dialect/TensorExt/IR/TensorExtDialect.h"
#include "include/Dialect/TfheRust/IR/TfheRustDialect.h"
#include "include/Dialect/TfheRustBool/IR/TfheRustBoolDialect.h"
#include "include/Transforms/ForwardStoreToLoad/ForwardStoreToLoad.h"
Expand Down Expand Up @@ -268,6 +269,7 @@ int main(int argc, char **argv) {
registry.insert<tfhe_rust::TfheRustDialect>();
registry.insert<tfhe_rust_bool::TfheRustBoolDialect>();
registry.insert<openfhe::OpenfheDialect>();
registry.insert<tensor_ext::TensorExtDialect>();

// Add expected MLIR dialects to the registry.
registry.insert<affine::AffineDialect>();
Expand Down