forked from google/heir
/
main.rs
83 lines (61 loc) · 2.09 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
use clap::Parser;
use tfhe::boolean::prelude::*;
use tfhe::boolean::engine::BooleanEngine;
use tfhe::boolean::prelude::*;
use std::time::Instant;
#[cfg(feature = "fpga")]
use tfhe::boolean::server_key::FpgaGates;
mod fn_under_test;
// TODO(https://github.com/google/heir/issues/235): improve generality
#[derive(Parser, Debug)]
struct Args {
/// arguments to forward to function under test
#[arg(id = "input_1", index = 1, action)]
input1: u8,
#[arg(id = "input_2", index = 2, action)]
input2: u8,
}
// Encrypt a u8
pub fn encrypt(value: u8, client_key: &ClientKey) -> Vec<Ciphertext> {
let arr: [u8; 8] = core::array::from_fn(|shift| (value >> shift) & 1 );
let res: Vec<Ciphertext> = arr.iter()
.map(|bit| client_key.encrypt(if *bit != 0u8 { true } else { false }))
.collect();
res
}
// Decrypt a u8
pub fn decrypt(ciphertexts: &Vec<Ciphertext>, client_key: &ClientKey) -> u8 {
let mut accum = 0u8;
for (i, ct) in ciphertexts.iter().enumerate() {
let bit = client_key.decrypt(ct);
accum |= (bit as u8) << i;
}
accum
}
fn main() {
let flags = Args::parse();
let params;
let client_key;
let mut boolean_engine = BooleanEngine::new();
#[cfg(feature = "fpga")]
{
params = tfhe::boolean::engine::fpga::parameters::DEFAULT_PARAMETERS_KS_PBS;
client_key = boolean_engine.create_client_key(*params);
}
#[cfg(not(feature = "fpga"))]
{
params = tfhe::boolean::parameters::DEFAULT_PARAMETERS_KS_PBS;
client_key = boolean_engine.create_client_key(params);
}
// generate the server key, only the SW needs this
let server_key = boolean_engine.create_server_key(&client_key);
#[cfg(feature = "fpga")]
server_key.enable_fpga(params);
let ct_1 = encrypt(flags.input1.into(), &client_key);
let ct_2 = encrypt(flags.input2.into(), &client_key);
let ct_1= ct_1.iter().collect();
let ct_2= ct_2.iter().collect();
let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2);
let output = decrypt(&result, &client_key);
println!("{:08b}", output);
}