Skip to content

Commit

Permalink
Updating the tests for e2e tfhe-rs-bool and starting with the tfhe-rs…
Browse files Browse the repository at this point in the history
…-bool fpga tests
  • Loading branch information
Wouter Legiest authored and Wouter Legiest committed Mar 24, 2024
1 parent cb54c78 commit e9703d9
Show file tree
Hide file tree
Showing 14 changed files with 394 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Expand Up @@ -17,7 +17,7 @@ venv
# for rust codegen tests
**/Cargo.lock
tests/**/**/target/
tests/tfhe_rust_bool/end_to_end_fpga/
tests/tfhe_rust_bool/end_to_end_fpga/tfhe-rs

# vscode
.vscode/**
11 changes: 11 additions & 0 deletions include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td
Expand Up @@ -56,6 +56,17 @@ def AndPackedOp : TfheRustBool_Op<"and_packed", [
let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output);
}

def XorPackedOp : TfheRustBool_Op<"xor_packed", [
Pure,
AllTypesMatch<["lhs", "rhs", "output"]>
]> {
let arguments = (ins
TfheRustBool_ServerKey:$serverKey,
TensorOf<[TfheRustBool_Encrypted]>:$lhs,
TensorOf<[TfheRustBool_Encrypted]>:$rhs
);
let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output);
}

def NotOp : TfheRustBool_Op<"not", [
Pure,
Expand Down
2 changes: 0 additions & 2 deletions tests/cggi_to_tfhe_rust_bool/add_bool.mlir
@@ -1,11 +1,9 @@
// RUN: heir-opt --cggi-to-tfhe-rust-bool -cse -remove-dead-values %s | FileCheck %s


#encoding = #lwe.unspecified_bit_field_encoding<cleartext_bitwidth = 1>
!ct_ty = !lwe.lwe_ciphertext<encoding = #encoding>
!pt_ty = !lwe.lwe_plaintext<encoding = #encoding>


// CHECK-LABEL: add_bool
// CHECK-NOT: cggi
// CHECK-NOT: lwe
Expand Down
2 changes: 0 additions & 2 deletions tests/cggi_to_tfhe_rust_bool/add_one_bool.mlir
@@ -1,11 +1,9 @@
// RUN: heir-opt --cggi-to-tfhe-rust-bool -cse -remove-dead-values %s | FileCheck %s


#encoding = #lwe.unspecified_bit_field_encoding<cleartext_bitwidth = 1>
!ct_ty = !lwe.lwe_ciphertext<encoding = #encoding>
!pt_ty = !lwe.lwe_plaintext<encoding = #encoding>


// CHECK-LABEL: add_one_bool
// CHECK-NOT: cggi
// CHECK-NOT: lwe
Expand Down
1 change: 1 addition & 0 deletions tests/tfhe_rust_bool/end_to_end/BUILD
Expand Up @@ -12,6 +12,7 @@ glob_lit_tests(
data = [
"Cargo.toml",
"src/main.rs",
"src/main_bool_add.rs",
"@heir//tests:test_utilities",
],
default_tags = [
Expand Down
4 changes: 4 additions & 0 deletions tests/tfhe_rust_bool/end_to_end/Cargo.toml
Expand Up @@ -12,3 +12,7 @@ tfhe = { version = "0.4.1", features = ["boolean", "x86_64-unix"] }
[[bin]]
name = "main"
path = "src/main.rs"

[[bin]]
name = "main_bool_add"
path = "src/main_bool_add.rs"
51 changes: 51 additions & 0 deletions tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs
@@ -0,0 +1,51 @@
use clap::Parser;
use tfhe::boolean::prelude::*;

mod fn_under_test;

// TODO(https://github.com/google/heir/issues/235): improve generality
#[derive(Parser, Debug)]
struct Args {
/// arguments to forward to function under test
#[arg(id = "input_1", index = 1, action)]
input1: u8,

#[arg(id = "input_2", index = 2, action)]
input2: u8,
}

// Encrypt a u8
pub fn encrypt(value: u8, client_key: &ClientKey) -> Vec<Ciphertext> {
let arr: [u8; 8] = core::array::from_fn(|shift| (value >> shift) & 1 );

let res: Vec<Ciphertext> = arr.iter()
.map(|bit| client_key.encrypt(if *bit != 0u8 { true } else { false }))
.collect();
res
}

// Decrypt a u8
pub fn decrypt(ciphertexts: &Vec<Ciphertext>, client_key: &ClientKey) -> u8 {
let mut accum = 0u8;
for (i, ct) in ciphertexts.iter().enumerate() {
let bit = client_key.decrypt(ct);
accum |= (bit as u8) << i;
}
accum.reverse_bits()

}

fn main() {
let flags = Args::parse();
let (client_key, server_key) = tfhe::boolean::gen_keys();

let ct_1 = encrypt(flags.input1.into(), &client_key);
let ct_2 = encrypt(flags.input2.into(), &client_key);


let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2);

let output = decrypt(&result, &client_key);

println!("{:08b}", output);
}
69 changes: 69 additions & 0 deletions tests/tfhe_rust_bool/end_to_end/test_bool_add.mlir
@@ -0,0 +1,69 @@
// RUN: heir-translate %s --emit-tfhe-rust-bool > %S/src/fn_under_test.rs
// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main_bool_add -- 15 3 | FileCheck %s

!bsks = !tfhe_rust_bool.server_key
!eb = !tfhe_rust_bool.eb

// CHECK: 00010010
func.func @fn_under_test(%bsks : !bsks, %arg0: tensor<8x!eb>, %arg1: tensor<8x!eb>) -> tensor<8x!eb> {
%c7 = arith.constant 7 : index
%c6 = arith.constant 6 : index
%c5 = arith.constant 5 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%extracted_00 = tensor.extract %arg0[%c0] : tensor<8x!eb>
%extracted_01 = tensor.extract %arg0[%c1] : tensor<8x!eb>
%extracted_02 = tensor.extract %arg0[%c2] : tensor<8x!eb>
%extracted_03 = tensor.extract %arg0[%c3] : tensor<8x!eb>
%extracted_04 = tensor.extract %arg0[%c4] : tensor<8x!eb>
%extracted_05 = tensor.extract %arg0[%c5] : tensor<8x!eb>
%extracted_06 = tensor.extract %arg0[%c6] : tensor<8x!eb>
%extracted_07 = tensor.extract %arg0[%c7] : tensor<8x!eb>
%extracted_10 = tensor.extract %arg1[%c0] : tensor<8x!eb>
%extracted_11 = tensor.extract %arg1[%c1] : tensor<8x!eb>
%extracted_12 = tensor.extract %arg1[%c2] : tensor<8x!eb>
%extracted_13 = tensor.extract %arg1[%c3] : tensor<8x!eb>
%extracted_14 = tensor.extract %arg1[%c4] : tensor<8x!eb>
%extracted_15 = tensor.extract %arg1[%c5] : tensor<8x!eb>
%extracted_16 = tensor.extract %arg1[%c6] : tensor<8x!eb>
%extracted_17 = tensor.extract %arg1[%c7] : tensor<8x!eb>
%ha_s = tfhe_rust_bool.xor %bsks, %extracted_00, %extracted_10 : (!bsks, !eb, !eb) -> !eb
%ha_c = tfhe_rust_bool.and %bsks, %extracted_00, %extracted_10 : (!bsks, !eb, !eb) -> !eb
%fa0_1 = tfhe_rust_bool.xor %bsks, %extracted_01, %extracted_11 : (!bsks, !eb, !eb) -> !eb
%fa0_2 = tfhe_rust_bool.and %bsks, %extracted_01, %extracted_11 : (!bsks, !eb, !eb) -> !eb
%fa0_3 = tfhe_rust_bool.and %bsks, %fa0_1, %ha_c : (!bsks, !eb, !eb) -> !eb
%fa0_s = tfhe_rust_bool.xor %bsks, %fa0_1, %ha_c : (!bsks, !eb, !eb) -> !eb
%fa0_c = tfhe_rust_bool.xor %bsks, %fa0_2, %fa0_3 : (!bsks, !eb, !eb) -> !eb
%fa1_1 = tfhe_rust_bool.xor %bsks, %extracted_02, %extracted_12 : (!bsks, !eb, !eb) -> !eb
%fa1_2 = tfhe_rust_bool.and %bsks, %extracted_02, %extracted_12 : (!bsks, !eb, !eb) -> !eb
%fa1_3 = tfhe_rust_bool.and %bsks, %fa1_1, %fa0_c : (!bsks, !eb, !eb) -> !eb
%fa1_s = tfhe_rust_bool.xor %bsks, %fa1_1, %fa0_c : (!bsks, !eb, !eb) -> !eb
%fa1_c = tfhe_rust_bool.xor %bsks, %fa1_2, %fa1_3 : (!bsks, !eb, !eb) -> !eb
%fa2_1 = tfhe_rust_bool.xor %bsks, %extracted_03, %extracted_13 : (!bsks, !eb, !eb) -> !eb
%fa2_2 = tfhe_rust_bool.and %bsks, %extracted_03, %extracted_13 : (!bsks, !eb, !eb) -> !eb
%fa2_3 = tfhe_rust_bool.and %bsks, %fa2_1, %fa1_c : (!bsks, !eb, !eb) -> !eb
%fa2_s = tfhe_rust_bool.xor %bsks, %fa2_1, %fa1_c : (!bsks, !eb, !eb) -> !eb
%fa2_c = tfhe_rust_bool.xor %bsks, %fa2_2, %fa2_3 : (!bsks, !eb, !eb) -> !eb
%fa3_1 = tfhe_rust_bool.xor %bsks, %extracted_04, %extracted_14 : (!bsks, !eb, !eb) -> !eb
%fa3_2 = tfhe_rust_bool.and %bsks, %extracted_04, %extracted_14 : (!bsks, !eb, !eb) -> !eb
%fa3_3 = tfhe_rust_bool.and %bsks, %fa3_1, %fa2_c : (!bsks, !eb, !eb) -> !eb
%fa3_s = tfhe_rust_bool.xor %bsks, %fa3_1, %fa2_c : (!bsks, !eb, !eb) -> !eb
%fa3_c = tfhe_rust_bool.xor %bsks, %fa3_2, %fa3_3 : (!bsks, !eb, !eb) -> !eb
%fa4_1 = tfhe_rust_bool.xor %bsks, %extracted_05, %extracted_15 : (!bsks, !eb, !eb) -> !eb
%fa4_2 = tfhe_rust_bool.and %bsks, %extracted_05, %extracted_15 : (!bsks, !eb, !eb) -> !eb
%fa4_3 = tfhe_rust_bool.and %bsks, %fa4_1, %fa3_c : (!bsks, !eb, !eb) -> !eb
%fa4_s = tfhe_rust_bool.xor %bsks, %fa4_1, %fa3_c : (!bsks, !eb, !eb) -> !eb
%fa4_c = tfhe_rust_bool.xor %bsks, %fa4_2, %fa4_3 : (!bsks, !eb, !eb) -> !eb
%fa5_1 = tfhe_rust_bool.xor %bsks, %extracted_06, %extracted_16 : (!bsks, !eb, !eb) -> !eb
%fa5_2 = tfhe_rust_bool.and %bsks, %extracted_06, %extracted_16 : (!bsks, !eb, !eb) -> !eb
%fa5_3 = tfhe_rust_bool.and %bsks, %fa5_1, %fa4_c : (!bsks, !eb, !eb) -> !eb
%fa5_s = tfhe_rust_bool.xor %bsks, %fa5_1, %fa4_c : (!bsks, !eb, !eb) -> !eb
%fa5_c = tfhe_rust_bool.xor %bsks, %fa5_2, %fa5_3 : (!bsks, !eb, !eb) -> !eb
%fa6_1 = tfhe_rust_bool.xor %bsks, %extracted_07, %extracted_17 : (!bsks, !eb, !eb) -> !eb
%fa6_s = tfhe_rust_bool.xor %bsks, %fa6_1, %fa5_c : (!bsks, !eb, !eb) -> !eb
%from_elements = tensor.from_elements %fa6_s, %fa5_s, %fa4_s, %fa3_s, %fa2_s, %fa1_s, %fa0_s, %ha_s : tensor<8x!eb>
return %from_elements : tensor<8x!eb>
}
23 changes: 23 additions & 0 deletions tests/tfhe_rust_bool/end_to_end_fpga/BUILD
@@ -0,0 +1,23 @@
# See README.md for setup required to run these tests

load("//bazel:lit.bzl", "glob_lit_tests")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

glob_lit_tests(
name = "all_tests",
data = [
"Cargo.toml",
"src/main.rs",
"@heir//tests:test_utilities",
],
default_tags = [
"manual",
"notap",
],
driver = "@heir//tests:run_lit.sh",
test_file_exts = ["mlir"],
)
21 changes: 21 additions & 0 deletions tests/tfhe_rust_bool/end_to_end_fpga/Cargo.toml
@@ -0,0 +1,21 @@
[package]
name = "heir-tfhe-rust-integration-test"
version = "0.1.0"
edition = "2021"
default-run = "main"

[dependencies]
clap = { version = "4.1.8", features = ["derive"] }
rayon = "1.6.1"
serde = { version = "1.0.152", features = ["derive"] }
tfhe = { path = "tfhe-rs/tfhe", features = [
"boolean",
"x86_64-unix",
] }

[features]
fpga = ["tfhe/fpga"]

[[bin]]
name = "main"
path = "src/main.rs"
44 changes: 44 additions & 0 deletions tests/tfhe_rust_bool/end_to_end_fpga/README.md
@@ -0,0 +1,44 @@
# End to end Rust codegen tests - Boolean FPGA

These tests exercise Rust codegen for the
[tfhe-rs](https://github.com/zama-ai/tfhe-rs) backend library, including
compiling the generated Rust source and running the resulting binary. This sets
tests are specifically of the boolean plaintexts, accompanying COSIC-KU Leuven version of the library,
and the [FPT-accelerator](https://eprint.iacr.org/2022/1635).

> :warning: Not possible to run these tests without the COSIC extension of TFHE-rs and FPT-accelerator
To avoid introducing these large dependencies into the entire project, these
tests are manual, and require the system they're running on to have
[Cargo](https://doc.rust-lang.org/cargo/index.html) installed. During the test,
cargo will fetch and build the required dependencies, and `Cargo.toml` in this
directory effectively pins the version of `tfhe` supported.

Use the following command to run the tests in this directory, where the default
Cargo home `$HOME/.cargo` may need to be replaced by your custom `$CARGO_HOME`,
if you overrode the default option when installing Cargo.

```bash
bazel query "filter('.mlir.test$', //tests/tfhe_rust_bool/end_to_end/...)" \
| xargs bazel test --sandbox_writable_path=$HOME/.cargo "$@"
```

The `manual` tag is added to the targets in this directory to ensure that they
are not run when someone runs a glob test like `bazel test //...`.

If you don't do this correctly, you will see an error like this:

```
# .---command stderr------------
# | Updating crates.io index
# | Downloading crates ...
# | Downloaded memoffset v0.9.0
# | error: failed to download replaced source registry `crates-io`
# |
# | Caused by:
# | failed to open `/home/you/.cargo/registry/cache/index.crates.io-6f17d22bba15001f/memoffset-0.9.0.crate`
# |
# | Caused by:
# | Read-only file system (os error 30)
# `-----------------------------
```
83 changes: 83 additions & 0 deletions tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs
@@ -0,0 +1,83 @@
use clap::Parser;
use tfhe::boolean::prelude::*;

use tfhe::boolean::engine::BooleanEngine;
use tfhe::boolean::prelude::*;
use std::time::Instant;

#[cfg(feature = "fpga")]
use tfhe::boolean::server_key::FpgaGates;


mod fn_under_test;

// TODO(https://github.com/google/heir/issues/235): improve generality
#[derive(Parser, Debug)]
struct Args {
/// arguments to forward to function under test
#[arg(id = "input_1", index = 1, action)]
input1: u8,

#[arg(id = "input_2", index = 2, action)]
input2: u8,
}

// Encrypt a u8
pub fn encrypt(value: u8, client_key: &ClientKey) -> Vec<Ciphertext> {
let arr: [u8; 8] = core::array::from_fn(|shift| (value >> shift) & 1 );

let res: Vec<Ciphertext> = arr.iter()
.map(|bit| client_key.encrypt(if *bit != 0u8 { true } else { false }))
.collect();
res
}

// Decrypt a u8
pub fn decrypt(ciphertexts: &Vec<Ciphertext>, client_key: &ClientKey) -> u8 {
let mut accum = 0u8;
for (i, ct) in ciphertexts.iter().enumerate() {
let bit = client_key.decrypt(ct);
accum |= (bit as u8) << i;
}
accum

}

fn main() {
let flags = Args::parse();

let params;
let client_key;

let mut boolean_engine = BooleanEngine::new();

#[cfg(feature = "fpga")]
{
params = tfhe::boolean::engine::fpga::parameters::DEFAULT_PARAMETERS_KS_PBS;
client_key = boolean_engine.create_client_key(*params);
}

#[cfg(not(feature = "fpga"))]
{
params = tfhe::boolean::parameters::DEFAULT_PARAMETERS_KS_PBS;
client_key = boolean_engine.create_client_key(params);
}

// generate the server key, only the SW needs this
let server_key = boolean_engine.create_server_key(&client_key);

#[cfg(feature = "fpga")]
server_key.enable_fpga(params);

let ct_1 = encrypt(flags.input1.into(), &client_key);
let ct_2 = encrypt(flags.input2.into(), &client_key);

let ct_1= ct_1.iter().collect();
let ct_2= ct_2.iter().collect();

let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2);

let output = decrypt(&result, &client_key);

println!("{:08b}", output);
}

0 comments on commit e9703d9

Please sign in to comment.