From 63410a8c2a85bb3a6faf1ad348eee9098637748e Mon Sep 17 00:00:00 2001 From: Wouter Legiest Date: Tue, 12 Mar 2024 17:05:11 +0000 Subject: [PATCH 01/11] Updating the tests for e2e tfhe-rs-bool and starting with the tfhe-rs-bool fpga tests --- .gitignore | 2 +- .../TfheRustBool/IR/TfheRustBoolOps.td | 11 +++ tests/cggi_to_tfhe_rust_bool/add_bool.mlir | 2 - .../cggi_to_tfhe_rust_bool/add_one_bool.mlir | 2 - tests/tfhe_rust_bool/end_to_end/BUILD | 1 + tests/tfhe_rust_bool/end_to_end/Cargo.toml | 4 + .../end_to_end/src/main_bool_add.rs | 51 ++++++++++++ .../end_to_end/test_bool_add.mlir | 69 +++++++++++++++ tests/tfhe_rust_bool/end_to_end_fpga/BUILD | 23 +++++ .../tfhe_rust_bool/end_to_end_fpga/Cargo.toml | 21 +++++ .../tfhe_rust_bool/end_to_end_fpga/README.md | 41 +++++++++ .../end_to_end_fpga/src/main.rs | 83 +++++++++++++++++++ .../end_to_end_fpga/test_add_one_bool.mlir | 73 ++++++++++++++++ .../end_to_end_fpga/test_packed_and.mlir | 13 +++ 14 files changed, 391 insertions(+), 5 deletions(-) create mode 100644 tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs create mode 100644 tests/tfhe_rust_bool/end_to_end/test_bool_add.mlir create mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/BUILD create mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/Cargo.toml create mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/README.md create mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs create mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir create mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir diff --git a/.gitignore b/.gitignore index 5cee65586..318187095 100644 --- a/.gitignore +++ b/.gitignore @@ -17,7 +17,7 @@ venv # for rust codegen tests **/Cargo.lock tests/**/**/target/ -tests/tfhe_rust_bool/end_to_end_fpga/ +tests/tfhe_rust_bool/end_to_end_fpga/tfhe-rs # vscode .vscode/** diff --git a/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td b/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td index c711f3d93..92dd1033f 100644 --- a/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td +++ b/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td @@ -56,6 +56,17 @@ def AndPackedOp : TfheRustBool_Op<"and_packed", [ let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output); } +def XorPackedOp : TfheRustBool_Op<"xor_packed", [ + Pure, + AllTypesMatch<["lhs", "rhs", "output"]> +]> { + let arguments = (ins + TfheRustBool_ServerKey:$serverKey, + TensorOf<[TfheRustBool_Encrypted]>:$lhs, + TensorOf<[TfheRustBool_Encrypted]>:$rhs + ); + let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output); +} def NotOp : TfheRustBool_Op<"not", [ Pure, diff --git a/tests/cggi_to_tfhe_rust_bool/add_bool.mlir b/tests/cggi_to_tfhe_rust_bool/add_bool.mlir index 646c68e81..034d04e8c 100644 --- a/tests/cggi_to_tfhe_rust_bool/add_bool.mlir +++ b/tests/cggi_to_tfhe_rust_bool/add_bool.mlir @@ -1,11 +1,9 @@ // RUN: heir-opt --cggi-to-tfhe-rust-bool -cse -remove-dead-values %s | FileCheck %s - #encoding = #lwe.unspecified_bit_field_encoding !ct_ty = !lwe.lwe_ciphertext !pt_ty = !lwe.lwe_plaintext - // CHECK-LABEL: add_bool // CHECK-NOT: cggi // CHECK-NOT: lwe diff --git a/tests/cggi_to_tfhe_rust_bool/add_one_bool.mlir b/tests/cggi_to_tfhe_rust_bool/add_one_bool.mlir index 8a6e62cd8..6d4cc5fb0 100644 --- a/tests/cggi_to_tfhe_rust_bool/add_one_bool.mlir +++ b/tests/cggi_to_tfhe_rust_bool/add_one_bool.mlir @@ -1,11 +1,9 @@ // RUN: heir-opt --cggi-to-tfhe-rust-bool -cse -remove-dead-values %s | FileCheck %s - #encoding = #lwe.unspecified_bit_field_encoding !ct_ty = !lwe.lwe_ciphertext !pt_ty = !lwe.lwe_plaintext - // CHECK-LABEL: add_one_bool // CHECK-NOT: cggi // CHECK-NOT: lwe diff --git a/tests/tfhe_rust_bool/end_to_end/BUILD b/tests/tfhe_rust_bool/end_to_end/BUILD index a189be648..3c47de4e2 100644 --- a/tests/tfhe_rust_bool/end_to_end/BUILD +++ b/tests/tfhe_rust_bool/end_to_end/BUILD @@ -12,6 +12,7 @@ glob_lit_tests( data = [ "Cargo.toml", "src/main.rs", + "src/main_bool_add.rs", "@heir//tests:test_utilities", ], default_tags = [ diff --git a/tests/tfhe_rust_bool/end_to_end/Cargo.toml b/tests/tfhe_rust_bool/end_to_end/Cargo.toml index 975466a91..6e6fa0a3c 100644 --- a/tests/tfhe_rust_bool/end_to_end/Cargo.toml +++ b/tests/tfhe_rust_bool/end_to_end/Cargo.toml @@ -12,3 +12,7 @@ tfhe = { version = "0.4.1", features = ["boolean", "x86_64-unix"] } [[bin]] name = "main" path = "src/main.rs" + +[[bin]] +name = "main_bool_add" +path = "src/main_bool_add.rs" diff --git a/tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs b/tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs new file mode 100644 index 000000000..02f61b2ef --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs @@ -0,0 +1,51 @@ +use clap::Parser; +use tfhe::boolean::prelude::*; + +mod fn_under_test; + +// TODO(https://github.com/google/heir/issues/235): improve generality +#[derive(Parser, Debug)] +struct Args { + /// arguments to forward to function under test + #[arg(id = "input_1", index = 1, action)] + input1: u8, + + #[arg(id = "input_2", index = 2, action)] + input2: u8, +} + +// Encrypt a u8 +pub fn encrypt(value: u8, client_key: &ClientKey) -> Vec { + let arr: [u8; 8] = core::array::from_fn(|shift| (value >> shift) & 1 ); + + let res: Vec = arr.iter() + .map(|bit| client_key.encrypt(if *bit != 0u8 { true } else { false })) + .collect(); + res +} + +// Decrypt a u8 +pub fn decrypt(ciphertexts: &Vec, client_key: &ClientKey) -> u8 { + let mut accum = 0u8; + for (i, ct) in ciphertexts.iter().enumerate() { + let bit = client_key.decrypt(ct); + accum |= (bit as u8) << i; + } + accum.reverse_bits() + +} + +fn main() { + let flags = Args::parse(); + let (client_key, server_key) = tfhe::boolean::gen_keys(); + + let ct_1 = encrypt(flags.input1.into(), &client_key); + let ct_2 = encrypt(flags.input2.into(), &client_key); + + + let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2); + + let output = decrypt(&result, &client_key); + + println!("{:08b}", output); +} diff --git a/tests/tfhe_rust_bool/end_to_end/test_bool_add.mlir b/tests/tfhe_rust_bool/end_to_end/test_bool_add.mlir new file mode 100644 index 000000000..3fe76bfd0 --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end/test_bool_add.mlir @@ -0,0 +1,69 @@ +// RUN: heir-translate %s --emit-tfhe-rust-bool > %S/src/fn_under_test.rs +// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main_bool_add -- 15 3 | FileCheck %s + +!bsks = !tfhe_rust_bool.server_key +!eb = !tfhe_rust_bool.eb + +// CHECK: 00010010 +func.func @fn_under_test(%bsks : !bsks, %arg0: tensor<8x!eb>, %arg1: tensor<8x!eb>) -> tensor<8x!eb> { + %c7 = arith.constant 7 : index + %c6 = arith.constant 6 : index + %c5 = arith.constant 5 : index + %c4 = arith.constant 4 : index + %c3 = arith.constant 3 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %extracted_00 = tensor.extract %arg0[%c0] : tensor<8x!eb> + %extracted_01 = tensor.extract %arg0[%c1] : tensor<8x!eb> + %extracted_02 = tensor.extract %arg0[%c2] : tensor<8x!eb> + %extracted_03 = tensor.extract %arg0[%c3] : tensor<8x!eb> + %extracted_04 = tensor.extract %arg0[%c4] : tensor<8x!eb> + %extracted_05 = tensor.extract %arg0[%c5] : tensor<8x!eb> + %extracted_06 = tensor.extract %arg0[%c6] : tensor<8x!eb> + %extracted_07 = tensor.extract %arg0[%c7] : tensor<8x!eb> + %extracted_10 = tensor.extract %arg1[%c0] : tensor<8x!eb> + %extracted_11 = tensor.extract %arg1[%c1] : tensor<8x!eb> + %extracted_12 = tensor.extract %arg1[%c2] : tensor<8x!eb> + %extracted_13 = tensor.extract %arg1[%c3] : tensor<8x!eb> + %extracted_14 = tensor.extract %arg1[%c4] : tensor<8x!eb> + %extracted_15 = tensor.extract %arg1[%c5] : tensor<8x!eb> + %extracted_16 = tensor.extract %arg1[%c6] : tensor<8x!eb> + %extracted_17 = tensor.extract %arg1[%c7] : tensor<8x!eb> + %ha_s = tfhe_rust_bool.xor %bsks, %extracted_00, %extracted_10 : (!bsks, !eb, !eb) -> !eb + %ha_c = tfhe_rust_bool.and %bsks, %extracted_00, %extracted_10 : (!bsks, !eb, !eb) -> !eb + %fa0_1 = tfhe_rust_bool.xor %bsks, %extracted_01, %extracted_11 : (!bsks, !eb, !eb) -> !eb + %fa0_2 = tfhe_rust_bool.and %bsks, %extracted_01, %extracted_11 : (!bsks, !eb, !eb) -> !eb + %fa0_3 = tfhe_rust_bool.and %bsks, %fa0_1, %ha_c : (!bsks, !eb, !eb) -> !eb + %fa0_s = tfhe_rust_bool.xor %bsks, %fa0_1, %ha_c : (!bsks, !eb, !eb) -> !eb + %fa0_c = tfhe_rust_bool.xor %bsks, %fa0_2, %fa0_3 : (!bsks, !eb, !eb) -> !eb + %fa1_1 = tfhe_rust_bool.xor %bsks, %extracted_02, %extracted_12 : (!bsks, !eb, !eb) -> !eb + %fa1_2 = tfhe_rust_bool.and %bsks, %extracted_02, %extracted_12 : (!bsks, !eb, !eb) -> !eb + %fa1_3 = tfhe_rust_bool.and %bsks, %fa1_1, %fa0_c : (!bsks, !eb, !eb) -> !eb + %fa1_s = tfhe_rust_bool.xor %bsks, %fa1_1, %fa0_c : (!bsks, !eb, !eb) -> !eb + %fa1_c = tfhe_rust_bool.xor %bsks, %fa1_2, %fa1_3 : (!bsks, !eb, !eb) -> !eb + %fa2_1 = tfhe_rust_bool.xor %bsks, %extracted_03, %extracted_13 : (!bsks, !eb, !eb) -> !eb + %fa2_2 = tfhe_rust_bool.and %bsks, %extracted_03, %extracted_13 : (!bsks, !eb, !eb) -> !eb + %fa2_3 = tfhe_rust_bool.and %bsks, %fa2_1, %fa1_c : (!bsks, !eb, !eb) -> !eb + %fa2_s = tfhe_rust_bool.xor %bsks, %fa2_1, %fa1_c : (!bsks, !eb, !eb) -> !eb + %fa2_c = tfhe_rust_bool.xor %bsks, %fa2_2, %fa2_3 : (!bsks, !eb, !eb) -> !eb + %fa3_1 = tfhe_rust_bool.xor %bsks, %extracted_04, %extracted_14 : (!bsks, !eb, !eb) -> !eb + %fa3_2 = tfhe_rust_bool.and %bsks, %extracted_04, %extracted_14 : (!bsks, !eb, !eb) -> !eb + %fa3_3 = tfhe_rust_bool.and %bsks, %fa3_1, %fa2_c : (!bsks, !eb, !eb) -> !eb + %fa3_s = tfhe_rust_bool.xor %bsks, %fa3_1, %fa2_c : (!bsks, !eb, !eb) -> !eb + %fa3_c = tfhe_rust_bool.xor %bsks, %fa3_2, %fa3_3 : (!bsks, !eb, !eb) -> !eb + %fa4_1 = tfhe_rust_bool.xor %bsks, %extracted_05, %extracted_15 : (!bsks, !eb, !eb) -> !eb + %fa4_2 = tfhe_rust_bool.and %bsks, %extracted_05, %extracted_15 : (!bsks, !eb, !eb) -> !eb + %fa4_3 = tfhe_rust_bool.and %bsks, %fa4_1, %fa3_c : (!bsks, !eb, !eb) -> !eb + %fa4_s = tfhe_rust_bool.xor %bsks, %fa4_1, %fa3_c : (!bsks, !eb, !eb) -> !eb + %fa4_c = tfhe_rust_bool.xor %bsks, %fa4_2, %fa4_3 : (!bsks, !eb, !eb) -> !eb + %fa5_1 = tfhe_rust_bool.xor %bsks, %extracted_06, %extracted_16 : (!bsks, !eb, !eb) -> !eb + %fa5_2 = tfhe_rust_bool.and %bsks, %extracted_06, %extracted_16 : (!bsks, !eb, !eb) -> !eb + %fa5_3 = tfhe_rust_bool.and %bsks, %fa5_1, %fa4_c : (!bsks, !eb, !eb) -> !eb + %fa5_s = tfhe_rust_bool.xor %bsks, %fa5_1, %fa4_c : (!bsks, !eb, !eb) -> !eb + %fa5_c = tfhe_rust_bool.xor %bsks, %fa5_2, %fa5_3 : (!bsks, !eb, !eb) -> !eb + %fa6_1 = tfhe_rust_bool.xor %bsks, %extracted_07, %extracted_17 : (!bsks, !eb, !eb) -> !eb + %fa6_s = tfhe_rust_bool.xor %bsks, %fa6_1, %fa5_c : (!bsks, !eb, !eb) -> !eb + %from_elements = tensor.from_elements %fa6_s, %fa5_s, %fa4_s, %fa3_s, %fa2_s, %fa1_s, %fa0_s, %ha_s : tensor<8x!eb> + return %from_elements : tensor<8x!eb> +} diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/BUILD b/tests/tfhe_rust_bool/end_to_end_fpga/BUILD new file mode 100644 index 000000000..a189be648 --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/BUILD @@ -0,0 +1,23 @@ +# See README.md for setup required to run these tests + +load("//bazel:lit.bzl", "glob_lit_tests") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +glob_lit_tests( + name = "all_tests", + data = [ + "Cargo.toml", + "src/main.rs", + "@heir//tests:test_utilities", + ], + default_tags = [ + "manual", + "notap", + ], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/Cargo.toml b/tests/tfhe_rust_bool/end_to_end_fpga/Cargo.toml new file mode 100644 index 000000000..07a449700 --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "heir-tfhe-rust-integration-test" +version = "0.1.0" +edition = "2021" +default-run = "main" + +[dependencies] +clap = { version = "4.1.8", features = ["derive"] } +rayon = "1.6.1" +serde = { version = "1.0.152", features = ["derive"] } +tfhe = { path = "tfhe-rs/tfhe", features = [ + "boolean", + "x86_64-unix", +] } + +[features] +fpga = ["tfhe/fpga"] + +[[bin]] +name = "main" +path = "src/main.rs" diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/README.md b/tests/tfhe_rust_bool/end_to_end_fpga/README.md new file mode 100644 index 000000000..4eb3d1103 --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/README.md @@ -0,0 +1,41 @@ +# End to end Rust codegen tests - Boolean FPGA + +These tests exercise Rust codegen for the +[tfhe-rs](https://github.com/zama-ai/tfhe-rs) backend library, including +compiling the generated Rust source and running the resulting binary. This sets +tests are specifically of the boolean plaintexts and the accompanying library. + +To avoid introducing these large dependencies into the entire project, these +tests are manual, and require the system they're running on to have +[Cargo](https://doc.rust-lang.org/cargo/index.html) installed. During the test, +cargo will fetch and build the required dependencies, and `Cargo.toml` in this +directory effectively pins the version of `tfhe` supported. + +Use the following command to run the tests in this directory, where the default +Cargo home `$HOME/.cargo` may need to be replaced by your custom `$CARGO_HOME`, +if you overrode the default option when installing Cargo. + +```bash +bazel query "filter('.mlir.test$', //tests/tfhe_rust_bool/end_to_end/...)" \ + | xargs bazel test --sandbox_writable_path=$HOME/.cargo "$@" +``` + +The `manual` tag is added to the targets in this directory to ensure that they +are not run when someone runs a glob test like `bazel test //...`. + +If you don't do this correctly, you will see an error like this: + +``` +# .---command stderr------------ +# | Updating crates.io index +# | Downloading crates ... +# | Downloaded memoffset v0.9.0 +# | error: failed to download replaced source registry `crates-io` +# | +# | Caused by: +# | failed to open `/home/you/.cargo/registry/cache/index.crates.io-6f17d22bba15001f/memoffset-0.9.0.crate` +# | +# | Caused by: +# | Read-only file system (os error 30) +# `----------------------------- +``` diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs b/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs new file mode 100644 index 000000000..47392eaa9 --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs @@ -0,0 +1,83 @@ +use clap::Parser; +use tfhe::boolean::prelude::*; + +use tfhe::boolean::engine::BooleanEngine; +use tfhe::boolean::prelude::*; +use std::time::Instant; + +#[cfg(feature = "fpga")] +use tfhe::boolean::server_key::FpgaGates; + + +mod fn_under_test; + +// TODO(https://github.com/google/heir/issues/235): improve generality +#[derive(Parser, Debug)] +struct Args { + /// arguments to forward to function under test + #[arg(id = "input_1", index = 1, action)] + input1: u8, + + #[arg(id = "input_2", index = 2, action)] + input2: u8, +} + +// Encrypt a u8 +pub fn encrypt(value: u8, client_key: &ClientKey) -> Vec { + let arr: [u8; 8] = core::array::from_fn(|shift| (value >> shift) & 1 ); + + let res: Vec = arr.iter() + .map(|bit| client_key.encrypt(if *bit != 0u8 { true } else { false })) + .collect(); + res +} + +// Decrypt a u8 +pub fn decrypt(ciphertexts: &Vec, client_key: &ClientKey) -> u8 { + let mut accum = 0u8; + for (i, ct) in ciphertexts.iter().enumerate() { + let bit = client_key.decrypt(ct); + accum |= (bit as u8) << i; + } + accum + +} + +fn main() { + let flags = Args::parse(); + + let params; + let client_key; + + let mut boolean_engine = BooleanEngine::new(); + + #[cfg(feature = "fpga")] + { + params = tfhe::boolean::engine::fpga::parameters::DEFAULT_PARAMETERS_KS_PBS; + client_key = boolean_engine.create_client_key(*params); + } + + #[cfg(not(feature = "fpga"))] + { + params = tfhe::boolean::parameters::DEFAULT_PARAMETERS_KS_PBS; + client_key = boolean_engine.create_client_key(params); + } + + // generate the server key, only the SW needs this + let server_key = boolean_engine.create_server_key(&client_key); + + #[cfg(feature = "fpga")] + server_key.enable_fpga(params); + + let ct_1 = encrypt(flags.input1.into(), &client_key); + let ct_2 = encrypt(flags.input2.into(), &client_key); + + let ct_1= ct_1.iter().collect(); + let ct_2= ct_2.iter().collect(); + + let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2); + + let output = decrypt(&result, &client_key); + + println!("{:08b}", output); +} diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir b/tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir new file mode 100644 index 000000000..d70eb52f3 --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir @@ -0,0 +1,73 @@ +// RUN: heir-translate %s --emit-tfhe-rust-bool > %S/src/fn_under_test.rs +// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main_add_one -- 1 1 | FileCheck %s + +!bsks = !tfhe_rust_bool.server_key +!eb = !tfhe_rust_bool.eb + +// CHECK-LABEL: pub fn fn_under_test( +// CHECK-NEXT: [[bsks:v[0-9]+]]: &ServerKey, +// CHECK-NEXT: [[input1:v[0-9]+]]: &Vec, +// CHECK-NEXT: [[input2:v[0-9]+]]: &Vec, +// CHECK-NEXT: ) -> Vec { +func.func @fn_under_test(%bsks : !bsks, %arg0: tensor<8x!eb>, %arg1: tensor<8x!eb>) -> tensor<8x!eb> { + %c7 = arith.constant 7 : index + %c6 = arith.constant 6 : index + %c5 = arith.constant 5 : index + %c4 = arith.constant 4 : index + %c3 = arith.constant 3 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %extracted_00 = tensor.extract %arg0[%c0] : tensor<8x!eb> + %extracted_01 = tensor.extract %arg0[%c1] : tensor<8x!eb> + %extracted_02 = tensor.extract %arg0[%c2] : tensor<8x!eb> + %extracted_03 = tensor.extract %arg0[%c3] : tensor<8x!eb> + %extracted_04 = tensor.extract %arg0[%c4] : tensor<8x!eb> + %extracted_05 = tensor.extract %arg0[%c5] : tensor<8x!eb> + %extracted_06 = tensor.extract %arg0[%c6] : tensor<8x!eb> + %extracted_07 = tensor.extract %arg0[%c7] : tensor<8x!eb> + %extracted_10 = tensor.extract %arg1[%c0] : tensor<8x!eb> + %extracted_11 = tensor.extract %arg1[%c1] : tensor<8x!eb> + %extracted_12 = tensor.extract %arg1[%c2] : tensor<8x!eb> + %extracted_13 = tensor.extract %arg1[%c3] : tensor<8x!eb> + %extracted_14 = tensor.extract %arg1[%c4] : tensor<8x!eb> + %extracted_15 = tensor.extract %arg1[%c5] : tensor<8x!eb> + %extracted_16 = tensor.extract %arg1[%c6] : tensor<8x!eb> + %extracted_17 = tensor.extract %arg1[%c7] : tensor<8x!eb> + %ha_s = tfhe_rust_bool.xor %bsks, %extracted_00, %extracted_10 : (!bsks, !eb, !eb) -> !eb + %ha_c = tfhe_rust_bool.and %bsks, %extracted_00, %extracted_10 : (!bsks, !eb, !eb) -> !eb + %fa0_1 = tfhe_rust_bool.xor %bsks, %extracted_01, %extracted_11 : (!bsks, !eb, !eb) -> !eb + %fa0_2 = tfhe_rust_bool.and %bsks, %extracted_01, %extracted_11 : (!bsks, !eb, !eb) -> !eb + %fa0_3 = tfhe_rust_bool.and %bsks, %fa0_1, %ha_c : (!bsks, !eb, !eb) -> !eb + %fa0_s = tfhe_rust_bool.xor %bsks, %fa0_1, %ha_c : (!bsks, !eb, !eb) -> !eb + %fa0_c = tfhe_rust_bool.xor %bsks, %fa0_2, %fa0_3 : (!bsks, !eb, !eb) -> !eb + %fa1_1 = tfhe_rust_bool.xor %bsks, %extracted_02, %extracted_12 : (!bsks, !eb, !eb) -> !eb + %fa1_2 = tfhe_rust_bool.and %bsks, %extracted_02, %extracted_12 : (!bsks, !eb, !eb) -> !eb + %fa1_3 = tfhe_rust_bool.and %bsks, %fa1_1, %fa0_c : (!bsks, !eb, !eb) -> !eb + %fa1_s = tfhe_rust_bool.xor %bsks, %fa1_1, %fa0_c : (!bsks, !eb, !eb) -> !eb + %fa1_c = tfhe_rust_bool.xor %bsks, %fa1_2, %fa1_3 : (!bsks, !eb, !eb) -> !eb + %fa2_1 = tfhe_rust_bool.xor %bsks, %extracted_03, %extracted_13 : (!bsks, !eb, !eb) -> !eb + %fa2_2 = tfhe_rust_bool.and %bsks, %extracted_03, %extracted_13 : (!bsks, !eb, !eb) -> !eb + %fa2_3 = tfhe_rust_bool.and %bsks, %fa2_1, %fa1_c : (!bsks, !eb, !eb) -> !eb + %fa2_s = tfhe_rust_bool.xor %bsks, %fa2_1, %fa1_c : (!bsks, !eb, !eb) -> !eb + %fa2_c = tfhe_rust_bool.xor %bsks, %fa2_2, %fa2_3 : (!bsks, !eb, !eb) -> !eb + %fa3_1 = tfhe_rust_bool.xor %bsks, %extracted_04, %extracted_14 : (!bsks, !eb, !eb) -> !eb + %fa3_2 = tfhe_rust_bool.and %bsks, %extracted_04, %extracted_14 : (!bsks, !eb, !eb) -> !eb + %fa3_3 = tfhe_rust_bool.and %bsks, %fa3_1, %fa2_c : (!bsks, !eb, !eb) -> !eb + %fa3_s = tfhe_rust_bool.xor %bsks, %fa3_1, %fa2_c : (!bsks, !eb, !eb) -> !eb + %fa3_c = tfhe_rust_bool.xor %bsks, %fa3_2, %fa3_3 : (!bsks, !eb, !eb) -> !eb + %fa4_1 = tfhe_rust_bool.xor %bsks, %extracted_05, %extracted_15 : (!bsks, !eb, !eb) -> !eb + %fa4_2 = tfhe_rust_bool.and %bsks, %extracted_05, %extracted_15 : (!bsks, !eb, !eb) -> !eb + %fa4_3 = tfhe_rust_bool.and %bsks, %fa4_1, %fa3_c : (!bsks, !eb, !eb) -> !eb + %fa4_s = tfhe_rust_bool.xor %bsks, %fa4_1, %fa3_c : (!bsks, !eb, !eb) -> !eb + %fa4_c = tfhe_rust_bool.xor %bsks, %fa4_2, %fa4_3 : (!bsks, !eb, !eb) -> !eb + %fa5_1 = tfhe_rust_bool.xor %bsks, %extracted_06, %extracted_16 : (!bsks, !eb, !eb) -> !eb + %fa5_2 = tfhe_rust_bool.and %bsks, %extracted_06, %extracted_16 : (!bsks, !eb, !eb) -> !eb + %fa5_3 = tfhe_rust_bool.and %bsks, %fa5_1, %fa4_c : (!bsks, !eb, !eb) -> !eb + %fa5_s = tfhe_rust_bool.xor %bsks, %fa5_1, %fa4_c : (!bsks, !eb, !eb) -> !eb + %fa5_c = tfhe_rust_bool.xor %bsks, %fa5_2, %fa5_3 : (!bsks, !eb, !eb) -> !eb + %fa6_1 = tfhe_rust_bool.xor %bsks, %extracted_07, %extracted_17 : (!bsks, !eb, !eb) -> !eb + %fa6_s = tfhe_rust_bool.xor %bsks, %fa6_1, %fa5_c : (!bsks, !eb, !eb) -> !eb + %from_elements = tensor.from_elements %fa6_s, %fa5_s, %fa4_s, %fa3_s, %fa2_s, %fa1_s, %fa0_s, %ha_s : tensor<8x!eb> + return %from_elements : tensor<8x!eb> +} diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir b/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir new file mode 100644 index 000000000..8ecca7cad --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir @@ -0,0 +1,13 @@ +// This test ensures the testing harness is working properly with minimal codegen. + +// RUN: heir-translate %s --emit-tfhe-rust-bool > %S/src/fn_under_test.rs +// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main -- 1 1 | FileCheck %s + +!bsks = !tfhe_rust_bool.server_key +!eb = !tfhe_rust_bool.eb + +// CHECK: 1 +func.func @fn_under_test(%bsks : !bsks, %a: tensor<8x!eb>, %b: tensor<8x!eb>) -> tensor<8x!eb> { + %res = tfhe_rust_bool.and_packed %bsks, %a, %b: (!bsks, tensor<8x!eb>, tensor<8x!eb>) -> tensor<8x!eb> + return %res : tensor<8x!eb> +} From 4440b802e15942445592fd18ab314ada51f9529f Mon Sep 17 00:00:00 2001 From: Wouter Legiest Date: Mon, 18 Mar 2024 12:55:53 +0000 Subject: [PATCH 02/11] Update readme --- tests/tfhe_rust_bool/end_to_end_fpga/README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/README.md b/tests/tfhe_rust_bool/end_to_end_fpga/README.md index 4eb3d1103..378da0f17 100644 --- a/tests/tfhe_rust_bool/end_to_end_fpga/README.md +++ b/tests/tfhe_rust_bool/end_to_end_fpga/README.md @@ -5,6 +5,9 @@ These tests exercise Rust codegen for the compiling the generated Rust source and running the resulting binary. This sets tests are specifically of the boolean plaintexts and the accompanying library. +This specific e2e tests are designed for the [FPT](https://eprint.iacr.org/2022/1635) accelerator, made by COSIC. + + To avoid introducing these large dependencies into the entire project, these tests are manual, and require the system they're running on to have [Cargo](https://doc.rust-lang.org/cargo/index.html) installed. During the test, From 2af32bea894adfbc14989579330f11fc6385e035 Mon Sep 17 00:00:00 2001 From: Wouter Legiest Date: Mon, 18 Mar 2024 16:28:10 +0000 Subject: [PATCH 03/11] Starting fpga emitter; function header half correct --- lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp | 10 +++++++++- tests/tfhe_rust_bool/add_one_bool.mlir | 6 +++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp b/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp index 8894ce615..4ae346e8e 100644 --- a/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp +++ b/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp @@ -274,6 +274,11 @@ LogicalResult TfheRustBoolEmitter::printOperation(AndPackedOp op) { {op.getLhs(), op.getRhs()}, "and_packed"); } +LogicalResult TfheRustBoolEmitter::printOperation(XorPackedOp op) { + return printSksMethod(op.getResult(), op.getServerKey(), + {op.getLhs(), op.getRhs()}, "xor_packed"); +} + FailureOr TfheRustBoolEmitter::convertType(Type type) { // Note: these are probably not the right type names to use exactly, and they // will need to chance to the right values once we try to compile it against @@ -283,7 +288,10 @@ FailureOr TfheRustBoolEmitter::convertType(Type type) { // FIXME: why can't both types be FailureOr? auto elementTy = convertType(shapedType.getElementType()); if (failed(elementTy)) return failure(); - return std::string("Vec<" + elementTy.value() + ">"); + auto refprefix = + shapedType.getElementType().hasTrait() ? "&" : ""; + return std::string(std::string("Vec<") + refprefix + elementTy.value() + + ">"); } return llvm::TypeSwitch>(type) .Case( diff --git a/tests/tfhe_rust_bool/add_one_bool.mlir b/tests/tfhe_rust_bool/add_one_bool.mlir index e9018007d..cd828bb54 100644 --- a/tests/tfhe_rust_bool/add_one_bool.mlir +++ b/tests/tfhe_rust_bool/add_one_bool.mlir @@ -5,9 +5,9 @@ // CHECK-LABEL: pub fn fn_under_test( // CHECK-NEXT: [[bsks:v[0-9]+]]: &ServerKey, -// CHECK-NEXT: [[input1:v[0-9]+]]: &Vec, -// CHECK-NEXT: [[input2:v[0-9]+]]: &Vec, -// CHECK-NEXT: ) -> Vec { +// CHECK-NEXT: [[input1:v[0-9]+]]: &Vec<&Ciphertext>, +// CHECK-NEXT: [[input2:v[0-9]+]]: &Vec<&Ciphertext>, +// CHECK-NEXT: ) -> Vec<&Ciphertext> { func.func @fn_under_test(%bsks : !bsks, %arg0: tensor<8x!eb>, %arg1: tensor<8x!eb>) -> tensor<8x!eb> { %c7 = arith.constant 7 : index %c6 = arith.constant 6 : index From 3c818ba8907716b4df55ed06d1f6d38934651e39 Mon Sep 17 00:00:00 2001 From: Wouter Legiest Date: Tue, 12 Mar 2024 17:05:11 +0000 Subject: [PATCH 04/11] Updating the tests for e2e tfhe-rs-bool and starting with the tfhe-rs-bool fpga tests --- .gitignore | 2 +- .../TargetSlotAnalysis/TargetSlotAnalysis.h | 6 +- .../TfheRustBool/IR/TfheRustBoolOps.td | 11 +++ tests/cggi_to_tfhe_rust_bool/add_bool.mlir | 2 - .../cggi_to_tfhe_rust_bool/add_one_bool.mlir | 2 - tests/tfhe_rust_bool/end_to_end/BUILD | 1 + tests/tfhe_rust_bool/end_to_end/Cargo.toml | 4 + .../end_to_end/src/main_bool_add.rs | 51 ++++++++++++ .../end_to_end/test_bool_add.mlir | 69 +++++++++++++++ tests/tfhe_rust_bool/end_to_end_fpga/BUILD | 23 +++++ .../tfhe_rust_bool/end_to_end_fpga/Cargo.toml | 21 +++++ .../tfhe_rust_bool/end_to_end_fpga/README.md | 44 ++++++++++ .../end_to_end_fpga/src/main.rs | 83 +++++++++++++++++++ .../end_to_end_fpga/test_add_one_bool.mlir | 73 ++++++++++++++++ .../end_to_end_fpga/test_packed_and.mlir | 13 +++ 15 files changed, 397 insertions(+), 8 deletions(-) create mode 100644 tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs create mode 100644 tests/tfhe_rust_bool/end_to_end/test_bool_add.mlir create mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/BUILD create mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/Cargo.toml create mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/README.md create mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs create mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir create mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir diff --git a/.gitignore b/.gitignore index 5cee65586..318187095 100644 --- a/.gitignore +++ b/.gitignore @@ -17,7 +17,7 @@ venv # for rust codegen tests **/Cargo.lock tests/**/**/target/ -tests/tfhe_rust_bool/end_to_end_fpga/ +tests/tfhe_rust_bool/end_to_end_fpga/tfhe-rs # vscode .vscode/** diff --git a/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h b/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h index cc7487a98..984fb338c 100644 --- a/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h +++ b/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h @@ -122,9 +122,9 @@ class TargetSlotAnalysis void visitOperation(Operation *op, ArrayRef operands, ArrayRef results) override; - void visitBranchOperand(OpOperand &operand) override {}; - void visitCallOperand(OpOperand &operand) override {}; - void setToExitState(TargetSlotLattice *lattice) override {}; + void visitBranchOperand(OpOperand &operand) override{}; + void visitCallOperand(OpOperand &operand) override{}; + void setToExitState(TargetSlotLattice *lattice) override{}; }; } // namespace target_slot_analysis diff --git a/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td b/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td index c711f3d93..92dd1033f 100644 --- a/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td +++ b/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td @@ -56,6 +56,17 @@ def AndPackedOp : TfheRustBool_Op<"and_packed", [ let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output); } +def XorPackedOp : TfheRustBool_Op<"xor_packed", [ + Pure, + AllTypesMatch<["lhs", "rhs", "output"]> +]> { + let arguments = (ins + TfheRustBool_ServerKey:$serverKey, + TensorOf<[TfheRustBool_Encrypted]>:$lhs, + TensorOf<[TfheRustBool_Encrypted]>:$rhs + ); + let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output); +} def NotOp : TfheRustBool_Op<"not", [ Pure, diff --git a/tests/cggi_to_tfhe_rust_bool/add_bool.mlir b/tests/cggi_to_tfhe_rust_bool/add_bool.mlir index 646c68e81..034d04e8c 100644 --- a/tests/cggi_to_tfhe_rust_bool/add_bool.mlir +++ b/tests/cggi_to_tfhe_rust_bool/add_bool.mlir @@ -1,11 +1,9 @@ // RUN: heir-opt --cggi-to-tfhe-rust-bool -cse -remove-dead-values %s | FileCheck %s - #encoding = #lwe.unspecified_bit_field_encoding !ct_ty = !lwe.lwe_ciphertext !pt_ty = !lwe.lwe_plaintext - // CHECK-LABEL: add_bool // CHECK-NOT: cggi // CHECK-NOT: lwe diff --git a/tests/cggi_to_tfhe_rust_bool/add_one_bool.mlir b/tests/cggi_to_tfhe_rust_bool/add_one_bool.mlir index 8a6e62cd8..6d4cc5fb0 100644 --- a/tests/cggi_to_tfhe_rust_bool/add_one_bool.mlir +++ b/tests/cggi_to_tfhe_rust_bool/add_one_bool.mlir @@ -1,11 +1,9 @@ // RUN: heir-opt --cggi-to-tfhe-rust-bool -cse -remove-dead-values %s | FileCheck %s - #encoding = #lwe.unspecified_bit_field_encoding !ct_ty = !lwe.lwe_ciphertext !pt_ty = !lwe.lwe_plaintext - // CHECK-LABEL: add_one_bool // CHECK-NOT: cggi // CHECK-NOT: lwe diff --git a/tests/tfhe_rust_bool/end_to_end/BUILD b/tests/tfhe_rust_bool/end_to_end/BUILD index a189be648..3c47de4e2 100644 --- a/tests/tfhe_rust_bool/end_to_end/BUILD +++ b/tests/tfhe_rust_bool/end_to_end/BUILD @@ -12,6 +12,7 @@ glob_lit_tests( data = [ "Cargo.toml", "src/main.rs", + "src/main_bool_add.rs", "@heir//tests:test_utilities", ], default_tags = [ diff --git a/tests/tfhe_rust_bool/end_to_end/Cargo.toml b/tests/tfhe_rust_bool/end_to_end/Cargo.toml index 975466a91..6e6fa0a3c 100644 --- a/tests/tfhe_rust_bool/end_to_end/Cargo.toml +++ b/tests/tfhe_rust_bool/end_to_end/Cargo.toml @@ -12,3 +12,7 @@ tfhe = { version = "0.4.1", features = ["boolean", "x86_64-unix"] } [[bin]] name = "main" path = "src/main.rs" + +[[bin]] +name = "main_bool_add" +path = "src/main_bool_add.rs" diff --git a/tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs b/tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs new file mode 100644 index 000000000..02f61b2ef --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs @@ -0,0 +1,51 @@ +use clap::Parser; +use tfhe::boolean::prelude::*; + +mod fn_under_test; + +// TODO(https://github.com/google/heir/issues/235): improve generality +#[derive(Parser, Debug)] +struct Args { + /// arguments to forward to function under test + #[arg(id = "input_1", index = 1, action)] + input1: u8, + + #[arg(id = "input_2", index = 2, action)] + input2: u8, +} + +// Encrypt a u8 +pub fn encrypt(value: u8, client_key: &ClientKey) -> Vec { + let arr: [u8; 8] = core::array::from_fn(|shift| (value >> shift) & 1 ); + + let res: Vec = arr.iter() + .map(|bit| client_key.encrypt(if *bit != 0u8 { true } else { false })) + .collect(); + res +} + +// Decrypt a u8 +pub fn decrypt(ciphertexts: &Vec, client_key: &ClientKey) -> u8 { + let mut accum = 0u8; + for (i, ct) in ciphertexts.iter().enumerate() { + let bit = client_key.decrypt(ct); + accum |= (bit as u8) << i; + } + accum.reverse_bits() + +} + +fn main() { + let flags = Args::parse(); + let (client_key, server_key) = tfhe::boolean::gen_keys(); + + let ct_1 = encrypt(flags.input1.into(), &client_key); + let ct_2 = encrypt(flags.input2.into(), &client_key); + + + let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2); + + let output = decrypt(&result, &client_key); + + println!("{:08b}", output); +} diff --git a/tests/tfhe_rust_bool/end_to_end/test_bool_add.mlir b/tests/tfhe_rust_bool/end_to_end/test_bool_add.mlir new file mode 100644 index 000000000..3fe76bfd0 --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end/test_bool_add.mlir @@ -0,0 +1,69 @@ +// RUN: heir-translate %s --emit-tfhe-rust-bool > %S/src/fn_under_test.rs +// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main_bool_add -- 15 3 | FileCheck %s + +!bsks = !tfhe_rust_bool.server_key +!eb = !tfhe_rust_bool.eb + +// CHECK: 00010010 +func.func @fn_under_test(%bsks : !bsks, %arg0: tensor<8x!eb>, %arg1: tensor<8x!eb>) -> tensor<8x!eb> { + %c7 = arith.constant 7 : index + %c6 = arith.constant 6 : index + %c5 = arith.constant 5 : index + %c4 = arith.constant 4 : index + %c3 = arith.constant 3 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %extracted_00 = tensor.extract %arg0[%c0] : tensor<8x!eb> + %extracted_01 = tensor.extract %arg0[%c1] : tensor<8x!eb> + %extracted_02 = tensor.extract %arg0[%c2] : tensor<8x!eb> + %extracted_03 = tensor.extract %arg0[%c3] : tensor<8x!eb> + %extracted_04 = tensor.extract %arg0[%c4] : tensor<8x!eb> + %extracted_05 = tensor.extract %arg0[%c5] : tensor<8x!eb> + %extracted_06 = tensor.extract %arg0[%c6] : tensor<8x!eb> + %extracted_07 = tensor.extract %arg0[%c7] : tensor<8x!eb> + %extracted_10 = tensor.extract %arg1[%c0] : tensor<8x!eb> + %extracted_11 = tensor.extract %arg1[%c1] : tensor<8x!eb> + %extracted_12 = tensor.extract %arg1[%c2] : tensor<8x!eb> + %extracted_13 = tensor.extract %arg1[%c3] : tensor<8x!eb> + %extracted_14 = tensor.extract %arg1[%c4] : tensor<8x!eb> + %extracted_15 = tensor.extract %arg1[%c5] : tensor<8x!eb> + %extracted_16 = tensor.extract %arg1[%c6] : tensor<8x!eb> + %extracted_17 = tensor.extract %arg1[%c7] : tensor<8x!eb> + %ha_s = tfhe_rust_bool.xor %bsks, %extracted_00, %extracted_10 : (!bsks, !eb, !eb) -> !eb + %ha_c = tfhe_rust_bool.and %bsks, %extracted_00, %extracted_10 : (!bsks, !eb, !eb) -> !eb + %fa0_1 = tfhe_rust_bool.xor %bsks, %extracted_01, %extracted_11 : (!bsks, !eb, !eb) -> !eb + %fa0_2 = tfhe_rust_bool.and %bsks, %extracted_01, %extracted_11 : (!bsks, !eb, !eb) -> !eb + %fa0_3 = tfhe_rust_bool.and %bsks, %fa0_1, %ha_c : (!bsks, !eb, !eb) -> !eb + %fa0_s = tfhe_rust_bool.xor %bsks, %fa0_1, %ha_c : (!bsks, !eb, !eb) -> !eb + %fa0_c = tfhe_rust_bool.xor %bsks, %fa0_2, %fa0_3 : (!bsks, !eb, !eb) -> !eb + %fa1_1 = tfhe_rust_bool.xor %bsks, %extracted_02, %extracted_12 : (!bsks, !eb, !eb) -> !eb + %fa1_2 = tfhe_rust_bool.and %bsks, %extracted_02, %extracted_12 : (!bsks, !eb, !eb) -> !eb + %fa1_3 = tfhe_rust_bool.and %bsks, %fa1_1, %fa0_c : (!bsks, !eb, !eb) -> !eb + %fa1_s = tfhe_rust_bool.xor %bsks, %fa1_1, %fa0_c : (!bsks, !eb, !eb) -> !eb + %fa1_c = tfhe_rust_bool.xor %bsks, %fa1_2, %fa1_3 : (!bsks, !eb, !eb) -> !eb + %fa2_1 = tfhe_rust_bool.xor %bsks, %extracted_03, %extracted_13 : (!bsks, !eb, !eb) -> !eb + %fa2_2 = tfhe_rust_bool.and %bsks, %extracted_03, %extracted_13 : (!bsks, !eb, !eb) -> !eb + %fa2_3 = tfhe_rust_bool.and %bsks, %fa2_1, %fa1_c : (!bsks, !eb, !eb) -> !eb + %fa2_s = tfhe_rust_bool.xor %bsks, %fa2_1, %fa1_c : (!bsks, !eb, !eb) -> !eb + %fa2_c = tfhe_rust_bool.xor %bsks, %fa2_2, %fa2_3 : (!bsks, !eb, !eb) -> !eb + %fa3_1 = tfhe_rust_bool.xor %bsks, %extracted_04, %extracted_14 : (!bsks, !eb, !eb) -> !eb + %fa3_2 = tfhe_rust_bool.and %bsks, %extracted_04, %extracted_14 : (!bsks, !eb, !eb) -> !eb + %fa3_3 = tfhe_rust_bool.and %bsks, %fa3_1, %fa2_c : (!bsks, !eb, !eb) -> !eb + %fa3_s = tfhe_rust_bool.xor %bsks, %fa3_1, %fa2_c : (!bsks, !eb, !eb) -> !eb + %fa3_c = tfhe_rust_bool.xor %bsks, %fa3_2, %fa3_3 : (!bsks, !eb, !eb) -> !eb + %fa4_1 = tfhe_rust_bool.xor %bsks, %extracted_05, %extracted_15 : (!bsks, !eb, !eb) -> !eb + %fa4_2 = tfhe_rust_bool.and %bsks, %extracted_05, %extracted_15 : (!bsks, !eb, !eb) -> !eb + %fa4_3 = tfhe_rust_bool.and %bsks, %fa4_1, %fa3_c : (!bsks, !eb, !eb) -> !eb + %fa4_s = tfhe_rust_bool.xor %bsks, %fa4_1, %fa3_c : (!bsks, !eb, !eb) -> !eb + %fa4_c = tfhe_rust_bool.xor %bsks, %fa4_2, %fa4_3 : (!bsks, !eb, !eb) -> !eb + %fa5_1 = tfhe_rust_bool.xor %bsks, %extracted_06, %extracted_16 : (!bsks, !eb, !eb) -> !eb + %fa5_2 = tfhe_rust_bool.and %bsks, %extracted_06, %extracted_16 : (!bsks, !eb, !eb) -> !eb + %fa5_3 = tfhe_rust_bool.and %bsks, %fa5_1, %fa4_c : (!bsks, !eb, !eb) -> !eb + %fa5_s = tfhe_rust_bool.xor %bsks, %fa5_1, %fa4_c : (!bsks, !eb, !eb) -> !eb + %fa5_c = tfhe_rust_bool.xor %bsks, %fa5_2, %fa5_3 : (!bsks, !eb, !eb) -> !eb + %fa6_1 = tfhe_rust_bool.xor %bsks, %extracted_07, %extracted_17 : (!bsks, !eb, !eb) -> !eb + %fa6_s = tfhe_rust_bool.xor %bsks, %fa6_1, %fa5_c : (!bsks, !eb, !eb) -> !eb + %from_elements = tensor.from_elements %fa6_s, %fa5_s, %fa4_s, %fa3_s, %fa2_s, %fa1_s, %fa0_s, %ha_s : tensor<8x!eb> + return %from_elements : tensor<8x!eb> +} diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/BUILD b/tests/tfhe_rust_bool/end_to_end_fpga/BUILD new file mode 100644 index 000000000..a189be648 --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/BUILD @@ -0,0 +1,23 @@ +# See README.md for setup required to run these tests + +load("//bazel:lit.bzl", "glob_lit_tests") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +glob_lit_tests( + name = "all_tests", + data = [ + "Cargo.toml", + "src/main.rs", + "@heir//tests:test_utilities", + ], + default_tags = [ + "manual", + "notap", + ], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/Cargo.toml b/tests/tfhe_rust_bool/end_to_end_fpga/Cargo.toml new file mode 100644 index 000000000..07a449700 --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "heir-tfhe-rust-integration-test" +version = "0.1.0" +edition = "2021" +default-run = "main" + +[dependencies] +clap = { version = "4.1.8", features = ["derive"] } +rayon = "1.6.1" +serde = { version = "1.0.152", features = ["derive"] } +tfhe = { path = "tfhe-rs/tfhe", features = [ + "boolean", + "x86_64-unix", +] } + +[features] +fpga = ["tfhe/fpga"] + +[[bin]] +name = "main" +path = "src/main.rs" diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/README.md b/tests/tfhe_rust_bool/end_to_end_fpga/README.md new file mode 100644 index 000000000..42be8b572 --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/README.md @@ -0,0 +1,44 @@ +# End to end Rust codegen tests - Boolean FPGA + +These tests exercise Rust codegen for the +[tfhe-rs](https://github.com/zama-ai/tfhe-rs) backend library, including +compiling the generated Rust source and running the resulting binary. This sets +tests are specifically of the boolean plaintexts, accompanying COSIC-KU Leuven version of the library, +and the [FPT-accelerator](https://eprint.iacr.org/2022/1635). + +> :warning: Not possible to run these tests without the COSIC extension of TFHE-rs and FPT-accelerator + +To avoid introducing these large dependencies into the entire project, these +tests are manual, and require the system they're running on to have +[Cargo](https://doc.rust-lang.org/cargo/index.html) installed. During the test, +cargo will fetch and build the required dependencies, and `Cargo.toml` in this +directory effectively pins the version of `tfhe` supported. + +Use the following command to run the tests in this directory, where the default +Cargo home `$HOME/.cargo` may need to be replaced by your custom `$CARGO_HOME`, +if you overrode the default option when installing Cargo. + +```bash +bazel query "filter('.mlir.test$', //tests/tfhe_rust_bool/end_to_end/...)" \ + | xargs bazel test --sandbox_writable_path=$HOME/.cargo "$@" +``` + +The `manual` tag is added to the targets in this directory to ensure that they +are not run when someone runs a glob test like `bazel test //...`. + +If you don't do this correctly, you will see an error like this: + +``` +# .---command stderr------------ +# | Updating crates.io index +# | Downloading crates ... +# | Downloaded memoffset v0.9.0 +# | error: failed to download replaced source registry `crates-io` +# | +# | Caused by: +# | failed to open `/home/you/.cargo/registry/cache/index.crates.io-6f17d22bba15001f/memoffset-0.9.0.crate` +# | +# | Caused by: +# | Read-only file system (os error 30) +# `----------------------------- +``` diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs b/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs new file mode 100644 index 000000000..47392eaa9 --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs @@ -0,0 +1,83 @@ +use clap::Parser; +use tfhe::boolean::prelude::*; + +use tfhe::boolean::engine::BooleanEngine; +use tfhe::boolean::prelude::*; +use std::time::Instant; + +#[cfg(feature = "fpga")] +use tfhe::boolean::server_key::FpgaGates; + + +mod fn_under_test; + +// TODO(https://github.com/google/heir/issues/235): improve generality +#[derive(Parser, Debug)] +struct Args { + /// arguments to forward to function under test + #[arg(id = "input_1", index = 1, action)] + input1: u8, + + #[arg(id = "input_2", index = 2, action)] + input2: u8, +} + +// Encrypt a u8 +pub fn encrypt(value: u8, client_key: &ClientKey) -> Vec { + let arr: [u8; 8] = core::array::from_fn(|shift| (value >> shift) & 1 ); + + let res: Vec = arr.iter() + .map(|bit| client_key.encrypt(if *bit != 0u8 { true } else { false })) + .collect(); + res +} + +// Decrypt a u8 +pub fn decrypt(ciphertexts: &Vec, client_key: &ClientKey) -> u8 { + let mut accum = 0u8; + for (i, ct) in ciphertexts.iter().enumerate() { + let bit = client_key.decrypt(ct); + accum |= (bit as u8) << i; + } + accum + +} + +fn main() { + let flags = Args::parse(); + + let params; + let client_key; + + let mut boolean_engine = BooleanEngine::new(); + + #[cfg(feature = "fpga")] + { + params = tfhe::boolean::engine::fpga::parameters::DEFAULT_PARAMETERS_KS_PBS; + client_key = boolean_engine.create_client_key(*params); + } + + #[cfg(not(feature = "fpga"))] + { + params = tfhe::boolean::parameters::DEFAULT_PARAMETERS_KS_PBS; + client_key = boolean_engine.create_client_key(params); + } + + // generate the server key, only the SW needs this + let server_key = boolean_engine.create_server_key(&client_key); + + #[cfg(feature = "fpga")] + server_key.enable_fpga(params); + + let ct_1 = encrypt(flags.input1.into(), &client_key); + let ct_2 = encrypt(flags.input2.into(), &client_key); + + let ct_1= ct_1.iter().collect(); + let ct_2= ct_2.iter().collect(); + + let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2); + + let output = decrypt(&result, &client_key); + + println!("{:08b}", output); +} diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir b/tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir new file mode 100644 index 000000000..d70eb52f3 --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir @@ -0,0 +1,73 @@ +// RUN: heir-translate %s --emit-tfhe-rust-bool > %S/src/fn_under_test.rs +// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main_add_one -- 1 1 | FileCheck %s + +!bsks = !tfhe_rust_bool.server_key +!eb = !tfhe_rust_bool.eb + +// CHECK-LABEL: pub fn fn_under_test( +// CHECK-NEXT: [[bsks:v[0-9]+]]: &ServerKey, +// CHECK-NEXT: [[input1:v[0-9]+]]: &Vec, +// CHECK-NEXT: [[input2:v[0-9]+]]: &Vec, +// CHECK-NEXT: ) -> Vec { +func.func @fn_under_test(%bsks : !bsks, %arg0: tensor<8x!eb>, %arg1: tensor<8x!eb>) -> tensor<8x!eb> { + %c7 = arith.constant 7 : index + %c6 = arith.constant 6 : index + %c5 = arith.constant 5 : index + %c4 = arith.constant 4 : index + %c3 = arith.constant 3 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %extracted_00 = tensor.extract %arg0[%c0] : tensor<8x!eb> + %extracted_01 = tensor.extract %arg0[%c1] : tensor<8x!eb> + %extracted_02 = tensor.extract %arg0[%c2] : tensor<8x!eb> + %extracted_03 = tensor.extract %arg0[%c3] : tensor<8x!eb> + %extracted_04 = tensor.extract %arg0[%c4] : tensor<8x!eb> + %extracted_05 = tensor.extract %arg0[%c5] : tensor<8x!eb> + %extracted_06 = tensor.extract %arg0[%c6] : tensor<8x!eb> + %extracted_07 = tensor.extract %arg0[%c7] : tensor<8x!eb> + %extracted_10 = tensor.extract %arg1[%c0] : tensor<8x!eb> + %extracted_11 = tensor.extract %arg1[%c1] : tensor<8x!eb> + %extracted_12 = tensor.extract %arg1[%c2] : tensor<8x!eb> + %extracted_13 = tensor.extract %arg1[%c3] : tensor<8x!eb> + %extracted_14 = tensor.extract %arg1[%c4] : tensor<8x!eb> + %extracted_15 = tensor.extract %arg1[%c5] : tensor<8x!eb> + %extracted_16 = tensor.extract %arg1[%c6] : tensor<8x!eb> + %extracted_17 = tensor.extract %arg1[%c7] : tensor<8x!eb> + %ha_s = tfhe_rust_bool.xor %bsks, %extracted_00, %extracted_10 : (!bsks, !eb, !eb) -> !eb + %ha_c = tfhe_rust_bool.and %bsks, %extracted_00, %extracted_10 : (!bsks, !eb, !eb) -> !eb + %fa0_1 = tfhe_rust_bool.xor %bsks, %extracted_01, %extracted_11 : (!bsks, !eb, !eb) -> !eb + %fa0_2 = tfhe_rust_bool.and %bsks, %extracted_01, %extracted_11 : (!bsks, !eb, !eb) -> !eb + %fa0_3 = tfhe_rust_bool.and %bsks, %fa0_1, %ha_c : (!bsks, !eb, !eb) -> !eb + %fa0_s = tfhe_rust_bool.xor %bsks, %fa0_1, %ha_c : (!bsks, !eb, !eb) -> !eb + %fa0_c = tfhe_rust_bool.xor %bsks, %fa0_2, %fa0_3 : (!bsks, !eb, !eb) -> !eb + %fa1_1 = tfhe_rust_bool.xor %bsks, %extracted_02, %extracted_12 : (!bsks, !eb, !eb) -> !eb + %fa1_2 = tfhe_rust_bool.and %bsks, %extracted_02, %extracted_12 : (!bsks, !eb, !eb) -> !eb + %fa1_3 = tfhe_rust_bool.and %bsks, %fa1_1, %fa0_c : (!bsks, !eb, !eb) -> !eb + %fa1_s = tfhe_rust_bool.xor %bsks, %fa1_1, %fa0_c : (!bsks, !eb, !eb) -> !eb + %fa1_c = tfhe_rust_bool.xor %bsks, %fa1_2, %fa1_3 : (!bsks, !eb, !eb) -> !eb + %fa2_1 = tfhe_rust_bool.xor %bsks, %extracted_03, %extracted_13 : (!bsks, !eb, !eb) -> !eb + %fa2_2 = tfhe_rust_bool.and %bsks, %extracted_03, %extracted_13 : (!bsks, !eb, !eb) -> !eb + %fa2_3 = tfhe_rust_bool.and %bsks, %fa2_1, %fa1_c : (!bsks, !eb, !eb) -> !eb + %fa2_s = tfhe_rust_bool.xor %bsks, %fa2_1, %fa1_c : (!bsks, !eb, !eb) -> !eb + %fa2_c = tfhe_rust_bool.xor %bsks, %fa2_2, %fa2_3 : (!bsks, !eb, !eb) -> !eb + %fa3_1 = tfhe_rust_bool.xor %bsks, %extracted_04, %extracted_14 : (!bsks, !eb, !eb) -> !eb + %fa3_2 = tfhe_rust_bool.and %bsks, %extracted_04, %extracted_14 : (!bsks, !eb, !eb) -> !eb + %fa3_3 = tfhe_rust_bool.and %bsks, %fa3_1, %fa2_c : (!bsks, !eb, !eb) -> !eb + %fa3_s = tfhe_rust_bool.xor %bsks, %fa3_1, %fa2_c : (!bsks, !eb, !eb) -> !eb + %fa3_c = tfhe_rust_bool.xor %bsks, %fa3_2, %fa3_3 : (!bsks, !eb, !eb) -> !eb + %fa4_1 = tfhe_rust_bool.xor %bsks, %extracted_05, %extracted_15 : (!bsks, !eb, !eb) -> !eb + %fa4_2 = tfhe_rust_bool.and %bsks, %extracted_05, %extracted_15 : (!bsks, !eb, !eb) -> !eb + %fa4_3 = tfhe_rust_bool.and %bsks, %fa4_1, %fa3_c : (!bsks, !eb, !eb) -> !eb + %fa4_s = tfhe_rust_bool.xor %bsks, %fa4_1, %fa3_c : (!bsks, !eb, !eb) -> !eb + %fa4_c = tfhe_rust_bool.xor %bsks, %fa4_2, %fa4_3 : (!bsks, !eb, !eb) -> !eb + %fa5_1 = tfhe_rust_bool.xor %bsks, %extracted_06, %extracted_16 : (!bsks, !eb, !eb) -> !eb + %fa5_2 = tfhe_rust_bool.and %bsks, %extracted_06, %extracted_16 : (!bsks, !eb, !eb) -> !eb + %fa5_3 = tfhe_rust_bool.and %bsks, %fa5_1, %fa4_c : (!bsks, !eb, !eb) -> !eb + %fa5_s = tfhe_rust_bool.xor %bsks, %fa5_1, %fa4_c : (!bsks, !eb, !eb) -> !eb + %fa5_c = tfhe_rust_bool.xor %bsks, %fa5_2, %fa5_3 : (!bsks, !eb, !eb) -> !eb + %fa6_1 = tfhe_rust_bool.xor %bsks, %extracted_07, %extracted_17 : (!bsks, !eb, !eb) -> !eb + %fa6_s = tfhe_rust_bool.xor %bsks, %fa6_1, %fa5_c : (!bsks, !eb, !eb) -> !eb + %from_elements = tensor.from_elements %fa6_s, %fa5_s, %fa4_s, %fa3_s, %fa2_s, %fa1_s, %fa0_s, %ha_s : tensor<8x!eb> + return %from_elements : tensor<8x!eb> +} diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir b/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir new file mode 100644 index 000000000..8ecca7cad --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir @@ -0,0 +1,13 @@ +// This test ensures the testing harness is working properly with minimal codegen. + +// RUN: heir-translate %s --emit-tfhe-rust-bool > %S/src/fn_under_test.rs +// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main -- 1 1 | FileCheck %s + +!bsks = !tfhe_rust_bool.server_key +!eb = !tfhe_rust_bool.eb + +// CHECK: 1 +func.func @fn_under_test(%bsks : !bsks, %a: tensor<8x!eb>, %b: tensor<8x!eb>) -> tensor<8x!eb> { + %res = tfhe_rust_bool.and_packed %bsks, %a, %b: (!bsks, tensor<8x!eb>, tensor<8x!eb>) -> tensor<8x!eb> + return %res : tensor<8x!eb> +} From f706e3b544d9130ae7b5a4f687bd14b1e5bca8bb Mon Sep 17 00:00:00 2001 From: Wouter Legiest Date: Tue, 19 Mar 2024 13:13:35 +0000 Subject: [PATCH 05/11] Working simple packed and --- include/Target/TfheRustBool/TfheRustBoolEmitter.h | 1 + lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp | 15 +++++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/include/Target/TfheRustBool/TfheRustBoolEmitter.h b/include/Target/TfheRustBool/TfheRustBoolEmitter.h index fca9fe291..95c6c41e4 100644 --- a/include/Target/TfheRustBool/TfheRustBoolEmitter.h +++ b/include/Target/TfheRustBool/TfheRustBoolEmitter.h @@ -60,6 +60,7 @@ class TfheRustBoolEmitter { LogicalResult printOperation(XnorOp op); LogicalResult printOperation(AndPackedOp op); + LogicalResult printOperation(XorPackedOp op); // Helpers for above LogicalResult printSksMethod(::mlir::Value result, ::mlir::Value sks, diff --git a/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp b/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp index 4ae346e8e..e93ceae2c 100644 --- a/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp +++ b/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp @@ -64,8 +64,8 @@ LogicalResult TfheRustBoolEmitter::translate(Operation &op) { // Arith ops .Case([&](auto op) { return printOperation(op); }) // TfheRustBool ops - .Case( - [&](auto op) { return printOperation(op); }) + .Case([&](auto op) { return printOperation(op); }) // Tensor ops .Case( [&](auto op) { return printOperation(op); }) @@ -270,6 +270,10 @@ LogicalResult TfheRustBoolEmitter::printOperation(XnorOp op) { } LogicalResult TfheRustBoolEmitter::printOperation(AndPackedOp op) { + os << "let " << variableNames->getNameForValue(op.getLhs()) << " = " + << variableNames->getNameForValue(op.getLhs()) << ".iter().collect();\n"; + os << "let " << variableNames->getNameForValue(op.getRhs()) << " = " + << variableNames->getNameForValue(op.getRhs()) << ".iter().collect();\n"; return printSksMethod(op.getResult(), op.getServerKey(), {op.getLhs(), op.getRhs()}, "and_packed"); } @@ -288,10 +292,9 @@ FailureOr TfheRustBoolEmitter::convertType(Type type) { // FIXME: why can't both types be FailureOr? auto elementTy = convertType(shapedType.getElementType()); if (failed(elementTy)) return failure(); - auto refprefix = - shapedType.getElementType().hasTrait() ? "&" : ""; - return std::string(std::string("Vec<") + refprefix + elementTy.value() + - ">"); + // auto refprefix = + // shapedType.getElementType().hasTrait() ? "&" : ""; + return std::string(std::string("Vec<") + elementTy.value() + ">"); } return llvm::TypeSwitch>(type) .Case( From 22b5ed1f6740e4369b434f44043d9c5b7002d8a5 Mon Sep 17 00:00:00 2001 From: Wouter Legiest Date: Tue, 19 Mar 2024 18:03:55 +0100 Subject: [PATCH 06/11] Working FPGA generation code --- .../Target/TfheRustBool/TfheRustBoolEmitter.h | 2 + .../TfheRustBool/TfheRustBoolEmitter.cpp | 94 +++++++-- tests/tfhe_rust_bool/add_one_bool.mlir | 6 +- .../tfhe_rust_bool/end_to_end_fpga/README.md | 2 +- .../end_to_end_fpga/src/fn_under_test.rs | 192 +++++++++++------- .../end_to_end_fpga/test_add_one_bool.mlir | 105 +++++----- .../end_to_end_fpga/test_packed_and.mlir | 2 +- 7 files changed, 258 insertions(+), 145 deletions(-) diff --git a/include/Target/TfheRustBool/TfheRustBoolEmitter.h b/include/Target/TfheRustBool/TfheRustBoolEmitter.h index 95c6c41e4..8a60783e0 100644 --- a/include/Target/TfheRustBool/TfheRustBoolEmitter.h +++ b/include/Target/TfheRustBool/TfheRustBoolEmitter.h @@ -51,7 +51,9 @@ class TfheRustBoolEmitter { LogicalResult printOperation(::mlir::func::ReturnOp op); LogicalResult printOperation(CreateTrivialOp op); LogicalResult printOperation(tensor::ExtractOp op); + LogicalResult printOperation(tensor::ExtractSliceOp op); LogicalResult printOperation(tensor::FromElementsOp op); + LogicalResult printOperation(tensor::ConcatOp op); LogicalResult printOperation(AndOp op); LogicalResult printOperation(NandOp op); LogicalResult printOperation(OrOp op); diff --git a/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp b/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp index e93ceae2c..fdc73eb67 100644 --- a/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp +++ b/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp @@ -67,7 +67,7 @@ LogicalResult TfheRustBoolEmitter::translate(Operation &op) { .Case([&](auto op) { return printOperation(op); }) // Tensor ops - .Case( + .Case( [&](auto op) { return printOperation(op); }) .Default([&](Operation &) { @@ -176,13 +176,17 @@ LogicalResult TfheRustBoolEmitter::printSksMethod( auto *prefix = value.getType().hasTrait() ? "&" : ""; // First check if a DefiningOp exists // if not: comes from function definition - mlir::Operation *op = value.getDefiningOp(); - if (op) { - prefix = isa(op) ? "" : prefix; + + mlir::Operation *opParent = value.getDefiningOp(); + if (opParent) { + prefix = isa(opParent) ? "" : prefix; } else { prefix = ""; } + prefix = op.find("packed") ? "&" : prefix; + + return prefix + variableNames->getNameForValue(value) + (!operandTypes.empty() ? " as " + *operandTypesIt++ : ""); }); @@ -226,6 +230,12 @@ LogicalResult TfheRustBoolEmitter::printOperation(tensor::ExtractOp op) { return success(); } +LogicalResult TfheRustBoolEmitter::printOperation(tensor::ExtractSliceOp op) { + emitAssignPrefix(op.getResult()); + os << "vec![&" << variableNames->getNameForValue(op.getSource()) << "[" << op.getStaticOffsets()[0] << "]];\n"; + return success(); +} + LogicalResult TfheRustBoolEmitter::printOperation(tensor::FromElementsOp op) { emitAssignPrefix(op.getResult()); os << "vec![" << commaSeparatedValues(op.getOperands(), [&](Value value) { @@ -239,6 +249,18 @@ LogicalResult TfheRustBoolEmitter::printOperation(tensor::FromElementsOp op) { return success(); } +LogicalResult TfheRustBoolEmitter::printOperation(tensor::ConcatOp op) { + auto varName = variableNames->getNameForValue(op.getResult()); + os << "let mut " << varName << ": Vec = vec![];\n"; + ValueRange values = op.getOperands(); + for(Value a: values){ + os << varName << ".extend(" << variableNames->getNameForValue(a) << "_ref);\n"; + } + + return success(); +} + + LogicalResult TfheRustBoolEmitter::printOperation(AndOp op) { return printSksMethod(op.getResult(), op.getServerKey(), {op.getLhs(), op.getRhs()}, "and"); @@ -270,17 +292,65 @@ LogicalResult TfheRustBoolEmitter::printOperation(XnorOp op) { } LogicalResult TfheRustBoolEmitter::printOperation(AndPackedOp op) { - os << "let " << variableNames->getNameForValue(op.getLhs()) << " = " - << variableNames->getNameForValue(op.getLhs()) << ".iter().collect();\n"; - os << "let " << variableNames->getNameForValue(op.getRhs()) << " = " - << variableNames->getNameForValue(op.getRhs()) << ".iter().collect();\n"; - return printSksMethod(op.getResult(), op.getServerKey(), - {op.getLhs(), op.getRhs()}, "and_packed"); + os << "let " << variableNames->getNameForValue(op.getResult()) << "_ref = "; + std::string_view opName = "and_packed"; + + os << variableNames->getNameForValue(op.getServerKey()) << "." << opName << "("; + os << commaSeparatedValues({op.getLhs(), op.getRhs()}, [&](Value value) { + auto *prefix = value.getType().hasTrait() ? "&" : ""; + // First check if a DefiningOp exists + // if not: comes from function definition + + mlir::Operation *opParent = value.getDefiningOp(); + if (opParent) { + prefix = isa(opParent) ? "" : prefix; + } else { + prefix = ""; + } + + prefix = opName.find("packed") ? "&" : prefix; + + + return prefix + variableNames->getNameForValue(value); + }); + os << ");\n"; + + os << "let " << variableNames->getNameForValue(op.getResult()) << ": Vec<&Ciphertext> = " + << variableNames->getNameForValue(op.getResult()) << "_ref.iter().collect();\n"; + return success(); } LogicalResult TfheRustBoolEmitter::printOperation(XorPackedOp op) { - return printSksMethod(op.getResult(), op.getServerKey(), - {op.getLhs(), op.getRhs()}, "xor_packed"); + // os << "let " << variableNames->getNameForValue(op.getLhs()) << " = " + // << variableNames->getNameForValue(op.getLhs()) << ".iter().collect();\n"; + // os << "let " << variableNames->getNameForValue(op.getRhs()) << " = " + // << variableNames->getNameForValue(op.getRhs()) << ".iter().collect();\n"; + os << "let " << variableNames->getNameForValue(op.getResult()) << "_ref = "; + std::string_view opName = "xor_packed"; + + os << variableNames->getNameForValue(op.getServerKey()) << "." << opName << "("; + os << commaSeparatedValues({op.getLhs(), op.getRhs()}, [&](Value value) { + auto *prefix = value.getType().hasTrait() ? "&" : ""; + // First check if a DefiningOp exists + // if not: comes from function definition + + mlir::Operation *opParent = value.getDefiningOp(); + if (opParent) { + prefix = isa(opParent) ? "" : prefix; + } else { + prefix = ""; + } + + prefix = opName.find("packed") ? "&" : prefix; + + + return prefix + variableNames->getNameForValue(value); + }); + os << ");\n"; + + os << "let " << variableNames->getNameForValue(op.getResult()) << ": Vec<&Ciphertext> = " + << variableNames->getNameForValue(op.getResult()) << "_ref.iter().collect();\n"; + return success(); } FailureOr TfheRustBoolEmitter::convertType(Type type) { diff --git a/tests/tfhe_rust_bool/add_one_bool.mlir b/tests/tfhe_rust_bool/add_one_bool.mlir index cd828bb54..e9018007d 100644 --- a/tests/tfhe_rust_bool/add_one_bool.mlir +++ b/tests/tfhe_rust_bool/add_one_bool.mlir @@ -5,9 +5,9 @@ // CHECK-LABEL: pub fn fn_under_test( // CHECK-NEXT: [[bsks:v[0-9]+]]: &ServerKey, -// CHECK-NEXT: [[input1:v[0-9]+]]: &Vec<&Ciphertext>, -// CHECK-NEXT: [[input2:v[0-9]+]]: &Vec<&Ciphertext>, -// CHECK-NEXT: ) -> Vec<&Ciphertext> { +// CHECK-NEXT: [[input1:v[0-9]+]]: &Vec, +// CHECK-NEXT: [[input2:v[0-9]+]]: &Vec, +// CHECK-NEXT: ) -> Vec { func.func @fn_under_test(%bsks : !bsks, %arg0: tensor<8x!eb>, %arg1: tensor<8x!eb>) -> tensor<8x!eb> { %c7 = arith.constant 7 : index %c6 = arith.constant 6 : index diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/README.md b/tests/tfhe_rust_bool/end_to_end_fpga/README.md index 42be8b572..5d6021216 100644 --- a/tests/tfhe_rust_bool/end_to_end_fpga/README.md +++ b/tests/tfhe_rust_bool/end_to_end_fpga/README.md @@ -19,7 +19,7 @@ Cargo home `$HOME/.cargo` may need to be replaced by your custom `$CARGO_HOME`, if you overrode the default option when installing Cargo. ```bash -bazel query "filter('.mlir.test$', //tests/tfhe_rust_bool/end_to_end/...)" \ +bazel query "filter('.mlir.test$', //tests/tfhe_rust_bool/end_to_end_fpga/...)" \ | xargs bazel test --sandbox_writable_path=$HOME/.cargo "$@" ``` diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/src/fn_under_test.rs b/tests/tfhe_rust_bool/end_to_end_fpga/src/fn_under_test.rs index 2109a435e..7ed9b2548 100644 --- a/tests/tfhe_rust_bool/end_to_end_fpga/src/fn_under_test.rs +++ b/tests/tfhe_rust_bool/end_to_end_fpga/src/fn_under_test.rs @@ -1,84 +1,124 @@ use tfhe::boolean::prelude::*; +// pub fn fn_under_test( +// v0: &ServerKey, +// v1: &Vec, +// v2: &Vec, +// ) -> Vec { +// let v1 = v1.iter().collect(); +// let v2 = v2.iter().collect(); +// let v3 = v0.xor_packed(&v1, &v2); +// v3 +// } + +use tfhe::boolean::prelude::*; + pub fn fn_under_test( v0: &ServerKey, v1: &Vec, v2: &Vec, ) -> Vec { - let v1 = v1.iter().collect(); - let v2 = v2.iter().collect(); - let v3 = v0.xor_packed(&v1, &v2); - v3 -} - - -// pub fn fn_under_test( -// v0: &ServerKey, -// v1: &Vec<&Ciphertext>, -// v2: &Vec<&Ciphertext>, -// ) -> Vec { -// let v3 = 7; -// let v4 = 6; -// let v5 = 5; -// let v6 = 4; -// let v7 = 3; -// let v8 = 2; -// let v9 = 1; -// let v10 = 0; -// let v11 = vec![v1[v10]]; -// let v12 = vec![v1[v9]]; -// let v13 = vec![v1[v8]]; -// let v14 = vec![v1[v7]]; -// let v15 = vec![v1[v6]]; -// let v16 = vec![v1[v5]]; -// let v17 = vec![v1[v4]]; -// let v18 = vec![v1[v3]]; -// let v19 = vec![v2[v10]]; -// let v20 = vec![v2[v9]]; -// let v21 = vec![v2[v8]]; -// let v22 = vec![v2[v7]]; -// let v23 = vec![v2[v6]]; -// let v24 = vec![v2[v5]]; -// let v25 = vec![v2[v4]]; -// let v26 = vec![v2[v3]]; -// let v27 = v0.xor_packed(&v11, &v19); -// let v27 = v27.iter().collect(); -// let v28 = v0.and_packed(&v11, &v19); -// let v29 = v0.xor_packed(&v12, &v20); -// let v30 = v0.and_packed(&v12, &v20); -// let v28 = v28.iter().collect(); -// let v29 = v29.iter().collect(); -// let v31 = v0.and_packed(&v29, &v28); -// let v32 = v0.xor_packed(&v29, &v28); -// let v33 = v0.xor_packed(&v30, &v31); -// let v34 = v0.xor_packed(&v13, &v21); -// let v35 = v0.and_packed(&v13, &v21); -// let v36 = v0.and_packed(&v34, &v33); -// let v37 = v0.xor_packed(&v34, &v33); -// let v38 = v0.xor_packed(&v35, &v36); -// let v39 = v0.xor_packed(&v14, &v22); -// let v40 = v0.and_packed(&v14, &v22); -// let v41 = v0.and_packed(&v39, &v38); -// let v42 = v0.xor_packed(&v39, &v38); -// let v43 = v0.xor_packed(&v40, &v41); -// let v44 = v0.xor_packed(&v15, &v23); -// let v45 = v0.and_packed(&v15, &v23); -// let v46 = v0.and_packed(&v44, &v43); -// let v47 = v0.xor_packed(&v44, &v43); -// let v48 = v0.xor_packed(&v45, &v46); -// let v49 = v0.xor_packed(&v16, &v24); -// let v50 = v0.and_packed(&v16, &v24); -// let v51 = v0.and_packed(&v49, &v48); -// let v52 = v0.xor_packed(&v49, &v48); -// let v53 = v0.xor_packed(&v50, &v51); -// let v54 = v0.xor_packed(&v17, &v25); -// let v55 = v0.and_packed(&v17, &v25); -// let v56 = v0.and_packed(&v54, &v53); -// let v57 = v0.xor_packed(&v54, &v53); -// let v58 = v0.xor_packed(&v55, &v56); -// let v59 = v0.xor_packed(&v18, &v26); -// let v60 = v0.xor_packed(&v59, &v58); -// let v61 = vec![v60[0], v57[0], v52[0], v47[0], v42[0], v37[0], v32[0], v27[0]]; -// v61 -// } + let v3 = 7; + let v4 = 6; + let v5 = 5; + let v6 = 4; + let v7 = 3; + let v8 = 2; + let v9 = 1; + let v10 = 0; + let v11 = vec![&v1[0]]; + let v12 = vec![&v1[1]]; + let v13 = vec![&v1[2]]; + let v14 = vec![&v1[3]]; + let v15 = vec![&v1[4]]; + let v16 = vec![&v1[5]]; + let v17 = vec![&v1[6]]; + let v18 = vec![&v1[7]]; + let v19 = vec![&v2[0]]; + let v20 = vec![&v2[1]]; + let v21 = vec![&v2[2]]; + let v22 = vec![&v2[3]]; + let v23 = vec![&v2[4]]; + let v24 = vec![&v2[5]]; + let v25 = vec![&v2[6]]; + let v26 = vec![&v2[7]]; + let v27_ref = v0.xor_packed(&v11, &v19); + let v27: Vec<&Ciphertext> = v27_ref.iter().collect(); + let v28_ref = v0.and_packed(&v11, &v19); + let v28: Vec<&Ciphertext> = v28_ref.iter().collect(); + let v29_ref = v0.xor_packed(&v12, &v20); + let v29: Vec<&Ciphertext> = v29_ref.iter().collect(); + let v30_ref = v0.and_packed(&v12, &v20); + let v30: Vec<&Ciphertext> = v30_ref.iter().collect(); + let v31_ref = v0.and_packed(&v29, &v28); + let v31: Vec<&Ciphertext> = v31_ref.iter().collect(); + let v32_ref = v0.xor_packed(&v29, &v28); + let v32: Vec<&Ciphertext> = v32_ref.iter().collect(); + let v33_ref = v0.xor_packed(&v30, &v31); + let v33: Vec<&Ciphertext> = v33_ref.iter().collect(); + let v34_ref = v0.xor_packed(&v13, &v21); + let v34: Vec<&Ciphertext> = v34_ref.iter().collect(); + let v35_ref = v0.and_packed(&v13, &v21); + let v35: Vec<&Ciphertext> = v35_ref.iter().collect(); + let v36_ref = v0.and_packed(&v34, &v33); + let v36: Vec<&Ciphertext> = v36_ref.iter().collect(); + let v37_ref = v0.xor_packed(&v34, &v33); + let v37: Vec<&Ciphertext> = v37_ref.iter().collect(); + let v38_ref = v0.xor_packed(&v35, &v36); + let v38: Vec<&Ciphertext> = v38_ref.iter().collect(); + let v39_ref = v0.xor_packed(&v14, &v22); + let v39: Vec<&Ciphertext> = v39_ref.iter().collect(); + let v40_ref = v0.and_packed(&v14, &v22); + let v40: Vec<&Ciphertext> = v40_ref.iter().collect(); + let v41_ref = v0.and_packed(&v39, &v38); + let v41: Vec<&Ciphertext> = v41_ref.iter().collect(); + let v42_ref = v0.xor_packed(&v39, &v38); + let v42: Vec<&Ciphertext> = v42_ref.iter().collect(); + let v43_ref = v0.xor_packed(&v40, &v41); + let v43: Vec<&Ciphertext> = v43_ref.iter().collect(); + let v44_ref = v0.xor_packed(&v15, &v23); + let v44: Vec<&Ciphertext> = v44_ref.iter().collect(); + let v45_ref = v0.and_packed(&v15, &v23); + let v45: Vec<&Ciphertext> = v45_ref.iter().collect(); + let v46_ref = v0.and_packed(&v44, &v43); + let v46: Vec<&Ciphertext> = v46_ref.iter().collect(); + let v47_ref = v0.xor_packed(&v44, &v43); + let v47: Vec<&Ciphertext> = v47_ref.iter().collect(); + let v48_ref = v0.xor_packed(&v45, &v46); + let v48: Vec<&Ciphertext> = v48_ref.iter().collect(); + let v49_ref = v0.xor_packed(&v16, &v24); + let v49: Vec<&Ciphertext> = v49_ref.iter().collect(); + let v50_ref = v0.and_packed(&v16, &v24); + let v50: Vec<&Ciphertext> = v50_ref.iter().collect(); + let v51_ref = v0.and_packed(&v49, &v48); + let v51: Vec<&Ciphertext> = v51_ref.iter().collect(); + let v52_ref = v0.xor_packed(&v49, &v48); + let v52: Vec<&Ciphertext> = v52_ref.iter().collect(); + let v53_ref = v0.xor_packed(&v50, &v51); + let v53: Vec<&Ciphertext> = v53_ref.iter().collect(); + let v54_ref = v0.xor_packed(&v17, &v25); + let v54: Vec<&Ciphertext> = v54_ref.iter().collect(); + let v55_ref = v0.and_packed(&v17, &v25); + let v55: Vec<&Ciphertext> = v55_ref.iter().collect(); + let v56_ref = v0.and_packed(&v54, &v53); + let v56: Vec<&Ciphertext> = v56_ref.iter().collect(); + let v57_ref = v0.xor_packed(&v54, &v53); + let v57: Vec<&Ciphertext> = v57_ref.iter().collect(); + let v58_ref = v0.xor_packed(&v55, &v56); + let v58: Vec<&Ciphertext> = v58_ref.iter().collect(); + let v59_ref = v0.xor_packed(&v18, &v26); + let v59: Vec<&Ciphertext> = v59_ref.iter().collect(); + let v60_ref = v0.xor_packed(&v59, &v58); + let v60: Vec<&Ciphertext> = v60_ref.iter().collect(); + let mut v61: Vec = vec![]; + v61.extend(v60_ref); + v61.extend(v57_ref); + v61.extend(v52_ref); + v61.extend(v47_ref); + v61.extend(v42_ref); + v61.extend(v37_ref); + v61.extend(v32_ref); + v61.extend(v27_ref); + v61 +} \ No newline at end of file diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir b/tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir index d70eb52f3..74c7ee6d4 100644 --- a/tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir +++ b/tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir @@ -1,5 +1,5 @@ // RUN: heir-translate %s --emit-tfhe-rust-bool > %S/src/fn_under_test.rs -// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main_add_one -- 1 1 | FileCheck %s +// RUN: cargo run --release --manifest-path %S/Cargo.toml -- 1 1 | FileCheck %s !bsks = !tfhe_rust_bool.server_key !eb = !tfhe_rust_bool.eb @@ -18,56 +18,57 @@ func.func @fn_under_test(%bsks : !bsks, %arg0: tensor<8x!eb>, %arg1: tensor<8x! %c2 = arith.constant 2 : index %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index - %extracted_00 = tensor.extract %arg0[%c0] : tensor<8x!eb> - %extracted_01 = tensor.extract %arg0[%c1] : tensor<8x!eb> - %extracted_02 = tensor.extract %arg0[%c2] : tensor<8x!eb> - %extracted_03 = tensor.extract %arg0[%c3] : tensor<8x!eb> - %extracted_04 = tensor.extract %arg0[%c4] : tensor<8x!eb> - %extracted_05 = tensor.extract %arg0[%c5] : tensor<8x!eb> - %extracted_06 = tensor.extract %arg0[%c6] : tensor<8x!eb> - %extracted_07 = tensor.extract %arg0[%c7] : tensor<8x!eb> - %extracted_10 = tensor.extract %arg1[%c0] : tensor<8x!eb> - %extracted_11 = tensor.extract %arg1[%c1] : tensor<8x!eb> - %extracted_12 = tensor.extract %arg1[%c2] : tensor<8x!eb> - %extracted_13 = tensor.extract %arg1[%c3] : tensor<8x!eb> - %extracted_14 = tensor.extract %arg1[%c4] : tensor<8x!eb> - %extracted_15 = tensor.extract %arg1[%c5] : tensor<8x!eb> - %extracted_16 = tensor.extract %arg1[%c6] : tensor<8x!eb> - %extracted_17 = tensor.extract %arg1[%c7] : tensor<8x!eb> - %ha_s = tfhe_rust_bool.xor %bsks, %extracted_00, %extracted_10 : (!bsks, !eb, !eb) -> !eb - %ha_c = tfhe_rust_bool.and %bsks, %extracted_00, %extracted_10 : (!bsks, !eb, !eb) -> !eb - %fa0_1 = tfhe_rust_bool.xor %bsks, %extracted_01, %extracted_11 : (!bsks, !eb, !eb) -> !eb - %fa0_2 = tfhe_rust_bool.and %bsks, %extracted_01, %extracted_11 : (!bsks, !eb, !eb) -> !eb - %fa0_3 = tfhe_rust_bool.and %bsks, %fa0_1, %ha_c : (!bsks, !eb, !eb) -> !eb - %fa0_s = tfhe_rust_bool.xor %bsks, %fa0_1, %ha_c : (!bsks, !eb, !eb) -> !eb - %fa0_c = tfhe_rust_bool.xor %bsks, %fa0_2, %fa0_3 : (!bsks, !eb, !eb) -> !eb - %fa1_1 = tfhe_rust_bool.xor %bsks, %extracted_02, %extracted_12 : (!bsks, !eb, !eb) -> !eb - %fa1_2 = tfhe_rust_bool.and %bsks, %extracted_02, %extracted_12 : (!bsks, !eb, !eb) -> !eb - %fa1_3 = tfhe_rust_bool.and %bsks, %fa1_1, %fa0_c : (!bsks, !eb, !eb) -> !eb - %fa1_s = tfhe_rust_bool.xor %bsks, %fa1_1, %fa0_c : (!bsks, !eb, !eb) -> !eb - %fa1_c = tfhe_rust_bool.xor %bsks, %fa1_2, %fa1_3 : (!bsks, !eb, !eb) -> !eb - %fa2_1 = tfhe_rust_bool.xor %bsks, %extracted_03, %extracted_13 : (!bsks, !eb, !eb) -> !eb - %fa2_2 = tfhe_rust_bool.and %bsks, %extracted_03, %extracted_13 : (!bsks, !eb, !eb) -> !eb - %fa2_3 = tfhe_rust_bool.and %bsks, %fa2_1, %fa1_c : (!bsks, !eb, !eb) -> !eb - %fa2_s = tfhe_rust_bool.xor %bsks, %fa2_1, %fa1_c : (!bsks, !eb, !eb) -> !eb - %fa2_c = tfhe_rust_bool.xor %bsks, %fa2_2, %fa2_3 : (!bsks, !eb, !eb) -> !eb - %fa3_1 = tfhe_rust_bool.xor %bsks, %extracted_04, %extracted_14 : (!bsks, !eb, !eb) -> !eb - %fa3_2 = tfhe_rust_bool.and %bsks, %extracted_04, %extracted_14 : (!bsks, !eb, !eb) -> !eb - %fa3_3 = tfhe_rust_bool.and %bsks, %fa3_1, %fa2_c : (!bsks, !eb, !eb) -> !eb - %fa3_s = tfhe_rust_bool.xor %bsks, %fa3_1, %fa2_c : (!bsks, !eb, !eb) -> !eb - %fa3_c = tfhe_rust_bool.xor %bsks, %fa3_2, %fa3_3 : (!bsks, !eb, !eb) -> !eb - %fa4_1 = tfhe_rust_bool.xor %bsks, %extracted_05, %extracted_15 : (!bsks, !eb, !eb) -> !eb - %fa4_2 = tfhe_rust_bool.and %bsks, %extracted_05, %extracted_15 : (!bsks, !eb, !eb) -> !eb - %fa4_3 = tfhe_rust_bool.and %bsks, %fa4_1, %fa3_c : (!bsks, !eb, !eb) -> !eb - %fa4_s = tfhe_rust_bool.xor %bsks, %fa4_1, %fa3_c : (!bsks, !eb, !eb) -> !eb - %fa4_c = tfhe_rust_bool.xor %bsks, %fa4_2, %fa4_3 : (!bsks, !eb, !eb) -> !eb - %fa5_1 = tfhe_rust_bool.xor %bsks, %extracted_06, %extracted_16 : (!bsks, !eb, !eb) -> !eb - %fa5_2 = tfhe_rust_bool.and %bsks, %extracted_06, %extracted_16 : (!bsks, !eb, !eb) -> !eb - %fa5_3 = tfhe_rust_bool.and %bsks, %fa5_1, %fa4_c : (!bsks, !eb, !eb) -> !eb - %fa5_s = tfhe_rust_bool.xor %bsks, %fa5_1, %fa4_c : (!bsks, !eb, !eb) -> !eb - %fa5_c = tfhe_rust_bool.xor %bsks, %fa5_2, %fa5_3 : (!bsks, !eb, !eb) -> !eb - %fa6_1 = tfhe_rust_bool.xor %bsks, %extracted_07, %extracted_17 : (!bsks, !eb, !eb) -> !eb - %fa6_s = tfhe_rust_bool.xor %bsks, %fa6_1, %fa5_c : (!bsks, !eb, !eb) -> !eb - %from_elements = tensor.from_elements %fa6_s, %fa5_s, %fa4_s, %fa3_s, %fa2_s, %fa1_s, %fa0_s, %ha_s : tensor<8x!eb> + %extracted_00 = tensor.extract_slice %arg0 [0][1][1] : tensor<8x!eb> to tensor<1x!eb> + %extracted_01 = tensor.extract_slice %arg0 [1][1][1] : tensor<8x!eb> to tensor<1x!eb> + %extracted_02 = tensor.extract_slice %arg0 [2][1][1]: tensor<8x!eb> to tensor<1x!eb> + %extracted_03 = tensor.extract_slice %arg0 [3][1][1] : tensor<8x!eb> to tensor<1x!eb> + %extracted_04 = tensor.extract_slice %arg0 [4][1][1] : tensor<8x!eb> to tensor<1x!eb> + %extracted_05 = tensor.extract_slice %arg0 [5][1][1] : tensor<8x!eb> to tensor<1x!eb> + %extracted_06 = tensor.extract_slice %arg0 [6][1][1] : tensor<8x!eb> to tensor<1x!eb> + %extracted_07 = tensor.extract_slice %arg0 [7][1][1] : tensor<8x!eb> to tensor<1x!eb> + %extracted_10 = tensor.extract_slice %arg1 [0][1][1] : tensor<8x!eb> to tensor<1x!eb> + %extracted_11 = tensor.extract_slice %arg1 [1][1][1]: tensor<8x!eb> to tensor<1x!eb> + %extracted_12 = tensor.extract_slice %arg1 [2][1][1] : tensor<8x!eb> to tensor<1x!eb> + %extracted_13 = tensor.extract_slice %arg1 [3][1][1] : tensor<8x!eb> to tensor<1x!eb> + %extracted_14 = tensor.extract_slice %arg1 [4][1][1] : tensor<8x!eb> to tensor<1x!eb> + %extracted_15 = tensor.extract_slice %arg1 [5][1][1] : tensor<8x!eb> to tensor<1x!eb> + %extracted_16 = tensor.extract_slice %arg1 [6][1][1] : tensor<8x!eb> to tensor<1x!eb> + %extracted_17 = tensor.extract_slice %arg1 [7][1][1]: tensor<8x!eb> to tensor<1x!eb> + %ha_s = tfhe_rust_bool.xor_packed %bsks, %extracted_00, %extracted_10 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %ha_c = tfhe_rust_bool.and_packed %bsks, %extracted_00, %extracted_10: (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa0_1 = tfhe_rust_bool.xor_packed %bsks, %extracted_01, %extracted_11 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa0_2 = tfhe_rust_bool.and_packed %bsks, %extracted_01, %extracted_11 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa0_3 = tfhe_rust_bool.and_packed %bsks, %fa0_1, %ha_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa0_s = tfhe_rust_bool.xor_packed %bsks, %fa0_1, %ha_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa0_c = tfhe_rust_bool.xor_packed %bsks, %fa0_2, %fa0_3 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa1_1 = tfhe_rust_bool.xor_packed %bsks, %extracted_02, %extracted_12 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa1_2 = tfhe_rust_bool.and_packed %bsks, %extracted_02, %extracted_12 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa1_3 = tfhe_rust_bool.and_packed %bsks, %fa1_1, %fa0_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa1_s = tfhe_rust_bool.xor_packed %bsks, %fa1_1, %fa0_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa1_c = tfhe_rust_bool.xor_packed %bsks, %fa1_2, %fa1_3 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa2_1 = tfhe_rust_bool.xor_packed %bsks, %extracted_03, %extracted_13 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa2_2 = tfhe_rust_bool.and_packed %bsks, %extracted_03, %extracted_13 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa2_3 = tfhe_rust_bool.and_packed %bsks, %fa2_1, %fa1_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa2_s = tfhe_rust_bool.xor_packed %bsks, %fa2_1, %fa1_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa2_c = tfhe_rust_bool.xor_packed %bsks, %fa2_2, %fa2_3 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa3_1 = tfhe_rust_bool.xor_packed %bsks, %extracted_04, %extracted_14 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa3_2 = tfhe_rust_bool.and_packed %bsks, %extracted_04, %extracted_14 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa3_3 = tfhe_rust_bool.and_packed %bsks, %fa3_1, %fa2_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa3_s = tfhe_rust_bool.xor_packed %bsks, %fa3_1, %fa2_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa3_c = tfhe_rust_bool.xor_packed %bsks, %fa3_2, %fa3_3 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa4_1 = tfhe_rust_bool.xor_packed %bsks, %extracted_05, %extracted_15 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa4_2 = tfhe_rust_bool.and_packed %bsks, %extracted_05, %extracted_15 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa4_3 = tfhe_rust_bool.and_packed %bsks, %fa4_1, %fa3_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa4_s = tfhe_rust_bool.xor_packed %bsks, %fa4_1, %fa3_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa4_c = tfhe_rust_bool.xor_packed %bsks, %fa4_2, %fa4_3 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa5_1 = tfhe_rust_bool.xor_packed %bsks, %extracted_06, %extracted_16 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa5_2 = tfhe_rust_bool.and_packed %bsks, %extracted_06, %extracted_16 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa5_3 = tfhe_rust_bool.and_packed %bsks, %fa5_1, %fa4_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa5_s = tfhe_rust_bool.xor_packed %bsks, %fa5_1, %fa4_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa5_c = tfhe_rust_bool.xor_packed %bsks, %fa5_2, %fa5_3 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa6_1 = tfhe_rust_bool.xor_packed %bsks, %extracted_07, %extracted_17 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %fa6_s = tfhe_rust_bool.xor_packed %bsks, %fa6_1, %fa5_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> + %from_elements = tensor.concat dim(0) %fa6_s, %fa5_s, %fa4_s, %fa3_s, %fa2_s, %fa1_s, %fa0_s, %ha_s + : (tensor<1x!eb>, tensor<1x!eb>, tensor<1x!eb>, tensor<1x!eb>, tensor<1x!eb>, tensor<1x!eb>, tensor<1x!eb>, tensor<1x!eb>) -> tensor<8x!eb> return %from_elements : tensor<8x!eb> } diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir b/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir index 8ecca7cad..9a2e0e5b5 100644 --- a/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir +++ b/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir @@ -1,7 +1,7 @@ // This test ensures the testing harness is working properly with minimal codegen. // RUN: heir-translate %s --emit-tfhe-rust-bool > %S/src/fn_under_test.rs -// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main -- 1 1 | FileCheck %s +// RUN: cargo run --release --manifest-path %S/Cargo.toml -- 1 1 | FileCheck %s !bsks = !tfhe_rust_bool.server_key !eb = !tfhe_rust_bool.eb From c36eaeb3991fa126077283939248ddd9d55131be Mon Sep 17 00:00:00 2001 From: Wouter Legiest Date: Sat, 23 Mar 2024 14:09:57 +0000 Subject: [PATCH 07/11] New vector test --- .../test_vectorize.mlir | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 tests/cggi_to_tfhe_rust_bool/test_vectorize.mlir diff --git a/tests/cggi_to_tfhe_rust_bool/test_vectorize.mlir b/tests/cggi_to_tfhe_rust_bool/test_vectorize.mlir new file mode 100644 index 000000000..e4f3f7d61 --- /dev/null +++ b/tests/cggi_to_tfhe_rust_bool/test_vectorize.mlir @@ -0,0 +1,47 @@ +// RUN: heir-opt --cggi-to-tfhe-rust-bool -cse -remove-dead-values %s | FileCheck %s + +#encoding = #lwe.unspecified_bit_field_encoding +!ct_ty = !lwe.lwe_ciphertext +!pt_ty = !lwe.lwe_plaintext + +// CHECK-LABEL: add_bool +// CHECK-NOT: cggi +// CHECK-NOT: lwe +func.func @add_bool(%arg0: tensor<8x!ct_ty>, %arg1: tensor<8x!ct_ty>) -> tensor<8x!ct_ty> { + %true = arith.constant true + %false = arith.constant false + %c7 = arith.constant 7 : index + %c6 = arith.constant 6 : index + %c5 = arith.constant 5 : index + %c4 = arith.constant 4 : index + %c3 = arith.constant 3 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %extracted_00 = tensor.extract %arg0[%c0] : tensor<8x!ct_ty> + %extracted_01 = tensor.extract %arg0[%c1] : tensor<8x!ct_ty> + %extracted_02 = tensor.extract %arg0[%c2] : tensor<8x!ct_ty> + %extracted_03 = tensor.extract %arg0[%c3] : tensor<8x!ct_ty> + %extracted_04 = tensor.extract %arg0[%c4] : tensor<8x!ct_ty> + %extracted_05 = tensor.extract %arg0[%c5] : tensor<8x!ct_ty> + %extracted_06 = tensor.extract %arg0[%c6] : tensor<8x!ct_ty> + %extracted_07 = tensor.extract %arg0[%c7] : tensor<8x!ct_ty> + %extracted_10 = tensor.extract %arg1[%c0] : tensor<8x!ct_ty> + %extracted_11 = tensor.extract %arg1[%c1] : tensor<8x!ct_ty> + %extracted_12 = tensor.extract %arg1[%c2] : tensor<8x!ct_ty> + %extracted_13 = tensor.extract %arg1[%c3] : tensor<8x!ct_ty> + %extracted_14 = tensor.extract %arg1[%c4] : tensor<8x!ct_ty> + %extracted_15 = tensor.extract %arg1[%c5] : tensor<8x!ct_ty> + %extracted_16 = tensor.extract %arg1[%c6] : tensor<8x!ct_ty> + %extracted_17 = tensor.extract %arg1[%c7] : tensor<8x!ct_ty> + %0 = cggi.xor %extracted_00, %extracted_10 : !ct_ty + %1 = cggi.and %extracted_02, %extracted_12 : !ct_ty + %2 = cggi.xor %extracted_01, %extracted_11 : !ct_ty + %3 = cggi.and %extracted_03, %extracted_13 : !ct_ty + %4 = cggi.xor %extracted_05, %extracted_15 : !ct_ty + %5 = cggi.and %extracted_04, %extracted_14 : !ct_ty + %6 = cggi.xor %extracted_06, %extracted_16 : !ct_ty + %7 = cggi.and %extracted_07, %extracted_17 : !ct_ty + %from_elements = tensor.from_elements %0, %2, %4, %6, %1, %3, %5, %7 : tensor<8x!ct_ty> + return %from_elements : tensor<8x!ct_ty> +} From 6ccd7cdb573449c86dc56259a246f148415eb6cf Mon Sep 17 00:00:00 2001 From: Wouter Legiest Date: Sun, 24 Mar 2024 17:23:36 +0100 Subject: [PATCH 08/11] tensor From Elements Ops now checks if the defining op if of type tfhe_rust_bool --- .../TfheRustBool/IR/TfheRustBoolOps.td | 20 +- .../TfheRustBool/TfheRustBoolEmitter.cpp | 102 +++-- tests/tfhe_rust_bool/emit_tfhe_rust_bool.mlir | 2 +- .../end_to_end_fpga/src/fn_under_test.rs | 353 ++++++++++++++++++ .../end_to_end_fpga/src/main.rs | 6 +- .../end_to_end_fpga/test_add_one_bool.mlir | 2 +- 6 files changed, 448 insertions(+), 37 deletions(-) diff --git a/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td b/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td index 92dd1033f..17d24d4b4 100644 --- a/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td +++ b/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td @@ -24,16 +24,20 @@ def CreateTrivialOp : TfheRustBool_Op<"create_trivial", [Pure]> { let results = (outs TfheRustBool_Encrypted:$output); } +// --- Operations for a gate-bootstrapping API of a CGGI library --- +def TfheRustBoolLike : TypeOrContainer; + + class TfheRustBool_BinaryGateOp : TfheRustBool_Op ]> { let arguments = (ins TfheRustBool_ServerKey:$serverKey, - TfheRustBool_Encrypted:$lhs, - TfheRustBool_Encrypted:$rhs + TfheRustBoolLike:$lhs, + TfheRustBoolLike:$rhs ); - let results = (outs TfheRustBool_Encrypted:$output); + let results = (outs TfheRustBoolLike:$output); } def AndOp : TfheRustBool_BinaryGateOp<"and"> { let summary = "Logical AND of two TFHE-rs Bool ciphertexts."; } @@ -43,6 +47,16 @@ def NorOp : TfheRustBool_BinaryGateOp<"nor"> { let summary = "Logical NOR of t def XorOp : TfheRustBool_BinaryGateOp<"xor"> { let summary = "Logical XOR of two TFHE-rs Bool ciphertexts."; } def XnorOp : TfheRustBool_BinaryGateOp<"xnor"> { let summary = "Logical XNOR of two TFHE-rs Bool ciphertexts."; } +// def TfheRustBool_Ops : +// AnyTypeOp<[ +// AndOp, +// NandOp, +// OrOp, +// NorOp, +// XorOp, +// XnorOp, +// ]>; + def AndPackedOp : TfheRustBool_Op<"and_packed", [ Pure, diff --git a/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp b/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp index fdc73eb67..3e2269f9c 100644 --- a/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp +++ b/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp @@ -67,7 +67,8 @@ LogicalResult TfheRustBoolEmitter::translate(Operation &op) { .Case([&](auto op) { return printOperation(op); }) // Tensor ops - .Case( + .Case( [&](auto op) { return printOperation(op); }) .Default([&](Operation &) { @@ -168,7 +169,40 @@ void TfheRustBoolEmitter::emitAssignPrefix(Value result) { LogicalResult TfheRustBoolEmitter::printSksMethod( ::mlir::Value result, ::mlir::Value sks, ::mlir::ValueRange nonSksOperands, std::string_view op, SmallVector operandTypes) { - emitAssignPrefix(result); + + if (isa(nonSksOperands[0].getType())) { + os << "let " << variableNames->getNameForValue(result) << "_ref = "; + std::string_view opName = "and_packed"; + + os << variableNames->getNameForValue(sks) << "." << opName << "("; + os << commaSeparatedValues( + {nonSksOperands[0], nonSksOperands[1]}, [&](Value value) { + auto *prefix = value.getType().hasTrait() ? "&" : ""; + // First check if a DefiningOp exists + // if not: comes from function definition + + // getDefiningOp look for a gedefining op using a specific + // type + mlir::Operation *opParent = value.getDefiningOp(); + if (opParent) { + prefix = !isa(opParent) ? "" : prefix; + } else { + prefix = ""; + } + + prefix = opName.find("packed") ? "&" : prefix; + + return prefix + variableNames->getNameForValue(value); + }); + os << ");\n"; + + os << "let " << variableNames->getNameForValue(result) + << ": Vec<&Ciphertext> = " << variableNames->getNameForValue(result) + << "_ref.iter().collect();\n"; + return success(); + + } else { + emitAssignPrefix(result); auto operandTypesIt = operandTypes.begin(); os << variableNames->getNameForValue(sks) << "." << op << "("; @@ -176,22 +210,19 @@ LogicalResult TfheRustBoolEmitter::printSksMethod( auto *prefix = value.getType().hasTrait() ? "&" : ""; // First check if a DefiningOp exists // if not: comes from function definition - - mlir::Operation *opParent = value.getDefiningOp(); - if (opParent) { - prefix = isa(opParent) ? "" : prefix; + mlir::Operation *op = value.getDefiningOp(); + if (op) { + prefix = isa(op) ? "" : prefix; } else { prefix = ""; } - prefix = op.find("packed") ? "&" : prefix; - - return prefix + variableNames->getNameForValue(value) + (!operandTypes.empty() ? " as " + *operandTypesIt++ : ""); }); os << ");\n"; return success(); + } } LogicalResult TfheRustBoolEmitter::printOperation(CreateTrivialOp op) { @@ -219,6 +250,7 @@ LogicalResult TfheRustBoolEmitter::printOperation(arith::ConstantOp op) { return success(); } +// Produces a &Ciphertext LogicalResult TfheRustBoolEmitter::printOperation(tensor::ExtractOp op) { // We assume here that the indices are SSA values (not integer attributes). emitAssignPrefix(op.getResult()); @@ -232,35 +264,37 @@ LogicalResult TfheRustBoolEmitter::printOperation(tensor::ExtractOp op) { LogicalResult TfheRustBoolEmitter::printOperation(tensor::ExtractSliceOp op) { emitAssignPrefix(op.getResult()); - os << "vec![&" << variableNames->getNameForValue(op.getSource()) << "[" << op.getStaticOffsets()[0] << "]];\n"; + os << "vec![&" << variableNames->getNameForValue(op.getSource()) << "[" + << op.getStaticOffsets()[0] << "]];\n"; return success(); } +// Need to produce a Vec<&Ciphertext> LogicalResult TfheRustBoolEmitter::printOperation(tensor::FromElementsOp op) { emitAssignPrefix(op.getResult()); os << "vec![" << commaSeparatedValues(op.getOperands(), [&](Value value) { // Check if block argument, if so, clone. - auto cloneStr = ""; - if (isa(value)) { - cloneStr = ".clone()"; - } - return variableNames->getNameForValue(value) + cloneStr; + auto cloneStr = isa(value) ? ".clone()": ""; + // Get the name of defining operation its dialect + auto tfhe_op = value.getDefiningOp()->getDialect()->getNamespace() == "tfhe_rust_bool"; + auto prefix = tfhe_op ? "&" : ""; + return std::string(prefix) + variableNames->getNameForValue(value) + cloneStr; }) << "];\n"; return success(); } LogicalResult TfheRustBoolEmitter::printOperation(tensor::ConcatOp op) { auto varName = variableNames->getNameForValue(op.getResult()); - os << "let mut " << varName << ": Vec = vec![];\n"; + os << "let mut " << varName << ": Vec = vec![];\n"; ValueRange values = op.getOperands(); - for(Value a: values){ - os << varName << ".extend(" << variableNames->getNameForValue(a) << "_ref);\n"; + for (Value a : values) { + os << varName << ".extend(" << variableNames->getNameForValue(a) + << "_ref);\n"; } return success(); } - LogicalResult TfheRustBoolEmitter::printOperation(AndOp op) { return printSksMethod(op.getResult(), op.getServerKey(), {op.getLhs(), op.getRhs()}, "and"); @@ -295,12 +329,13 @@ LogicalResult TfheRustBoolEmitter::printOperation(AndPackedOp op) { os << "let " << variableNames->getNameForValue(op.getResult()) << "_ref = "; std::string_view opName = "and_packed"; - os << variableNames->getNameForValue(op.getServerKey()) << "." << opName << "("; + os << variableNames->getNameForValue(op.getServerKey()) << "." << opName + << "("; os << commaSeparatedValues({op.getLhs(), op.getRhs()}, [&](Value value) { auto *prefix = value.getType().hasTrait() ? "&" : ""; // First check if a DefiningOp exists // if not: comes from function definition - + mlir::Operation *opParent = value.getDefiningOp(); if (opParent) { prefix = isa(opParent) ? "" : prefix; @@ -310,30 +345,34 @@ LogicalResult TfheRustBoolEmitter::printOperation(AndPackedOp op) { prefix = opName.find("packed") ? "&" : prefix; - return prefix + variableNames->getNameForValue(value); }); os << ");\n"; - os << "let " << variableNames->getNameForValue(op.getResult()) << ": Vec<&Ciphertext> = " - << variableNames->getNameForValue(op.getResult()) << "_ref.iter().collect();\n"; + os << "let " << variableNames->getNameForValue(op.getResult()) + << ": Vec<&Ciphertext> = " + << variableNames->getNameForValue(op.getResult()) + << "_ref.iter().collect();\n"; return success(); } LogicalResult TfheRustBoolEmitter::printOperation(XorPackedOp op) { // os << "let " << variableNames->getNameForValue(op.getLhs()) << " = " - // << variableNames->getNameForValue(op.getLhs()) << ".iter().collect();\n"; + // << variableNames->getNameForValue(op.getLhs()) << + // ".iter().collect();\n"; // os << "let " << variableNames->getNameForValue(op.getRhs()) << " = " - // << variableNames->getNameForValue(op.getRhs()) << ".iter().collect();\n"; + // << variableNames->getNameForValue(op.getRhs()) << + // ".iter().collect();\n"; os << "let " << variableNames->getNameForValue(op.getResult()) << "_ref = "; std::string_view opName = "xor_packed"; - os << variableNames->getNameForValue(op.getServerKey()) << "." << opName << "("; + os << variableNames->getNameForValue(op.getServerKey()) << "." << opName + << "("; os << commaSeparatedValues({op.getLhs(), op.getRhs()}, [&](Value value) { auto *prefix = value.getType().hasTrait() ? "&" : ""; // First check if a DefiningOp exists // if not: comes from function definition - + mlir::Operation *opParent = value.getDefiningOp(); if (opParent) { prefix = isa(opParent) ? "" : prefix; @@ -343,13 +382,14 @@ LogicalResult TfheRustBoolEmitter::printOperation(XorPackedOp op) { prefix = opName.find("packed") ? "&" : prefix; - return prefix + variableNames->getNameForValue(value); }); os << ");\n"; - os << "let " << variableNames->getNameForValue(op.getResult()) << ": Vec<&Ciphertext> = " - << variableNames->getNameForValue(op.getResult()) << "_ref.iter().collect();\n"; + os << "let " << variableNames->getNameForValue(op.getResult()) + << ": Vec<&Ciphertext> = " + << variableNames->getNameForValue(op.getResult()) + << "_ref.iter().collect();\n"; return success(); } diff --git a/tests/tfhe_rust_bool/emit_tfhe_rust_bool.mlir b/tests/tfhe_rust_bool/emit_tfhe_rust_bool.mlir index 1a96d3b25..e6dd85760 100644 --- a/tests/tfhe_rust_bool/emit_tfhe_rust_bool.mlir +++ b/tests/tfhe_rust_bool/emit_tfhe_rust_bool.mlir @@ -8,7 +8,7 @@ // CHECK-NEXT: [[input1:v[0-9]+]]: &Ciphertext, // CHECK-NEXT: [[input2:v[0-9]+]]: &Ciphertext, // CHECK-NEXT: ) -> Ciphertext { -// CHECK-NEXT: let [[v0:.*]] = [[bsks]].and([[input1]], [[input2]]); +// CHECK-NEXT: let [[v0:.*]] = [[bsks]].and(&[[input1]], &[[input2]]); // CHECK-NEXT: [[v0]] // CHECK-NEXT: } func.func @test_and(%bsks : !bsks, %input1 : !eb, %input2 : !eb) -> !eb { diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/src/fn_under_test.rs b/tests/tfhe_rust_bool/end_to_end_fpga/src/fn_under_test.rs index 7ed9b2548..2cc7e1749 100644 --- a/tests/tfhe_rust_bool/end_to_end_fpga/src/fn_under_test.rs +++ b/tests/tfhe_rust_bool/end_to_end_fpga/src/fn_under_test.rs @@ -14,6 +14,168 @@ use tfhe::boolean::prelude::*; use tfhe::boolean::prelude::*; + +pub fn fn_under_test_fpga( + v0: &ServerKey, + v1: &Vec, + v2: &Vec, +) -> Vec { + let v3 = 7; + let v4 = 6; + let v5 = 5; + let v6 = 4; + let v7 = 3; + let v8 = 2; + let v9 = 1; + let v10 = 0; + let v11 = &v1[v10]; + let v12 = &v1[v9]; + let v13 = &v1[v8]; + let v14 = &v1[v7]; + let v15 = &v1[v6]; + let v16 = &v1[v5]; + let v17 = &v1[v4]; + let v18 = &v1[v3]; + let v19 = &v2[v10]; + let v20 = &v2[v9]; + let v21 = &v2[v8]; + let v22 = &v2[v7]; + let v23 = &v2[v6]; + let v24 = &v2[v5]; + let v25 = &v2[v4]; + let v26 = &v2[v3]; + let v27 = v0.and(v11, v19); + let v28 = v0.xor(v12, v20); + let v29 = vec![&v28, v12]; + let v30 = vec![&v27, v20]; + let v31_ref = v0.and_packed(&v29, &v30); + let v31: Vec<&Ciphertext> = v31_ref.iter().collect(); + let v32 = 0; + let v33 = v31[v32]; + let v34 = 1; + let v35 = v31[v34]; + let v36 = vec![v13, v35]; + let v37 = vec![v21, v33]; + let v38_ref = v0.and_packed(&v36, &v37); + let v38: Vec<&Ciphertext> = v38_ref.iter().collect(); + let v39 = 0; + let v40 = v38[v39]; + let v41 = 1; + let v42 = v38[v41]; + let v43 = vec![v13, v40]; + let v44 = vec![v21, v42]; + let v45_ref = v0.and_packed(&v43, &v44); + let v45: Vec<&Ciphertext> = v45_ref.iter().collect(); + let v46 = 0; + let v47 = v45[v46]; + let v48 = 1; + let v49 = v45[v48]; + let v50 = vec![v47, v14]; + let v51 = vec![v49, v22]; + let v52_ref = v0.and_packed(&v50, &v51); + let v52: Vec<&Ciphertext> = v52_ref.iter().collect(); + let v53 = 0; + let v54 = v52[v53]; + let v55 = 1; + let v56 = v52[v55]; + let v57 = vec![v56, v14]; + let v58 = vec![v54, v22]; + let v59_ref = v0.and_packed(&v57, &v58); + let v59: Vec<&Ciphertext> = v59_ref.iter().collect(); + let v60 = 0; + let v61 = v59[v60]; + let v62 = 1; + let v63 = v59[v62]; + let v64 = vec![v15, v63]; + let v65 = vec![v23, v61]; + let v66_ref = v0.and_packed(&v64, &v65); + let v66: Vec<&Ciphertext> = v66_ref.iter().collect(); + let v67 = 0; + let v68 = v66[v67]; + let v69 = 1; + let v70 = v66[v69]; + let v71 = vec![v15, v68]; + let v72 = vec![v23, v70]; + let v73_ref = v0.and_packed(&v71, &v72); + let v73: Vec<&Ciphertext> = v73_ref.iter().collect(); + let v74 = 0; + let v75 = v73[v74]; + let v76 = 1; + let v77 = v73[v76]; + let v78 = vec![v16, v75]; + let v79 = vec![v24, v77]; + let v80_ref = v0.and_packed(&v78, &v79); + let v80: Vec<&Ciphertext> = v80_ref.iter().collect(); + let v81 = 0; + let v82 = v80[v81]; + let v83 = 1; + let v84 = v80[v83]; + let v85 = vec![v16, v82]; + let v86 = vec![v24, v84]; + let v87_ref = v0.and_packed(&v85, &v86); + let v87: Vec<&Ciphertext> = v87_ref.iter().collect(); + let v88 = 0; + let v89 = v87[v88]; + let v90 = 1; + let v91 = v87[v90]; + let v92 = 0; + let v93 = 1; + let v94 = 2; + let v95 = 3; + let v96 = 4; + let v97 = 5; + let v98 = 6; + let v99 = 7; + let v100 = vec![v89, v17]; + let v101 = vec![v91, v25]; + let v102_ref = v0.and_packed(&v100, &v101); + let v102: Vec<&Ciphertext> = v102_ref.iter().collect(); + let v103 = 0; + let v104 = v102[v103]; + let v105 = 1; + let v106 = v102[v105]; + let v107 = vec![v106, v17]; + let v108 = vec![v104, v25]; + let v109_ref = v0.and_packed(&v107, &v108); + let v109: Vec<&Ciphertext> = v109_ref.iter().collect(); + let v110 = 0; + let v111 = v109[v110]; + let v112 = 1; + let v113 = v109[v112]; + let v114 = vec![v18, v113]; + let v115 = vec![v26, v111]; + let v116_ref = v0.and_packed(&v114, &v115); + let v116: Vec<&Ciphertext> = v116_ref.iter().collect(); + let v117 = 0; + let v118 = v116[v117]; + let v119 = 1; + let v120 = v116[v119]; + let v121 = vec![v118, v18]; + let v122 = vec![v120, v26]; + let v123_ref = v0.and_packed(&v121, &v122); + let v123: Vec<&Ciphertext> = v123_ref.iter().collect(); + let v124 = 0; + let v125 = v123[v124]; + let v126 = 1; + let v127 = v123[v126]; + let v128 = vec![&v28, v118, v40, v68, v106, v56, v11, v82]; + let v129 = vec![&v27, v120, v42, v70, v104, v54, v19, v84]; + let v130_ref = v0.and_packed(&v128, &v129); + let v130: Vec<&Ciphertext> = v130_ref.iter().collect(); + let v131 = v130_ref[v92].clone(); + let v132 = v130_ref[v93].clone(); + let v133 = v130_ref[v94].clone(); + let v134 = v130_ref[v95].clone(); + let v135 = v130_ref[v96].clone(); + let v136 = v130_ref[v97].clone(); + let v137 = v130_ref[v98].clone(); + let v138 = v130_ref[v99].clone(); + let v139 = vec![v132, v135, v138, v134, v136, v133, v131, v137]; + v139 +} + +// use tfhe::boolean::prelude::*; + pub fn fn_under_test( v0: &ServerKey, v1: &Vec, @@ -121,4 +283,195 @@ pub fn fn_under_test( v61.extend(v32_ref); v61.extend(v27_ref); v61 +} + + +pub fn add_bool( + v0: &ServerKey, + v1: &Vec, + v2: &Vec, +) -> Vec { + let v3 = 7; + let v4 = 6; + let v5 = 5; + let v6 = 4; + let v7 = 3; + let v8 = 2; + let v9 = 1; + let v10 = 0; + let v11 = &v1[v10]; + let v12 = &v1[v9]; + let v13 = &v1[v8]; + let v14 = &v1[v7]; + let v15 = &v1[v6]; + let v16 = &v1[v5]; + let v17 = &v1[v4]; + let v18 = &v1[v3]; + let v19 = &v2[v10]; + let v20 = &v2[v9]; + let v21 = &v2[v8]; + let v22 = &v2[v7]; + let v23 = &v2[v6]; + let v24 = &v2[v5]; + let v25 = &v2[v4]; + let v26 = &v2[v3]; + let v27 = v0.and(v11, v19); + let v28 = v0.xor(v12, v20); + let v29 = vec![v12, &v28]; + let v30 = vec![v20, &v27]; + let v31_ref = v0.and_packed(&v29, &v30); + let v31: Vec<&Ciphertext> = v31_ref.iter().collect(); + let v32 = &v31[v10]; + let v33 = &v31[v9]; + let v34 = vec![v32, v13]; + let v35 = vec![v33, v21]; + let v36_ref = v0.and_packed(&v34, &v35); + let v36: Vec<&Ciphertext> = v36_ref.iter().collect(); + let v37 = &v36[v10]; + let v38 = &v36[v9]; + let v39 = vec![v13, v38]; + let v40 = vec![v21, v37]; + let v41_ref = v0.and_packed(&v39, &v40); + let v41: Vec<&Ciphertext> = v41_ref.iter().collect(); + let v42 = &v41[v10]; + let v43 = &v41[v9]; + let v44 = vec![v42, v14]; + let v45 = vec![v43, v22]; + let v46_ref = v0.and_packed(&v44, &v45); + let v46: Vec<&Ciphertext> = v46_ref.iter().collect(); + let v47 = &v46[v10]; + let v48 = &v46[v9]; + let v49 = vec![v48, v14]; + let v50 = vec![v47, v22]; + let v51_ref = v0.and_packed(&v49, &v50); + let v51: Vec<&Ciphertext> = v51_ref.iter().collect(); + let v52 = &v51[v10]; + let v53 = &v51[v9]; + let v54 = vec![v53, v15]; + let v55 = vec![v52, v23]; + let v56_ref = v0.and_packed(&v54, &v55); + let v56: Vec<&Ciphertext> = v56_ref.iter().collect(); + let v57 = &v56[v10]; + let v58 = &v56[v9]; + let v59 = vec![v58, v15]; + let v60 = vec![v57, v23]; + let v61_ref = v0.and_packed(&v59, &v60); + let v61: Vec<&Ciphertext> = v61_ref.iter().collect(); + let v62 = &v61[v10]; + let v63 = &v61[v9]; + let v64 = vec![v63, v16]; + let v65 = vec![v62, v24]; + let v66_ref = v0.and_packed(&v64, &v65); + let v66: Vec<&Ciphertext> = v66_ref.iter().collect(); + let v67 = &v66[v10]; + let v68 = &v66[v9]; + let v69 = vec![v16, v68]; + let v70 = vec![v24, v67]; + let v71_ref = v0.and_packed(&v69, &v70); + let v71: Vec<&Ciphertext> = v71_ref.iter().collect(); + let v72 = &v71[v10]; + let v73 = &v71[v9]; + let v74 = vec![v17, v72]; + let v75 = vec![v25, v73]; + let v76_ref = v0.and_packed(&v74, &v75); + let v76: Vec<&Ciphertext> = v76_ref.iter().collect(); + let v77 = &v76[v10]; + let v78 = &v76[v9]; + let v79 = vec![v77, v17]; + let v80 = vec![v78, v25]; + let v81_ref = v0.and_packed(&v79, &v80); + let v81: Vec<&Ciphertext> = v81_ref.iter().collect(); + let v82 = &v81[v10]; + let v83 = &v81[v9]; + let v84 = vec![v18, v83]; + let v85 = vec![v26, v82]; + let v86_ref = v0.and_packed(&v84, &v85); + let v86: Vec<&Ciphertext> = v86_ref.iter().collect(); + let v87 = &v86[v10]; + let v88 = &v86[v9]; + let v89 = vec![v18, v87]; + let v90 = vec![v26, v88]; + let v91_ref = v0.and_packed(&v89, &v90); + let v91: Vec<&Ciphertext> = v91_ref.iter().collect(); + let v92 = vec![v68, v58, v48, v38, v28, v87, v77, v11]; + let v93 = vec![v67, v57, v47, v37, v27, v88, v78, v19]; + let v94_ref = v0.and_packed(&v92, &v93); + let v94: Vec<&Ciphertext> = v94_ref.iter().collect(); + let v95 = &v94[v10]; + let v96 = &v94[v9]; + let v97 = &v94[v8]; + let v98 = &v94[v7]; + let v99 = &v94[v6]; + let v100 = &v94[v5]; + let v101 = &v94[v4]; + let v102 = &v94[v3]; + let v103 = vec![v100, v101, v95, v96, v97, v98, v99, v102]; + v103 +} + +pub fn fn_under_testtt( + v0: &ServerKey, + v1: &Vec, + v2: &Vec, +) -> Vec { + let v3 = 7; + let v4 = 6; + let v5 = 5; + let v6 = 4; + let v7 = 3; + let v8 = 2; + let v9 = 1; + let v10 = 0; + let v11 = v1[v10]; + let v12 = v1[v9]; + let v13 = v1[v8]; + let v14 = v1[v7]; + let v15 = v1[v6]; + let v16 = v1[v5]; + let v17 = v1[v4]; + let v18 = v1[v3]; + let v19 = v2[v10]; + let v20 = v2[v9]; + let v21 = v2[v8]; + let v22 = v2[v7]; + let v23 = v2[v6]; + let v24 = v2[v5]; + let v25 = v2[v4]; + let v26 = v2[v3]; + let v27 = v0.xor(&v11, &v19); + let v28 = v0.and(&v11, &v19); + let v29 = v0.xor(&v12, &v20); + let v30 = v0.and(&v12, &v20); + let v31 = v0.and(&v29, &v28); + let v32 = v0.xor(&v29, &v28); + let v33 = v0.xor(&v30, &v31); + let v34 = v0.xor(&v13, &v21); + let v35 = v0.and(&v13, &v21); + let v36 = v0.and(&v34, &v33); + let v37 = v0.xor(&v34, &v33); + let v38 = v0.xor(&v35, &v36); + let v39 = v0.xor(&v14, &v22); + let v40 = v0.and(&v14, &v22); + let v41 = v0.and(&v39, &v38); + let v42 = v0.xor(&v39, &v38); + let v43 = v0.xor(&v40, &v41); + let v44 = v0.xor(&v15, &v23); + let v45 = v0.and(&v15, &v23); + let v46 = v0.and(&v44, &v43); + let v47 = v0.xor(&v44, &v43); + let v48 = v0.xor(&v45, &v46); + let v49 = v0.xor(&v16, &v24); + let v50 = v0.and(&v16, &v24); + let v51 = v0.and(&v49, &v48); + let v52 = v0.xor(&v49, &v48); + let v53 = v0.xor(&v50, &v51); + let v54 = v0.xor(&v17, &v25); + let v55 = v0.and(&v17, &v25); + let v56 = v0.and(&v54, &v53); + let v57 = v0.xor(&v54, &v53); + let v58 = v0.xor(&v55, &v56); + let v59 = v0.xor(&v18, &v26); + let v60 = v0.xor(&v59, &v58); + let v61 = vec![v60, v57, v52, v47, v42, v37, v32, v27]; + v61 } \ No newline at end of file diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs b/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs index 1a10fadf0..a62295f40 100644 --- a/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs +++ b/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs @@ -75,7 +75,11 @@ fn main() { // let ct_1= ct_1.into_iter().collect(); // let ct_2= ct_2.into_iter().collect(); - let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2); + let t = Instant::now(); + let result = fn_under_test::fn_under_testtt(&server_key, &ct_1, &ct_2); + let run = t.elapsed().as_millis(); + + println!("{:?}", run); let output = decrypt(&result, &client_key); diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir b/tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir index 74c7ee6d4..0e4f71bd3 100644 --- a/tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir +++ b/tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir @@ -68,7 +68,7 @@ func.func @fn_under_test(%bsks : !bsks, %arg0: tensor<8x!eb>, %arg1: tensor<8x! %fa5_c = tfhe_rust_bool.xor_packed %bsks, %fa5_2, %fa5_3 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> %fa6_1 = tfhe_rust_bool.xor_packed %bsks, %extracted_07, %extracted_17 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> %fa6_s = tfhe_rust_bool.xor_packed %bsks, %fa6_1, %fa5_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %from_elements = tensor.concat dim(0) %fa6_s, %fa5_s, %fa4_s, %fa3_s, %fa2_s, %fa1_s, %fa0_s, %ha_s + %from_elements = tensor.concat dim(0) %fa6_s, %fa5_s, %fa4_s, %fa3_s, %fa2_s, %fa1_s, %fa0_s, %ha_s : (tensor<1x!eb>, tensor<1x!eb>, tensor<1x!eb>, tensor<1x!eb>, tensor<1x!eb>, tensor<1x!eb>, tensor<1x!eb>, tensor<1x!eb>) -> tensor<8x!eb> return %from_elements : tensor<8x!eb> } From 4dd34fad18fd1813174610848d5cb306830973f3 Mon Sep 17 00:00:00 2001 From: Wouter Legiest Date: Sun, 24 Mar 2024 22:30:25 +0100 Subject: [PATCH 09/11] Updated tfhe-rs-bool emitter with collaboration with straight-line-vectorize --- .../TfheRustBool/IR/TfheRustBoolOps.td | 35 -- .../Target/TfheRustBool/TfheRustBoolEmitter.h | 3 - .../TfheRustBool/TfheRustBoolEmitter.cpp | 165 +++--- tests/tfhe_rust_bool/emit_tfhe_rust_bool.mlir | 2 +- tests/tfhe_rust_bool/end_to_end_fpga/BUILD | 1 + .../tfhe_rust_bool/end_to_end_fpga/README.md | 6 + .../end_to_end_fpga/src/fn_under_test.rs | 477 ------------------ .../end_to_end_fpga/src/main.rs | 4 +- .../end_to_end_fpga/test_add_one_bool.mlir | 74 --- .../end_to_end_fpga/test_cggi_add_bool.mlir | 77 +++ .../end_to_end_fpga/test_packed_and.mlir | 2 +- tests/tfhe_rust_bool/ops.mlir | 2 +- 12 files changed, 146 insertions(+), 702 deletions(-) delete mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/src/fn_under_test.rs delete mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir create mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/test_cggi_add_bool.mlir diff --git a/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td b/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td index 17d24d4b4..f64bfced7 100644 --- a/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td +++ b/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td @@ -47,41 +47,6 @@ def NorOp : TfheRustBool_BinaryGateOp<"nor"> { let summary = "Logical NOR of t def XorOp : TfheRustBool_BinaryGateOp<"xor"> { let summary = "Logical XOR of two TFHE-rs Bool ciphertexts."; } def XnorOp : TfheRustBool_BinaryGateOp<"xnor"> { let summary = "Logical XNOR of two TFHE-rs Bool ciphertexts."; } -// def TfheRustBool_Ops : -// AnyTypeOp<[ -// AndOp, -// NandOp, -// OrOp, -// NorOp, -// XorOp, -// XnorOp, -// ]>; - - -def AndPackedOp : TfheRustBool_Op<"and_packed", [ - Pure, - AllTypesMatch<["lhs", "rhs", "output"]> -]> { - let arguments = (ins - TfheRustBool_ServerKey:$serverKey, - TensorOf<[TfheRustBool_Encrypted]>:$lhs, - TensorOf<[TfheRustBool_Encrypted]>:$rhs - ); - let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output); -} - -def XorPackedOp : TfheRustBool_Op<"xor_packed", [ - Pure, - AllTypesMatch<["lhs", "rhs", "output"]> -]> { - let arguments = (ins - TfheRustBool_ServerKey:$serverKey, - TensorOf<[TfheRustBool_Encrypted]>:$lhs, - TensorOf<[TfheRustBool_Encrypted]>:$rhs - ); - let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output); -} - def NotOp : TfheRustBool_Op<"not", [ Pure, AllTypesMatch<["input", "output"]> diff --git a/include/Target/TfheRustBool/TfheRustBoolEmitter.h b/include/Target/TfheRustBool/TfheRustBoolEmitter.h index 8a60783e0..512dba039 100644 --- a/include/Target/TfheRustBool/TfheRustBoolEmitter.h +++ b/include/Target/TfheRustBool/TfheRustBoolEmitter.h @@ -61,9 +61,6 @@ class TfheRustBoolEmitter { LogicalResult printOperation(XorOp op); LogicalResult printOperation(XnorOp op); - LogicalResult printOperation(AndPackedOp op); - LogicalResult printOperation(XorPackedOp op); - // Helpers for above LogicalResult printSksMethod(::mlir::Value result, ::mlir::Value sks, ::mlir::ValueRange nonSksOperands, diff --git a/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp b/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp index 3e2269f9c..3f0d3dbc1 100644 --- a/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp +++ b/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp @@ -64,8 +64,8 @@ LogicalResult TfheRustBoolEmitter::translate(Operation &op) { // Arith ops .Case([&](auto op) { return printOperation(op); }) // TfheRustBool ops - .Case([&](auto op) { return printOperation(op); }) + .Case( + [&](auto op) { return printOperation(op); }) // Tensor ops .Case( @@ -149,6 +149,9 @@ LogicalResult TfheRustBoolEmitter::printOperation(func::ReturnOp op) { if (isa(value)) { cloneStr = ".clone()"; } + if (isa(value.getDefiningOp())) { + cloneStr = ".into_iter().cloned().collect()"; + } return variableNames->getNameForValue(value) + cloneStr; }; @@ -169,59 +172,71 @@ void TfheRustBoolEmitter::emitAssignPrefix(Value result) { LogicalResult TfheRustBoolEmitter::printSksMethod( ::mlir::Value result, ::mlir::Value sks, ::mlir::ValueRange nonSksOperands, std::string_view op, SmallVector operandTypes) { - if (isa(nonSksOperands[0].getType())) { - os << "let " << variableNames->getNameForValue(result) << "_ref = "; - std::string_view opName = "and_packed"; + mlir::Operation *opParent = nonSksOperands[0].getDefiningOp(); + if (!opParent) { + os << "let " << variableNames->getNameForValue(nonSksOperands[0]) + << "_ref = " << variableNames->getNameForValue(nonSksOperands[0]) + << ".clone();\n"; + os << "let " << variableNames->getNameForValue(nonSksOperands[0]) + << "_ref: Vec<&Ciphertext> = " + << variableNames->getNameForValue(nonSksOperands[0]) + << ".iter().collect();\n"; + os << "let " << variableNames->getNameForValue(nonSksOperands[1]) + << "_ref = " << variableNames->getNameForValue(nonSksOperands[1]) + << ".clone();\n"; + os << "let " << variableNames->getNameForValue(nonSksOperands[1]) + << "_ref: Vec<&Ciphertext> = " + << variableNames->getNameForValue(nonSksOperands[1]) + << ".iter().collect();\n"; + } - os << variableNames->getNameForValue(sks) << "." << opName << "("; + emitAssignPrefix(result); + + os << variableNames->getNameForValue(sks) << "." << op << "_packed("; os << commaSeparatedValues( {nonSksOperands[0], nonSksOperands[1]}, [&](Value value) { auto *prefix = value.getType().hasTrait() ? "&" : ""; + auto suffix = ""; // First check if a DefiningOp exists // if not: comes from function definition - - // getDefiningOp look for a gedefining op using a specific - // type mlir::Operation *opParent = value.getDefiningOp(); if (opParent) { - prefix = !isa(opParent) ? "" : prefix; + prefix = isa(opParent) ? prefix : ""; + prefix = + isa(value.getDefiningOp()) ? "&" : ""; } else { - prefix = ""; + prefix = "&"; + suffix = "_ref"; } - prefix = opName.find("packed") ? "&" : prefix; - - return prefix + variableNames->getNameForValue(value); + return prefix + variableNames->getNameForValue(value) + suffix; }); os << ");\n"; - - os << "let " << variableNames->getNameForValue(result) - << ": Vec<&Ciphertext> = " << variableNames->getNameForValue(result) - << "_ref.iter().collect();\n"; return success(); } else { - emitAssignPrefix(result); - - auto operandTypesIt = operandTypes.begin(); - os << variableNames->getNameForValue(sks) << "." << op << "("; - os << commaSeparatedValues(nonSksOperands, [&](Value value) { - auto *prefix = value.getType().hasTrait() ? "&" : ""; - // First check if a DefiningOp exists - // if not: comes from function definition - mlir::Operation *op = value.getDefiningOp(); - if (op) { - prefix = isa(op) ? "" : prefix; - } else { - prefix = ""; - } + emitAssignPrefix(result); + + auto operandTypesIt = operandTypes.begin(); + os << variableNames->getNameForValue(sks) << "." << op << "("; + os << commaSeparatedValues(nonSksOperands, [&](Value value) { + // ToDo: can be removed? + auto *prefix = value.getType().hasTrait() ? "&" : ""; + // First check if a DefiningOp exists + // if not: comes from function definition + mlir::Operation *op = value.getDefiningOp(); + if (op) { + prefix = isa(op) ? "" : prefix; + } else { + prefix = ""; + } - return prefix + variableNames->getNameForValue(value) + - (!operandTypes.empty() ? " as " + *operandTypesIt++ : ""); - }); - os << ");\n"; - return success(); + return prefix + variableNames->getNameForValue(value) + + (!operandTypes.empty() ? " as " + *operandTypesIt++ : ""); + }); + os << ");\n"; + return success(); } } @@ -274,11 +289,13 @@ LogicalResult TfheRustBoolEmitter::printOperation(tensor::FromElementsOp op) { emitAssignPrefix(op.getResult()); os << "vec![" << commaSeparatedValues(op.getOperands(), [&](Value value) { // Check if block argument, if so, clone. - auto cloneStr = isa(value) ? ".clone()": ""; + auto cloneStr = isa(value) ? ".clone()" : ""; // Get the name of defining operation its dialect - auto tfhe_op = value.getDefiningOp()->getDialect()->getNamespace() == "tfhe_rust_bool"; + auto tfhe_op = + value.getDefiningOp()->getDialect()->getNamespace() == "tfhe_rust_bool"; auto prefix = tfhe_op ? "&" : ""; - return std::string(prefix) + variableNames->getNameForValue(value) + cloneStr; + return std::string(prefix) + variableNames->getNameForValue(value) + + cloneStr; }) << "];\n"; return success(); } @@ -325,74 +342,6 @@ LogicalResult TfheRustBoolEmitter::printOperation(XnorOp op) { {op.getLhs(), op.getRhs()}, "xnor"); } -LogicalResult TfheRustBoolEmitter::printOperation(AndPackedOp op) { - os << "let " << variableNames->getNameForValue(op.getResult()) << "_ref = "; - std::string_view opName = "and_packed"; - - os << variableNames->getNameForValue(op.getServerKey()) << "." << opName - << "("; - os << commaSeparatedValues({op.getLhs(), op.getRhs()}, [&](Value value) { - auto *prefix = value.getType().hasTrait() ? "&" : ""; - // First check if a DefiningOp exists - // if not: comes from function definition - - mlir::Operation *opParent = value.getDefiningOp(); - if (opParent) { - prefix = isa(opParent) ? "" : prefix; - } else { - prefix = ""; - } - - prefix = opName.find("packed") ? "&" : prefix; - - return prefix + variableNames->getNameForValue(value); - }); - os << ");\n"; - - os << "let " << variableNames->getNameForValue(op.getResult()) - << ": Vec<&Ciphertext> = " - << variableNames->getNameForValue(op.getResult()) - << "_ref.iter().collect();\n"; - return success(); -} - -LogicalResult TfheRustBoolEmitter::printOperation(XorPackedOp op) { - // os << "let " << variableNames->getNameForValue(op.getLhs()) << " = " - // << variableNames->getNameForValue(op.getLhs()) << - // ".iter().collect();\n"; - // os << "let " << variableNames->getNameForValue(op.getRhs()) << " = " - // << variableNames->getNameForValue(op.getRhs()) << - // ".iter().collect();\n"; - os << "let " << variableNames->getNameForValue(op.getResult()) << "_ref = "; - std::string_view opName = "xor_packed"; - - os << variableNames->getNameForValue(op.getServerKey()) << "." << opName - << "("; - os << commaSeparatedValues({op.getLhs(), op.getRhs()}, [&](Value value) { - auto *prefix = value.getType().hasTrait() ? "&" : ""; - // First check if a DefiningOp exists - // if not: comes from function definition - - mlir::Operation *opParent = value.getDefiningOp(); - if (opParent) { - prefix = isa(opParent) ? "" : prefix; - } else { - prefix = ""; - } - - prefix = opName.find("packed") ? "&" : prefix; - - return prefix + variableNames->getNameForValue(value); - }); - os << ");\n"; - - os << "let " << variableNames->getNameForValue(op.getResult()) - << ": Vec<&Ciphertext> = " - << variableNames->getNameForValue(op.getResult()) - << "_ref.iter().collect();\n"; - return success(); -} - FailureOr TfheRustBoolEmitter::convertType(Type type) { // Note: these are probably not the right type names to use exactly, and they // will need to chance to the right values once we try to compile it against diff --git a/tests/tfhe_rust_bool/emit_tfhe_rust_bool.mlir b/tests/tfhe_rust_bool/emit_tfhe_rust_bool.mlir index e6dd85760..1a96d3b25 100644 --- a/tests/tfhe_rust_bool/emit_tfhe_rust_bool.mlir +++ b/tests/tfhe_rust_bool/emit_tfhe_rust_bool.mlir @@ -8,7 +8,7 @@ // CHECK-NEXT: [[input1:v[0-9]+]]: &Ciphertext, // CHECK-NEXT: [[input2:v[0-9]+]]: &Ciphertext, // CHECK-NEXT: ) -> Ciphertext { -// CHECK-NEXT: let [[v0:.*]] = [[bsks]].and(&[[input1]], &[[input2]]); +// CHECK-NEXT: let [[v0:.*]] = [[bsks]].and([[input1]], [[input2]]); // CHECK-NEXT: [[v0]] // CHECK-NEXT: } func.func @test_and(%bsks : !bsks, %input1 : !eb, %input2 : !eb) -> !eb { diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/BUILD b/tests/tfhe_rust_bool/end_to_end_fpga/BUILD index a189be648..868d9ba3c 100644 --- a/tests/tfhe_rust_bool/end_to_end_fpga/BUILD +++ b/tests/tfhe_rust_bool/end_to_end_fpga/BUILD @@ -12,6 +12,7 @@ glob_lit_tests( data = [ "Cargo.toml", "src/main.rs", + "tfhe-rs", "@heir//tests:test_utilities", ], default_tags = [ diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/README.md b/tests/tfhe_rust_bool/end_to_end_fpga/README.md index 5d6021216..fde92ec51 100644 --- a/tests/tfhe_rust_bool/end_to_end_fpga/README.md +++ b/tests/tfhe_rust_bool/end_to_end_fpga/README.md @@ -23,6 +23,12 @@ bazel query "filter('.mlir.test$', //tests/tfhe_rust_bool/end_to_end_fpga/...)" | xargs bazel test --sandbox_writable_path=$HOME/.cargo "$@" ``` +Manually generate the Rust code fort the CGGI lowering: +```bash +bazel run //tools:heir-opt -- -cse --straight-line-vectorize --cggi-to-tfhe-rust-bool -cse $(pwd)/tests/cggi_to_tfhe_rust_bool/add_bool.mlir | bazel run //tools:heir-translate -- --emit-tfhe-rust-bool +``` + + The `manual` tag is added to the targets in this directory to ensure that they are not run when someone runs a glob test like `bazel test //...`. diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/src/fn_under_test.rs b/tests/tfhe_rust_bool/end_to_end_fpga/src/fn_under_test.rs deleted file mode 100644 index 2cc7e1749..000000000 --- a/tests/tfhe_rust_bool/end_to_end_fpga/src/fn_under_test.rs +++ /dev/null @@ -1,477 +0,0 @@ -use tfhe::boolean::prelude::*; - - -// pub fn fn_under_test( -// v0: &ServerKey, -// v1: &Vec, -// v2: &Vec, -// ) -> Vec { -// let v1 = v1.iter().collect(); -// let v2 = v2.iter().collect(); -// let v3 = v0.xor_packed(&v1, &v2); -// v3 -// } - -use tfhe::boolean::prelude::*; - - -pub fn fn_under_test_fpga( - v0: &ServerKey, - v1: &Vec, - v2: &Vec, -) -> Vec { - let v3 = 7; - let v4 = 6; - let v5 = 5; - let v6 = 4; - let v7 = 3; - let v8 = 2; - let v9 = 1; - let v10 = 0; - let v11 = &v1[v10]; - let v12 = &v1[v9]; - let v13 = &v1[v8]; - let v14 = &v1[v7]; - let v15 = &v1[v6]; - let v16 = &v1[v5]; - let v17 = &v1[v4]; - let v18 = &v1[v3]; - let v19 = &v2[v10]; - let v20 = &v2[v9]; - let v21 = &v2[v8]; - let v22 = &v2[v7]; - let v23 = &v2[v6]; - let v24 = &v2[v5]; - let v25 = &v2[v4]; - let v26 = &v2[v3]; - let v27 = v0.and(v11, v19); - let v28 = v0.xor(v12, v20); - let v29 = vec![&v28, v12]; - let v30 = vec![&v27, v20]; - let v31_ref = v0.and_packed(&v29, &v30); - let v31: Vec<&Ciphertext> = v31_ref.iter().collect(); - let v32 = 0; - let v33 = v31[v32]; - let v34 = 1; - let v35 = v31[v34]; - let v36 = vec![v13, v35]; - let v37 = vec![v21, v33]; - let v38_ref = v0.and_packed(&v36, &v37); - let v38: Vec<&Ciphertext> = v38_ref.iter().collect(); - let v39 = 0; - let v40 = v38[v39]; - let v41 = 1; - let v42 = v38[v41]; - let v43 = vec![v13, v40]; - let v44 = vec![v21, v42]; - let v45_ref = v0.and_packed(&v43, &v44); - let v45: Vec<&Ciphertext> = v45_ref.iter().collect(); - let v46 = 0; - let v47 = v45[v46]; - let v48 = 1; - let v49 = v45[v48]; - let v50 = vec![v47, v14]; - let v51 = vec![v49, v22]; - let v52_ref = v0.and_packed(&v50, &v51); - let v52: Vec<&Ciphertext> = v52_ref.iter().collect(); - let v53 = 0; - let v54 = v52[v53]; - let v55 = 1; - let v56 = v52[v55]; - let v57 = vec![v56, v14]; - let v58 = vec![v54, v22]; - let v59_ref = v0.and_packed(&v57, &v58); - let v59: Vec<&Ciphertext> = v59_ref.iter().collect(); - let v60 = 0; - let v61 = v59[v60]; - let v62 = 1; - let v63 = v59[v62]; - let v64 = vec![v15, v63]; - let v65 = vec![v23, v61]; - let v66_ref = v0.and_packed(&v64, &v65); - let v66: Vec<&Ciphertext> = v66_ref.iter().collect(); - let v67 = 0; - let v68 = v66[v67]; - let v69 = 1; - let v70 = v66[v69]; - let v71 = vec![v15, v68]; - let v72 = vec![v23, v70]; - let v73_ref = v0.and_packed(&v71, &v72); - let v73: Vec<&Ciphertext> = v73_ref.iter().collect(); - let v74 = 0; - let v75 = v73[v74]; - let v76 = 1; - let v77 = v73[v76]; - let v78 = vec![v16, v75]; - let v79 = vec![v24, v77]; - let v80_ref = v0.and_packed(&v78, &v79); - let v80: Vec<&Ciphertext> = v80_ref.iter().collect(); - let v81 = 0; - let v82 = v80[v81]; - let v83 = 1; - let v84 = v80[v83]; - let v85 = vec![v16, v82]; - let v86 = vec![v24, v84]; - let v87_ref = v0.and_packed(&v85, &v86); - let v87: Vec<&Ciphertext> = v87_ref.iter().collect(); - let v88 = 0; - let v89 = v87[v88]; - let v90 = 1; - let v91 = v87[v90]; - let v92 = 0; - let v93 = 1; - let v94 = 2; - let v95 = 3; - let v96 = 4; - let v97 = 5; - let v98 = 6; - let v99 = 7; - let v100 = vec![v89, v17]; - let v101 = vec![v91, v25]; - let v102_ref = v0.and_packed(&v100, &v101); - let v102: Vec<&Ciphertext> = v102_ref.iter().collect(); - let v103 = 0; - let v104 = v102[v103]; - let v105 = 1; - let v106 = v102[v105]; - let v107 = vec![v106, v17]; - let v108 = vec![v104, v25]; - let v109_ref = v0.and_packed(&v107, &v108); - let v109: Vec<&Ciphertext> = v109_ref.iter().collect(); - let v110 = 0; - let v111 = v109[v110]; - let v112 = 1; - let v113 = v109[v112]; - let v114 = vec![v18, v113]; - let v115 = vec![v26, v111]; - let v116_ref = v0.and_packed(&v114, &v115); - let v116: Vec<&Ciphertext> = v116_ref.iter().collect(); - let v117 = 0; - let v118 = v116[v117]; - let v119 = 1; - let v120 = v116[v119]; - let v121 = vec![v118, v18]; - let v122 = vec![v120, v26]; - let v123_ref = v0.and_packed(&v121, &v122); - let v123: Vec<&Ciphertext> = v123_ref.iter().collect(); - let v124 = 0; - let v125 = v123[v124]; - let v126 = 1; - let v127 = v123[v126]; - let v128 = vec![&v28, v118, v40, v68, v106, v56, v11, v82]; - let v129 = vec![&v27, v120, v42, v70, v104, v54, v19, v84]; - let v130_ref = v0.and_packed(&v128, &v129); - let v130: Vec<&Ciphertext> = v130_ref.iter().collect(); - let v131 = v130_ref[v92].clone(); - let v132 = v130_ref[v93].clone(); - let v133 = v130_ref[v94].clone(); - let v134 = v130_ref[v95].clone(); - let v135 = v130_ref[v96].clone(); - let v136 = v130_ref[v97].clone(); - let v137 = v130_ref[v98].clone(); - let v138 = v130_ref[v99].clone(); - let v139 = vec![v132, v135, v138, v134, v136, v133, v131, v137]; - v139 -} - -// use tfhe::boolean::prelude::*; - -pub fn fn_under_test( - v0: &ServerKey, - v1: &Vec, - v2: &Vec, -) -> Vec { - let v3 = 7; - let v4 = 6; - let v5 = 5; - let v6 = 4; - let v7 = 3; - let v8 = 2; - let v9 = 1; - let v10 = 0; - let v11 = vec![&v1[0]]; - let v12 = vec![&v1[1]]; - let v13 = vec![&v1[2]]; - let v14 = vec![&v1[3]]; - let v15 = vec![&v1[4]]; - let v16 = vec![&v1[5]]; - let v17 = vec![&v1[6]]; - let v18 = vec![&v1[7]]; - let v19 = vec![&v2[0]]; - let v20 = vec![&v2[1]]; - let v21 = vec![&v2[2]]; - let v22 = vec![&v2[3]]; - let v23 = vec![&v2[4]]; - let v24 = vec![&v2[5]]; - let v25 = vec![&v2[6]]; - let v26 = vec![&v2[7]]; - let v27_ref = v0.xor_packed(&v11, &v19); - let v27: Vec<&Ciphertext> = v27_ref.iter().collect(); - let v28_ref = v0.and_packed(&v11, &v19); - let v28: Vec<&Ciphertext> = v28_ref.iter().collect(); - let v29_ref = v0.xor_packed(&v12, &v20); - let v29: Vec<&Ciphertext> = v29_ref.iter().collect(); - let v30_ref = v0.and_packed(&v12, &v20); - let v30: Vec<&Ciphertext> = v30_ref.iter().collect(); - let v31_ref = v0.and_packed(&v29, &v28); - let v31: Vec<&Ciphertext> = v31_ref.iter().collect(); - let v32_ref = v0.xor_packed(&v29, &v28); - let v32: Vec<&Ciphertext> = v32_ref.iter().collect(); - let v33_ref = v0.xor_packed(&v30, &v31); - let v33: Vec<&Ciphertext> = v33_ref.iter().collect(); - let v34_ref = v0.xor_packed(&v13, &v21); - let v34: Vec<&Ciphertext> = v34_ref.iter().collect(); - let v35_ref = v0.and_packed(&v13, &v21); - let v35: Vec<&Ciphertext> = v35_ref.iter().collect(); - let v36_ref = v0.and_packed(&v34, &v33); - let v36: Vec<&Ciphertext> = v36_ref.iter().collect(); - let v37_ref = v0.xor_packed(&v34, &v33); - let v37: Vec<&Ciphertext> = v37_ref.iter().collect(); - let v38_ref = v0.xor_packed(&v35, &v36); - let v38: Vec<&Ciphertext> = v38_ref.iter().collect(); - let v39_ref = v0.xor_packed(&v14, &v22); - let v39: Vec<&Ciphertext> = v39_ref.iter().collect(); - let v40_ref = v0.and_packed(&v14, &v22); - let v40: Vec<&Ciphertext> = v40_ref.iter().collect(); - let v41_ref = v0.and_packed(&v39, &v38); - let v41: Vec<&Ciphertext> = v41_ref.iter().collect(); - let v42_ref = v0.xor_packed(&v39, &v38); - let v42: Vec<&Ciphertext> = v42_ref.iter().collect(); - let v43_ref = v0.xor_packed(&v40, &v41); - let v43: Vec<&Ciphertext> = v43_ref.iter().collect(); - let v44_ref = v0.xor_packed(&v15, &v23); - let v44: Vec<&Ciphertext> = v44_ref.iter().collect(); - let v45_ref = v0.and_packed(&v15, &v23); - let v45: Vec<&Ciphertext> = v45_ref.iter().collect(); - let v46_ref = v0.and_packed(&v44, &v43); - let v46: Vec<&Ciphertext> = v46_ref.iter().collect(); - let v47_ref = v0.xor_packed(&v44, &v43); - let v47: Vec<&Ciphertext> = v47_ref.iter().collect(); - let v48_ref = v0.xor_packed(&v45, &v46); - let v48: Vec<&Ciphertext> = v48_ref.iter().collect(); - let v49_ref = v0.xor_packed(&v16, &v24); - let v49: Vec<&Ciphertext> = v49_ref.iter().collect(); - let v50_ref = v0.and_packed(&v16, &v24); - let v50: Vec<&Ciphertext> = v50_ref.iter().collect(); - let v51_ref = v0.and_packed(&v49, &v48); - let v51: Vec<&Ciphertext> = v51_ref.iter().collect(); - let v52_ref = v0.xor_packed(&v49, &v48); - let v52: Vec<&Ciphertext> = v52_ref.iter().collect(); - let v53_ref = v0.xor_packed(&v50, &v51); - let v53: Vec<&Ciphertext> = v53_ref.iter().collect(); - let v54_ref = v0.xor_packed(&v17, &v25); - let v54: Vec<&Ciphertext> = v54_ref.iter().collect(); - let v55_ref = v0.and_packed(&v17, &v25); - let v55: Vec<&Ciphertext> = v55_ref.iter().collect(); - let v56_ref = v0.and_packed(&v54, &v53); - let v56: Vec<&Ciphertext> = v56_ref.iter().collect(); - let v57_ref = v0.xor_packed(&v54, &v53); - let v57: Vec<&Ciphertext> = v57_ref.iter().collect(); - let v58_ref = v0.xor_packed(&v55, &v56); - let v58: Vec<&Ciphertext> = v58_ref.iter().collect(); - let v59_ref = v0.xor_packed(&v18, &v26); - let v59: Vec<&Ciphertext> = v59_ref.iter().collect(); - let v60_ref = v0.xor_packed(&v59, &v58); - let v60: Vec<&Ciphertext> = v60_ref.iter().collect(); - let mut v61: Vec = vec![]; - v61.extend(v60_ref); - v61.extend(v57_ref); - v61.extend(v52_ref); - v61.extend(v47_ref); - v61.extend(v42_ref); - v61.extend(v37_ref); - v61.extend(v32_ref); - v61.extend(v27_ref); - v61 -} - - -pub fn add_bool( - v0: &ServerKey, - v1: &Vec, - v2: &Vec, -) -> Vec { - let v3 = 7; - let v4 = 6; - let v5 = 5; - let v6 = 4; - let v7 = 3; - let v8 = 2; - let v9 = 1; - let v10 = 0; - let v11 = &v1[v10]; - let v12 = &v1[v9]; - let v13 = &v1[v8]; - let v14 = &v1[v7]; - let v15 = &v1[v6]; - let v16 = &v1[v5]; - let v17 = &v1[v4]; - let v18 = &v1[v3]; - let v19 = &v2[v10]; - let v20 = &v2[v9]; - let v21 = &v2[v8]; - let v22 = &v2[v7]; - let v23 = &v2[v6]; - let v24 = &v2[v5]; - let v25 = &v2[v4]; - let v26 = &v2[v3]; - let v27 = v0.and(v11, v19); - let v28 = v0.xor(v12, v20); - let v29 = vec![v12, &v28]; - let v30 = vec![v20, &v27]; - let v31_ref = v0.and_packed(&v29, &v30); - let v31: Vec<&Ciphertext> = v31_ref.iter().collect(); - let v32 = &v31[v10]; - let v33 = &v31[v9]; - let v34 = vec![v32, v13]; - let v35 = vec![v33, v21]; - let v36_ref = v0.and_packed(&v34, &v35); - let v36: Vec<&Ciphertext> = v36_ref.iter().collect(); - let v37 = &v36[v10]; - let v38 = &v36[v9]; - let v39 = vec![v13, v38]; - let v40 = vec![v21, v37]; - let v41_ref = v0.and_packed(&v39, &v40); - let v41: Vec<&Ciphertext> = v41_ref.iter().collect(); - let v42 = &v41[v10]; - let v43 = &v41[v9]; - let v44 = vec![v42, v14]; - let v45 = vec![v43, v22]; - let v46_ref = v0.and_packed(&v44, &v45); - let v46: Vec<&Ciphertext> = v46_ref.iter().collect(); - let v47 = &v46[v10]; - let v48 = &v46[v9]; - let v49 = vec![v48, v14]; - let v50 = vec![v47, v22]; - let v51_ref = v0.and_packed(&v49, &v50); - let v51: Vec<&Ciphertext> = v51_ref.iter().collect(); - let v52 = &v51[v10]; - let v53 = &v51[v9]; - let v54 = vec![v53, v15]; - let v55 = vec![v52, v23]; - let v56_ref = v0.and_packed(&v54, &v55); - let v56: Vec<&Ciphertext> = v56_ref.iter().collect(); - let v57 = &v56[v10]; - let v58 = &v56[v9]; - let v59 = vec![v58, v15]; - let v60 = vec![v57, v23]; - let v61_ref = v0.and_packed(&v59, &v60); - let v61: Vec<&Ciphertext> = v61_ref.iter().collect(); - let v62 = &v61[v10]; - let v63 = &v61[v9]; - let v64 = vec![v63, v16]; - let v65 = vec![v62, v24]; - let v66_ref = v0.and_packed(&v64, &v65); - let v66: Vec<&Ciphertext> = v66_ref.iter().collect(); - let v67 = &v66[v10]; - let v68 = &v66[v9]; - let v69 = vec![v16, v68]; - let v70 = vec![v24, v67]; - let v71_ref = v0.and_packed(&v69, &v70); - let v71: Vec<&Ciphertext> = v71_ref.iter().collect(); - let v72 = &v71[v10]; - let v73 = &v71[v9]; - let v74 = vec![v17, v72]; - let v75 = vec![v25, v73]; - let v76_ref = v0.and_packed(&v74, &v75); - let v76: Vec<&Ciphertext> = v76_ref.iter().collect(); - let v77 = &v76[v10]; - let v78 = &v76[v9]; - let v79 = vec![v77, v17]; - let v80 = vec![v78, v25]; - let v81_ref = v0.and_packed(&v79, &v80); - let v81: Vec<&Ciphertext> = v81_ref.iter().collect(); - let v82 = &v81[v10]; - let v83 = &v81[v9]; - let v84 = vec![v18, v83]; - let v85 = vec![v26, v82]; - let v86_ref = v0.and_packed(&v84, &v85); - let v86: Vec<&Ciphertext> = v86_ref.iter().collect(); - let v87 = &v86[v10]; - let v88 = &v86[v9]; - let v89 = vec![v18, v87]; - let v90 = vec![v26, v88]; - let v91_ref = v0.and_packed(&v89, &v90); - let v91: Vec<&Ciphertext> = v91_ref.iter().collect(); - let v92 = vec![v68, v58, v48, v38, v28, v87, v77, v11]; - let v93 = vec![v67, v57, v47, v37, v27, v88, v78, v19]; - let v94_ref = v0.and_packed(&v92, &v93); - let v94: Vec<&Ciphertext> = v94_ref.iter().collect(); - let v95 = &v94[v10]; - let v96 = &v94[v9]; - let v97 = &v94[v8]; - let v98 = &v94[v7]; - let v99 = &v94[v6]; - let v100 = &v94[v5]; - let v101 = &v94[v4]; - let v102 = &v94[v3]; - let v103 = vec![v100, v101, v95, v96, v97, v98, v99, v102]; - v103 -} - -pub fn fn_under_testtt( - v0: &ServerKey, - v1: &Vec, - v2: &Vec, -) -> Vec { - let v3 = 7; - let v4 = 6; - let v5 = 5; - let v6 = 4; - let v7 = 3; - let v8 = 2; - let v9 = 1; - let v10 = 0; - let v11 = v1[v10]; - let v12 = v1[v9]; - let v13 = v1[v8]; - let v14 = v1[v7]; - let v15 = v1[v6]; - let v16 = v1[v5]; - let v17 = v1[v4]; - let v18 = v1[v3]; - let v19 = v2[v10]; - let v20 = v2[v9]; - let v21 = v2[v8]; - let v22 = v2[v7]; - let v23 = v2[v6]; - let v24 = v2[v5]; - let v25 = v2[v4]; - let v26 = v2[v3]; - let v27 = v0.xor(&v11, &v19); - let v28 = v0.and(&v11, &v19); - let v29 = v0.xor(&v12, &v20); - let v30 = v0.and(&v12, &v20); - let v31 = v0.and(&v29, &v28); - let v32 = v0.xor(&v29, &v28); - let v33 = v0.xor(&v30, &v31); - let v34 = v0.xor(&v13, &v21); - let v35 = v0.and(&v13, &v21); - let v36 = v0.and(&v34, &v33); - let v37 = v0.xor(&v34, &v33); - let v38 = v0.xor(&v35, &v36); - let v39 = v0.xor(&v14, &v22); - let v40 = v0.and(&v14, &v22); - let v41 = v0.and(&v39, &v38); - let v42 = v0.xor(&v39, &v38); - let v43 = v0.xor(&v40, &v41); - let v44 = v0.xor(&v15, &v23); - let v45 = v0.and(&v15, &v23); - let v46 = v0.and(&v44, &v43); - let v47 = v0.xor(&v44, &v43); - let v48 = v0.xor(&v45, &v46); - let v49 = v0.xor(&v16, &v24); - let v50 = v0.and(&v16, &v24); - let v51 = v0.and(&v49, &v48); - let v52 = v0.xor(&v49, &v48); - let v53 = v0.xor(&v50, &v51); - let v54 = v0.xor(&v17, &v25); - let v55 = v0.and(&v17, &v25); - let v56 = v0.and(&v54, &v53); - let v57 = v0.xor(&v54, &v53); - let v58 = v0.xor(&v55, &v56); - let v59 = v0.xor(&v18, &v26); - let v60 = v0.xor(&v59, &v58); - let v61 = vec![v60, v57, v52, v47, v42, v37, v32, v27]; - v61 -} \ No newline at end of file diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs b/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs index a62295f40..06b43cfa5 100644 --- a/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs +++ b/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs @@ -76,10 +76,10 @@ fn main() { // let ct_2= ct_2.into_iter().collect(); let t = Instant::now(); - let result = fn_under_test::fn_under_testtt(&server_key, &ct_1, &ct_2); + let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2); let run = t.elapsed().as_millis(); - println!("{:?}", run); + // println!("{:?}", run); let output = decrypt(&result, &client_key); diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir b/tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir deleted file mode 100644 index 0e4f71bd3..000000000 --- a/tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir +++ /dev/null @@ -1,74 +0,0 @@ -// RUN: heir-translate %s --emit-tfhe-rust-bool > %S/src/fn_under_test.rs -// RUN: cargo run --release --manifest-path %S/Cargo.toml -- 1 1 | FileCheck %s - -!bsks = !tfhe_rust_bool.server_key -!eb = !tfhe_rust_bool.eb - -// CHECK-LABEL: pub fn fn_under_test( -// CHECK-NEXT: [[bsks:v[0-9]+]]: &ServerKey, -// CHECK-NEXT: [[input1:v[0-9]+]]: &Vec, -// CHECK-NEXT: [[input2:v[0-9]+]]: &Vec, -// CHECK-NEXT: ) -> Vec { -func.func @fn_under_test(%bsks : !bsks, %arg0: tensor<8x!eb>, %arg1: tensor<8x!eb>) -> tensor<8x!eb> { - %c7 = arith.constant 7 : index - %c6 = arith.constant 6 : index - %c5 = arith.constant 5 : index - %c4 = arith.constant 4 : index - %c3 = arith.constant 3 : index - %c2 = arith.constant 2 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %extracted_00 = tensor.extract_slice %arg0 [0][1][1] : tensor<8x!eb> to tensor<1x!eb> - %extracted_01 = tensor.extract_slice %arg0 [1][1][1] : tensor<8x!eb> to tensor<1x!eb> - %extracted_02 = tensor.extract_slice %arg0 [2][1][1]: tensor<8x!eb> to tensor<1x!eb> - %extracted_03 = tensor.extract_slice %arg0 [3][1][1] : tensor<8x!eb> to tensor<1x!eb> - %extracted_04 = tensor.extract_slice %arg0 [4][1][1] : tensor<8x!eb> to tensor<1x!eb> - %extracted_05 = tensor.extract_slice %arg0 [5][1][1] : tensor<8x!eb> to tensor<1x!eb> - %extracted_06 = tensor.extract_slice %arg0 [6][1][1] : tensor<8x!eb> to tensor<1x!eb> - %extracted_07 = tensor.extract_slice %arg0 [7][1][1] : tensor<8x!eb> to tensor<1x!eb> - %extracted_10 = tensor.extract_slice %arg1 [0][1][1] : tensor<8x!eb> to tensor<1x!eb> - %extracted_11 = tensor.extract_slice %arg1 [1][1][1]: tensor<8x!eb> to tensor<1x!eb> - %extracted_12 = tensor.extract_slice %arg1 [2][1][1] : tensor<8x!eb> to tensor<1x!eb> - %extracted_13 = tensor.extract_slice %arg1 [3][1][1] : tensor<8x!eb> to tensor<1x!eb> - %extracted_14 = tensor.extract_slice %arg1 [4][1][1] : tensor<8x!eb> to tensor<1x!eb> - %extracted_15 = tensor.extract_slice %arg1 [5][1][1] : tensor<8x!eb> to tensor<1x!eb> - %extracted_16 = tensor.extract_slice %arg1 [6][1][1] : tensor<8x!eb> to tensor<1x!eb> - %extracted_17 = tensor.extract_slice %arg1 [7][1][1]: tensor<8x!eb> to tensor<1x!eb> - %ha_s = tfhe_rust_bool.xor_packed %bsks, %extracted_00, %extracted_10 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %ha_c = tfhe_rust_bool.and_packed %bsks, %extracted_00, %extracted_10: (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa0_1 = tfhe_rust_bool.xor_packed %bsks, %extracted_01, %extracted_11 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa0_2 = tfhe_rust_bool.and_packed %bsks, %extracted_01, %extracted_11 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa0_3 = tfhe_rust_bool.and_packed %bsks, %fa0_1, %ha_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa0_s = tfhe_rust_bool.xor_packed %bsks, %fa0_1, %ha_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa0_c = tfhe_rust_bool.xor_packed %bsks, %fa0_2, %fa0_3 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa1_1 = tfhe_rust_bool.xor_packed %bsks, %extracted_02, %extracted_12 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa1_2 = tfhe_rust_bool.and_packed %bsks, %extracted_02, %extracted_12 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa1_3 = tfhe_rust_bool.and_packed %bsks, %fa1_1, %fa0_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa1_s = tfhe_rust_bool.xor_packed %bsks, %fa1_1, %fa0_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa1_c = tfhe_rust_bool.xor_packed %bsks, %fa1_2, %fa1_3 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa2_1 = tfhe_rust_bool.xor_packed %bsks, %extracted_03, %extracted_13 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa2_2 = tfhe_rust_bool.and_packed %bsks, %extracted_03, %extracted_13 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa2_3 = tfhe_rust_bool.and_packed %bsks, %fa2_1, %fa1_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa2_s = tfhe_rust_bool.xor_packed %bsks, %fa2_1, %fa1_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa2_c = tfhe_rust_bool.xor_packed %bsks, %fa2_2, %fa2_3 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa3_1 = tfhe_rust_bool.xor_packed %bsks, %extracted_04, %extracted_14 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa3_2 = tfhe_rust_bool.and_packed %bsks, %extracted_04, %extracted_14 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa3_3 = tfhe_rust_bool.and_packed %bsks, %fa3_1, %fa2_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa3_s = tfhe_rust_bool.xor_packed %bsks, %fa3_1, %fa2_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa3_c = tfhe_rust_bool.xor_packed %bsks, %fa3_2, %fa3_3 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa4_1 = tfhe_rust_bool.xor_packed %bsks, %extracted_05, %extracted_15 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa4_2 = tfhe_rust_bool.and_packed %bsks, %extracted_05, %extracted_15 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa4_3 = tfhe_rust_bool.and_packed %bsks, %fa4_1, %fa3_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa4_s = tfhe_rust_bool.xor_packed %bsks, %fa4_1, %fa3_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa4_c = tfhe_rust_bool.xor_packed %bsks, %fa4_2, %fa4_3 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa5_1 = tfhe_rust_bool.xor_packed %bsks, %extracted_06, %extracted_16 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa5_2 = tfhe_rust_bool.and_packed %bsks, %extracted_06, %extracted_16 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa5_3 = tfhe_rust_bool.and_packed %bsks, %fa5_1, %fa4_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa5_s = tfhe_rust_bool.xor_packed %bsks, %fa5_1, %fa4_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa5_c = tfhe_rust_bool.xor_packed %bsks, %fa5_2, %fa5_3 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa6_1 = tfhe_rust_bool.xor_packed %bsks, %extracted_07, %extracted_17 : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %fa6_s = tfhe_rust_bool.xor_packed %bsks, %fa6_1, %fa5_c : (!bsks, tensor<1x!eb>, tensor<1x!eb>) -> tensor<1x!eb> - %from_elements = tensor.concat dim(0) %fa6_s, %fa5_s, %fa4_s, %fa3_s, %fa2_s, %fa1_s, %fa0_s, %ha_s - : (tensor<1x!eb>, tensor<1x!eb>, tensor<1x!eb>, tensor<1x!eb>, tensor<1x!eb>, tensor<1x!eb>, tensor<1x!eb>, tensor<1x!eb>) -> tensor<8x!eb> - return %from_elements : tensor<8x!eb> -} diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/test_cggi_add_bool.mlir b/tests/tfhe_rust_bool/end_to_end_fpga/test_cggi_add_bool.mlir new file mode 100644 index 000000000..6d612a0a1 --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/test_cggi_add_bool.mlir @@ -0,0 +1,77 @@ +// This test ensures the testing harness is working properly with minimal codegen. + +// RUN: heir-opt --straight-line-vectorize --cggi-to-tfhe-rust-bool -cse -remove-dead-values %s | heir-translate --emit-tfhe-rust-bool > %S/src/fn_under_test.rs +// RUN: cargo run --release --manifest-path %S/Cargo.toml -- 1 1 | FileCheck %s + +#encoding = #lwe.unspecified_bit_field_encoding +!ct_ty = !lwe.lwe_ciphertext +!pt_ty = !lwe.lwe_plaintext + +// CHECK: 01000000 +func.func @fn_under_test(%arg0: tensor<8x!ct_ty>, %arg1: tensor<8x!ct_ty>) -> tensor<8x!ct_ty> { + %true = arith.constant true + %false = arith.constant false + %c7 = arith.constant 7 : index + %c6 = arith.constant 6 : index + %c5 = arith.constant 5 : index + %c4 = arith.constant 4 : index + %c3 = arith.constant 3 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %extracted_00 = tensor.extract %arg0[%c0] : tensor<8x!ct_ty> + %extracted_01 = tensor.extract %arg0[%c1] : tensor<8x!ct_ty> + %extracted_02 = tensor.extract %arg0[%c2] : tensor<8x!ct_ty> + %extracted_03 = tensor.extract %arg0[%c3] : tensor<8x!ct_ty> + %extracted_04 = tensor.extract %arg0[%c4] : tensor<8x!ct_ty> + %extracted_05 = tensor.extract %arg0[%c5] : tensor<8x!ct_ty> + %extracted_06 = tensor.extract %arg0[%c6] : tensor<8x!ct_ty> + %extracted_07 = tensor.extract %arg0[%c7] : tensor<8x!ct_ty> + %extracted_10 = tensor.extract %arg1[%c0] : tensor<8x!ct_ty> + %extracted_11 = tensor.extract %arg1[%c1] : tensor<8x!ct_ty> + %extracted_12 = tensor.extract %arg1[%c2] : tensor<8x!ct_ty> + %extracted_13 = tensor.extract %arg1[%c3] : tensor<8x!ct_ty> + %extracted_14 = tensor.extract %arg1[%c4] : tensor<8x!ct_ty> + %extracted_15 = tensor.extract %arg1[%c5] : tensor<8x!ct_ty> + %extracted_16 = tensor.extract %arg1[%c6] : tensor<8x!ct_ty> + %extracted_17 = tensor.extract %arg1[%c7] : tensor<8x!ct_ty> + %ha_s = cggi.xor %extracted_00, %extracted_10 : !ct_ty + %ha_c = cggi.and %extracted_00, %extracted_10 : !ct_ty + %fa0_1 = cggi.xor %extracted_01, %extracted_11 : !ct_ty + %fa0_2 = cggi.and %extracted_01, %extracted_11 : !ct_ty + %fa0_3 = cggi.and %fa0_1, %ha_c : !ct_ty + %fa0_s = cggi.xor %fa0_1, %ha_c : !ct_ty + %fa0_c = cggi.xor %fa0_2, %fa0_3 : !ct_ty + %fa1_1 = cggi.xor %extracted_02, %extracted_12 : !ct_ty + %fa1_2 = cggi.and %extracted_02, %extracted_12 : !ct_ty + %fa1_3 = cggi.and %fa1_1, %fa0_c : !ct_ty + %fa1_s = cggi.xor %fa1_1, %fa0_c : !ct_ty + %fa1_c = cggi.xor %fa1_2, %fa1_3 : !ct_ty + %fa2_1 = cggi.xor %extracted_03, %extracted_13 : !ct_ty + %fa2_2 = cggi.and %extracted_03, %extracted_13 : !ct_ty + %fa2_3 = cggi.and %fa2_1, %fa1_c : !ct_ty + %fa2_s = cggi.xor %fa2_1, %fa1_c : !ct_ty + %fa2_c = cggi.xor %fa2_2, %fa2_3 : !ct_ty + %fa3_1 = cggi.xor %extracted_04, %extracted_14 : !ct_ty + %fa3_2 = cggi.and %extracted_04, %extracted_14 : !ct_ty + %fa3_3 = cggi.and %fa3_1, %fa2_c : !ct_ty + %fa3_s = cggi.xor %fa3_1, %fa2_c : !ct_ty + %fa3_c = cggi.xor %fa3_2, %fa3_3 : !ct_ty + %fa4_1 = cggi.xor %extracted_05, %extracted_15 : !ct_ty + %fa4_2 = cggi.and %extracted_05, %extracted_15 : !ct_ty + %fa4_3 = cggi.and %fa4_1, %fa3_c : !ct_ty + %fa4_s = cggi.xor %fa4_1, %fa3_c : !ct_ty + %fa4_c = cggi.xor %fa4_2, %fa4_3 : !ct_ty + %fa5_1 = cggi.xor %extracted_06, %extracted_16 : !ct_ty + %fa5_2 = cggi.and %extracted_06, %extracted_16 : !ct_ty + %fa5_3 = cggi.and %fa5_1, %fa4_c : !ct_ty + %fa5_s = cggi.xor %fa5_1, %fa4_c : !ct_ty + %fa5_c = cggi.xor %fa5_2, %fa5_3 : !ct_ty + %fa6_1 = cggi.xor %extracted_07, %extracted_17 : !ct_ty + %fa6_2 = cggi.and %extracted_07, %extracted_17 : !ct_ty + %fa6_3 = cggi.and %fa6_1, %fa5_c : !ct_ty + %fa6_s = cggi.xor %fa6_1, %fa5_c : !ct_ty + %fa6_c = cggi.xor %fa6_2, %fa6_3 : !ct_ty + %from_elements = tensor.from_elements %fa6_s, %fa5_s, %fa4_s, %fa3_s, %fa2_s, %fa1_s, %fa0_s, %ha_s : tensor<8x!ct_ty> + return %from_elements : tensor<8x!ct_ty> +} diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir b/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir index 9a2e0e5b5..d22cadf5d 100644 --- a/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir +++ b/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir @@ -8,6 +8,6 @@ // CHECK: 1 func.func @fn_under_test(%bsks : !bsks, %a: tensor<8x!eb>, %b: tensor<8x!eb>) -> tensor<8x!eb> { - %res = tfhe_rust_bool.and_packed %bsks, %a, %b: (!bsks, tensor<8x!eb>, tensor<8x!eb>) -> tensor<8x!eb> + %res = tfhe_rust_bool.and %bsks, %a, %b: (!bsks, tensor<8x!eb>, tensor<8x!eb>) -> tensor<8x!eb> return %res : tensor<8x!eb> } diff --git a/tests/tfhe_rust_bool/ops.mlir b/tests/tfhe_rust_bool/ops.mlir index 1eb7f7353..b5ca17d68 100644 --- a/tests/tfhe_rust_bool/ops.mlir +++ b/tests/tfhe_rust_bool/ops.mlir @@ -29,7 +29,7 @@ module { // CHECK-LABEL: func @test_packed_and func.func @test_packed_and(%bsks : !bsks, %lhs : tensor<4x!eb>, %rhs : tensor<4x!eb>) { - %out = tfhe_rust_bool.and_packed %bsks, %lhs, %rhs: (!bsks, tensor<4x!eb>, tensor<4x!eb>) -> tensor<4x!eb> + %out = tfhe_rust_bool.and %bsks, %lhs, %rhs: (!bsks, tensor<4x!eb>, tensor<4x!eb>) -> tensor<4x!eb> return } } From 0e5f6db9e6200c3c525138514b6a63dcf4d8be8a Mon Sep 17 00:00:00 2001 From: Wouter Legiest Date: Sun, 24 Mar 2024 23:03:41 +0100 Subject: [PATCH 10/11] cleanups --- .../TargetSlotAnalysis/TargetSlotAnalysis.h | 6 +-- .../Target/TfheRustBool/TfheRustBoolEmitter.h | 3 +- .../TfheRustBool/TfheRustBoolEmitter.cpp | 48 +++++-------------- .../test_vectorize.mlir | 47 ------------------ 4 files changed, 16 insertions(+), 88 deletions(-) delete mode 100644 tests/cggi_to_tfhe_rust_bool/test_vectorize.mlir diff --git a/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h b/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h index 984fb338c..cc7487a98 100644 --- a/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h +++ b/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h @@ -122,9 +122,9 @@ class TargetSlotAnalysis void visitOperation(Operation *op, ArrayRef operands, ArrayRef results) override; - void visitBranchOperand(OpOperand &operand) override{}; - void visitCallOperand(OpOperand &operand) override{}; - void setToExitState(TargetSlotLattice *lattice) override{}; + void visitBranchOperand(OpOperand &operand) override {}; + void visitCallOperand(OpOperand &operand) override {}; + void setToExitState(TargetSlotLattice *lattice) override {}; }; } // namespace target_slot_analysis diff --git a/include/Target/TfheRustBool/TfheRustBoolEmitter.h b/include/Target/TfheRustBool/TfheRustBoolEmitter.h index 512dba039..569b09448 100644 --- a/include/Target/TfheRustBool/TfheRustBoolEmitter.h +++ b/include/Target/TfheRustBool/TfheRustBoolEmitter.h @@ -51,9 +51,7 @@ class TfheRustBoolEmitter { LogicalResult printOperation(::mlir::func::ReturnOp op); LogicalResult printOperation(CreateTrivialOp op); LogicalResult printOperation(tensor::ExtractOp op); - LogicalResult printOperation(tensor::ExtractSliceOp op); LogicalResult printOperation(tensor::FromElementsOp op); - LogicalResult printOperation(tensor::ConcatOp op); LogicalResult printOperation(AndOp op); LogicalResult printOperation(NandOp op); LogicalResult printOperation(OrOp op); @@ -72,6 +70,7 @@ class TfheRustBoolEmitter { FailureOr convertType(Type type); void emitAssignPrefix(::mlir::Value result); + void emitReferenceConversion(::mlir::Value value); }; } // namespace tfhe_rust_bool diff --git a/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp b/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp index 3f0d3dbc1..8c0c4b956 100644 --- a/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp +++ b/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp @@ -67,8 +67,7 @@ LogicalResult TfheRustBoolEmitter::translate(Operation &op) { .Case( [&](auto op) { return printOperation(op); }) // Tensor ops - .Case( + .Case( [&](auto op) { return printOperation(op); }) .Default([&](Operation &) { @@ -169,26 +168,22 @@ void TfheRustBoolEmitter::emitAssignPrefix(Value result) { os << "let " << variableNames->getNameForValue(result) << " = "; } +void TfheRustBoolEmitter::emitReferenceConversion(Value value) { + auto varName = variableNames->getNameForValue(value); + os << "let " << varName << "_ref = " << varName << ".clone();\n"; + os << "let " << varName << "_ref: Vec<&Ciphertext> = " << varName + << ".iter().collect();\n"; +} + LogicalResult TfheRustBoolEmitter::printSksMethod( ::mlir::Value result, ::mlir::Value sks, ::mlir::ValueRange nonSksOperands, std::string_view op, SmallVector operandTypes) { if (isa(nonSksOperands[0].getType())) { - mlir::Operation *opParent = nonSksOperands[0].getDefiningOp(); + auto *opParent = nonSksOperands[0].getDefiningOp(); if (!opParent) { - os << "let " << variableNames->getNameForValue(nonSksOperands[0]) - << "_ref = " << variableNames->getNameForValue(nonSksOperands[0]) - << ".clone();\n"; - os << "let " << variableNames->getNameForValue(nonSksOperands[0]) - << "_ref: Vec<&Ciphertext> = " - << variableNames->getNameForValue(nonSksOperands[0]) - << ".iter().collect();\n"; - os << "let " << variableNames->getNameForValue(nonSksOperands[1]) - << "_ref = " << variableNames->getNameForValue(nonSksOperands[1]) - << ".clone();\n"; - os << "let " << variableNames->getNameForValue(nonSksOperands[1]) - << "_ref: Vec<&Ciphertext> = " - << variableNames->getNameForValue(nonSksOperands[1]) - << ".iter().collect();\n"; + for (auto nonSksOperand : nonSksOperands) { + emitReferenceConversion(nonSksOperand); + } } emitAssignPrefix(result); @@ -277,13 +272,6 @@ LogicalResult TfheRustBoolEmitter::printOperation(tensor::ExtractOp op) { return success(); } -LogicalResult TfheRustBoolEmitter::printOperation(tensor::ExtractSliceOp op) { - emitAssignPrefix(op.getResult()); - os << "vec![&" << variableNames->getNameForValue(op.getSource()) << "[" - << op.getStaticOffsets()[0] << "]];\n"; - return success(); -} - // Need to produce a Vec<&Ciphertext> LogicalResult TfheRustBoolEmitter::printOperation(tensor::FromElementsOp op) { emitAssignPrefix(op.getResult()); @@ -300,18 +288,6 @@ LogicalResult TfheRustBoolEmitter::printOperation(tensor::FromElementsOp op) { return success(); } -LogicalResult TfheRustBoolEmitter::printOperation(tensor::ConcatOp op) { - auto varName = variableNames->getNameForValue(op.getResult()); - os << "let mut " << varName << ": Vec = vec![];\n"; - ValueRange values = op.getOperands(); - for (Value a : values) { - os << varName << ".extend(" << variableNames->getNameForValue(a) - << "_ref);\n"; - } - - return success(); -} - LogicalResult TfheRustBoolEmitter::printOperation(AndOp op) { return printSksMethod(op.getResult(), op.getServerKey(), {op.getLhs(), op.getRhs()}, "and"); diff --git a/tests/cggi_to_tfhe_rust_bool/test_vectorize.mlir b/tests/cggi_to_tfhe_rust_bool/test_vectorize.mlir deleted file mode 100644 index e4f3f7d61..000000000 --- a/tests/cggi_to_tfhe_rust_bool/test_vectorize.mlir +++ /dev/null @@ -1,47 +0,0 @@ -// RUN: heir-opt --cggi-to-tfhe-rust-bool -cse -remove-dead-values %s | FileCheck %s - -#encoding = #lwe.unspecified_bit_field_encoding -!ct_ty = !lwe.lwe_ciphertext -!pt_ty = !lwe.lwe_plaintext - -// CHECK-LABEL: add_bool -// CHECK-NOT: cggi -// CHECK-NOT: lwe -func.func @add_bool(%arg0: tensor<8x!ct_ty>, %arg1: tensor<8x!ct_ty>) -> tensor<8x!ct_ty> { - %true = arith.constant true - %false = arith.constant false - %c7 = arith.constant 7 : index - %c6 = arith.constant 6 : index - %c5 = arith.constant 5 : index - %c4 = arith.constant 4 : index - %c3 = arith.constant 3 : index - %c2 = arith.constant 2 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %extracted_00 = tensor.extract %arg0[%c0] : tensor<8x!ct_ty> - %extracted_01 = tensor.extract %arg0[%c1] : tensor<8x!ct_ty> - %extracted_02 = tensor.extract %arg0[%c2] : tensor<8x!ct_ty> - %extracted_03 = tensor.extract %arg0[%c3] : tensor<8x!ct_ty> - %extracted_04 = tensor.extract %arg0[%c4] : tensor<8x!ct_ty> - %extracted_05 = tensor.extract %arg0[%c5] : tensor<8x!ct_ty> - %extracted_06 = tensor.extract %arg0[%c6] : tensor<8x!ct_ty> - %extracted_07 = tensor.extract %arg0[%c7] : tensor<8x!ct_ty> - %extracted_10 = tensor.extract %arg1[%c0] : tensor<8x!ct_ty> - %extracted_11 = tensor.extract %arg1[%c1] : tensor<8x!ct_ty> - %extracted_12 = tensor.extract %arg1[%c2] : tensor<8x!ct_ty> - %extracted_13 = tensor.extract %arg1[%c3] : tensor<8x!ct_ty> - %extracted_14 = tensor.extract %arg1[%c4] : tensor<8x!ct_ty> - %extracted_15 = tensor.extract %arg1[%c5] : tensor<8x!ct_ty> - %extracted_16 = tensor.extract %arg1[%c6] : tensor<8x!ct_ty> - %extracted_17 = tensor.extract %arg1[%c7] : tensor<8x!ct_ty> - %0 = cggi.xor %extracted_00, %extracted_10 : !ct_ty - %1 = cggi.and %extracted_02, %extracted_12 : !ct_ty - %2 = cggi.xor %extracted_01, %extracted_11 : !ct_ty - %3 = cggi.and %extracted_03, %extracted_13 : !ct_ty - %4 = cggi.xor %extracted_05, %extracted_15 : !ct_ty - %5 = cggi.and %extracted_04, %extracted_14 : !ct_ty - %6 = cggi.xor %extracted_06, %extracted_16 : !ct_ty - %7 = cggi.and %extracted_07, %extracted_17 : !ct_ty - %from_elements = tensor.from_elements %0, %2, %4, %6, %1, %3, %5, %7 : tensor<8x!ct_ty> - return %from_elements : tensor<8x!ct_ty> -} From 1603295adb751c2452865ed81454516808ea8c78 Mon Sep 17 00:00:00 2001 From: Wouter Legiest Date: Tue, 26 Mar 2024 20:16:41 +0100 Subject: [PATCH 11/11] tfhe-rs-bool emitter from cggi to fpga --- lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp | 4 +--- tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs | 3 --- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp b/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp index 8c0c4b956..cdd709283 100644 --- a/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp +++ b/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp @@ -216,7 +216,6 @@ LogicalResult TfheRustBoolEmitter::printSksMethod( auto operandTypesIt = operandTypes.begin(); os << variableNames->getNameForValue(sks) << "." << op << "("; os << commaSeparatedValues(nonSksOperands, [&](Value value) { - // ToDo: can be removed? auto *prefix = value.getType().hasTrait() ? "&" : ""; // First check if a DefiningOp exists // if not: comes from function definition @@ -327,8 +326,7 @@ FailureOr TfheRustBoolEmitter::convertType(Type type) { // FIXME: why can't both types be FailureOr? auto elementTy = convertType(shapedType.getElementType()); if (failed(elementTy)) return failure(); - // auto refprefix = - // shapedType.getElementType().hasTrait() ? "&" : ""; + return std::string(std::string("Vec<") + elementTy.value() + ">"); } return llvm::TypeSwitch>(type) diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs b/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs index 06b43cfa5..3a86c6471 100644 --- a/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs +++ b/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs @@ -72,9 +72,6 @@ fn main() { let ct_1 = encrypt(flags.input1.into(), &client_key); let ct_2 = encrypt(flags.input2.into(), &client_key); - // let ct_1= ct_1.into_iter().collect(); - // let ct_2= ct_2.into_iter().collect(); - let t = Instant::now(); let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2); let run = t.elapsed().as_millis();