Skip to content

Commit

Permalink
Adding cggi-tfhe-rs-bool pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Wouter Legiest authored and Wouter Legiest committed Mar 6, 2024
1 parent d0b5d80 commit da8ba6d
Show file tree
Hide file tree
Showing 19 changed files with 720 additions and 14 deletions.
7 changes: 6 additions & 1 deletion .gitignore
Expand Up @@ -15,4 +15,9 @@ hugo_stats.json
venv

# for rust codegen tests
Cargo.lock
**/Cargo.lock
tests/**/**/target/
tests/tfhe_rust_bool/end_to_end_fpga/

# vscode
.vscode/**
37 changes: 37 additions & 0 deletions include/Conversion/CGGIToTfheRustBool/BUILD
@@ -0,0 +1,37 @@
# CGGIToTfheRustBool tablegen and headers.

load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

exports_files(
[
"CGGIToTfheRustBool.h",
],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=CGGIToTfheRustBool",
],
"CGGIToTfheRustBool.h.inc",
),
(
["-gen-pass-doc"],
"CGGIToTfheRustBool.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "CGGIToTfheRustBool.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
16 changes: 16 additions & 0 deletions include/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.h
@@ -0,0 +1,16 @@
#ifndef INCLUDE_CONVERSION_CGGITOTFHERUSTBOOL_CGGITOTFHERUSTBOOL_H_
#define INCLUDE_CONVERSION_CGGITOTFHERUSTBOOL_CGGITOTFHERUSTBOOL_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir::heir {

#define GEN_PASS_DECL
#include "include/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.h.inc"

#define GEN_PASS_REGISTRATION
#include "include/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.h.inc"

} // namespace mlir::heir

#endif // INCLUDE_CONVERSION_CGGITOTFHERUSTBOOL_CGGITOTFHERUSTBOOL_H_
16 changes: 16 additions & 0 deletions include/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.td
@@ -0,0 +1,16 @@
#ifndef INCLUDE_CONVERSION_CGGITOTFHERUSTBOOL_CGGITOTFHERUSTBOOL_TD_
#define INCLUDE_CONVERSION_CGGITOTFHERUSTBOOL_CGGITOTFHERUSTBOOL_TD_

include "mlir/Pass/PassBase.td"

def CGGIToTfheRustBool : Pass<"cggi-to-tfhe-rust-bool"> {
let summary = "Lower `cggi` to `tfhe_rust_bool` dialect.";
let dependentDialects = [
"mlir::arith::ArithDialect",
"mlir::heir::cggi::CGGIDialect",
"mlir::heir::lwe::LWEDialect",
"mlir::heir::tfhe_rust_bool::TfheRustBoolDialect",
];
}

#endif // INCLUDE_CONVERSION_CGGITOTFHERUSTBOOL_CGGITOTFHERUSTBOOL_TD_
14 changes: 14 additions & 0 deletions include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td
Expand Up @@ -8,6 +8,7 @@ include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinTypes.td"


class TfheRustBool_Op<string mnemonic, list<Trait> traits = []> :
Expand Down Expand Up @@ -43,6 +44,19 @@ def XorOp : TfheRustBool_BinaryGateOp<"xor"> { let summary = "Logical XOR of two
def XnorOp : TfheRustBool_BinaryGateOp<"xnor"> { let summary = "Logical XNOR of two TFHE-rs Bool ciphertexts."; }


def AndPackedOp : TfheRustBool_Op<"and_packed", [
Pure,
AllTypesMatch<["lhs", "rhs", "output"]>
]> {
let arguments = (ins
TfheRustBool_ServerKey:$serverKey,
TensorOf<[TfheRustBool_Encrypted]>:$lhs,
TensorOf<[TfheRustBool_Encrypted]>:$rhs
);
let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output);
}


def NotOp : TfheRustBool_Op<"not", [
Pure,
AllTypesMatch<["input", "output"]>
Expand Down
2 changes: 2 additions & 0 deletions include/Target/TfheRustBool/TfheRustBoolEmitter.h
Expand Up @@ -59,6 +59,8 @@ class TfheRustBoolEmitter {
LogicalResult printOperation(XorOp op);
LogicalResult printOperation(XnorOp op);

LogicalResult printOperation(AndPackedOp op);

// Helpers for above
LogicalResult printSksMethod(::mlir::Value result, ::mlir::Value sks,
::mlir::ValueRange nonSksOperands,
Expand Down
28 changes: 28 additions & 0 deletions lib/Conversion/CGGIToTfheRustBool/BUILD
@@ -0,0 +1,28 @@
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "CGGIToTfheRustBool",
srcs = ["CGGIToTfheRustBool.cpp"],
hdrs = [
"@heir//include/Conversion/CGGIToTfheRustBool:CGGIToTfheRustBool.h",
],
deps = [
"@heir//include/Conversion/CGGIToTfheRustBool:pass_inc_gen",
"@heir//lib/Conversion:Utils",
"@heir//lib/Dialect/CGGI/IR:Dialect",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@heir//lib/Dialect/TfheRustBool/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)

0 comments on commit da8ba6d

Please sign in to comment.