diff --git a/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td b/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td index 92dd1033f..f64bfced7 100644 --- a/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td +++ b/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td @@ -24,16 +24,20 @@ def CreateTrivialOp : TfheRustBool_Op<"create_trivial", [Pure]> { let results = (outs TfheRustBool_Encrypted:$output); } +// --- Operations for a gate-bootstrapping API of a CGGI library --- +def TfheRustBoolLike : TypeOrContainer; + + class TfheRustBool_BinaryGateOp : TfheRustBool_Op ]> { let arguments = (ins TfheRustBool_ServerKey:$serverKey, - TfheRustBool_Encrypted:$lhs, - TfheRustBool_Encrypted:$rhs + TfheRustBoolLike:$lhs, + TfheRustBoolLike:$rhs ); - let results = (outs TfheRustBool_Encrypted:$output); + let results = (outs TfheRustBoolLike:$output); } def AndOp : TfheRustBool_BinaryGateOp<"and"> { let summary = "Logical AND of two TFHE-rs Bool ciphertexts."; } @@ -43,31 +47,6 @@ def NorOp : TfheRustBool_BinaryGateOp<"nor"> { let summary = "Logical NOR of t def XorOp : TfheRustBool_BinaryGateOp<"xor"> { let summary = "Logical XOR of two TFHE-rs Bool ciphertexts."; } def XnorOp : TfheRustBool_BinaryGateOp<"xnor"> { let summary = "Logical XNOR of two TFHE-rs Bool ciphertexts."; } - -def AndPackedOp : TfheRustBool_Op<"and_packed", [ - Pure, - AllTypesMatch<["lhs", "rhs", "output"]> -]> { - let arguments = (ins - TfheRustBool_ServerKey:$serverKey, - TensorOf<[TfheRustBool_Encrypted]>:$lhs, - TensorOf<[TfheRustBool_Encrypted]>:$rhs - ); - let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output); -} - -def XorPackedOp : TfheRustBool_Op<"xor_packed", [ - Pure, - AllTypesMatch<["lhs", "rhs", "output"]> -]> { - let arguments = (ins - TfheRustBool_ServerKey:$serverKey, - TensorOf<[TfheRustBool_Encrypted]>:$lhs, - TensorOf<[TfheRustBool_Encrypted]>:$rhs - ); - let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output); -} - def NotOp : TfheRustBool_Op<"not", [ Pure, AllTypesMatch<["input", "output"]> diff --git a/include/Target/TfheRustBool/TfheRustBoolEmitter.h b/include/Target/TfheRustBool/TfheRustBoolEmitter.h index fca9fe291..569b09448 100644 --- a/include/Target/TfheRustBool/TfheRustBoolEmitter.h +++ b/include/Target/TfheRustBool/TfheRustBoolEmitter.h @@ -59,8 +59,6 @@ class TfheRustBoolEmitter { LogicalResult printOperation(XorOp op); LogicalResult printOperation(XnorOp op); - LogicalResult printOperation(AndPackedOp op); - // Helpers for above LogicalResult printSksMethod(::mlir::Value result, ::mlir::Value sks, ::mlir::ValueRange nonSksOperands, @@ -72,6 +70,7 @@ class TfheRustBoolEmitter { FailureOr convertType(Type type); void emitAssignPrefix(::mlir::Value result); + void emitReferenceConversion(::mlir::Value value); }; } // namespace tfhe_rust_bool diff --git a/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp b/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp index 8894ce615..cdd709283 100644 --- a/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp +++ b/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp @@ -64,7 +64,7 @@ LogicalResult TfheRustBoolEmitter::translate(Operation &op) { // Arith ops .Case([&](auto op) { return printOperation(op); }) // TfheRustBool ops - .Case( + .Case( [&](auto op) { return printOperation(op); }) // Tensor ops .Case( @@ -148,6 +148,9 @@ LogicalResult TfheRustBoolEmitter::printOperation(func::ReturnOp op) { if (isa(value)) { cloneStr = ".clone()"; } + if (isa(value.getDefiningOp())) { + cloneStr = ".into_iter().cloned().collect()"; + } return variableNames->getNameForValue(value) + cloneStr; }; @@ -165,29 +168,70 @@ void TfheRustBoolEmitter::emitAssignPrefix(Value result) { os << "let " << variableNames->getNameForValue(result) << " = "; } +void TfheRustBoolEmitter::emitReferenceConversion(Value value) { + auto varName = variableNames->getNameForValue(value); + os << "let " << varName << "_ref = " << varName << ".clone();\n"; + os << "let " << varName << "_ref: Vec<&Ciphertext> = " << varName + << ".iter().collect();\n"; +} + LogicalResult TfheRustBoolEmitter::printSksMethod( ::mlir::Value result, ::mlir::Value sks, ::mlir::ValueRange nonSksOperands, std::string_view op, SmallVector operandTypes) { - emitAssignPrefix(result); - - auto operandTypesIt = operandTypes.begin(); - os << variableNames->getNameForValue(sks) << "." << op << "("; - os << commaSeparatedValues(nonSksOperands, [&](Value value) { - auto *prefix = value.getType().hasTrait() ? "&" : ""; - // First check if a DefiningOp exists - // if not: comes from function definition - mlir::Operation *op = value.getDefiningOp(); - if (op) { - prefix = isa(op) ? "" : prefix; - } else { - prefix = ""; + if (isa(nonSksOperands[0].getType())) { + auto *opParent = nonSksOperands[0].getDefiningOp(); + if (!opParent) { + for (auto nonSksOperand : nonSksOperands) { + emitReferenceConversion(nonSksOperand); + } } - return prefix + variableNames->getNameForValue(value) + - (!operandTypes.empty() ? " as " + *operandTypesIt++ : ""); - }); - os << ");\n"; - return success(); + emitAssignPrefix(result); + + os << variableNames->getNameForValue(sks) << "." << op << "_packed("; + os << commaSeparatedValues( + {nonSksOperands[0], nonSksOperands[1]}, [&](Value value) { + auto *prefix = value.getType().hasTrait() ? "&" : ""; + auto suffix = ""; + // First check if a DefiningOp exists + // if not: comes from function definition + mlir::Operation *opParent = value.getDefiningOp(); + if (opParent) { + prefix = isa(opParent) ? prefix : ""; + prefix = + isa(value.getDefiningOp()) ? "&" : ""; + } else { + prefix = "&"; + suffix = "_ref"; + } + + return prefix + variableNames->getNameForValue(value) + suffix; + }); + os << ");\n"; + return success(); + + } else { + emitAssignPrefix(result); + + auto operandTypesIt = operandTypes.begin(); + os << variableNames->getNameForValue(sks) << "." << op << "("; + os << commaSeparatedValues(nonSksOperands, [&](Value value) { + auto *prefix = value.getType().hasTrait() ? "&" : ""; + // First check if a DefiningOp exists + // if not: comes from function definition + mlir::Operation *op = value.getDefiningOp(); + if (op) { + prefix = isa(op) ? "" : prefix; + } else { + prefix = ""; + } + + return prefix + variableNames->getNameForValue(value) + + (!operandTypes.empty() ? " as " + *operandTypesIt++ : ""); + }); + os << ");\n"; + return success(); + } } LogicalResult TfheRustBoolEmitter::printOperation(CreateTrivialOp op) { @@ -215,6 +259,7 @@ LogicalResult TfheRustBoolEmitter::printOperation(arith::ConstantOp op) { return success(); } +// Produces a &Ciphertext LogicalResult TfheRustBoolEmitter::printOperation(tensor::ExtractOp op) { // We assume here that the indices are SSA values (not integer attributes). emitAssignPrefix(op.getResult()); @@ -226,15 +271,18 @@ LogicalResult TfheRustBoolEmitter::printOperation(tensor::ExtractOp op) { return success(); } +// Need to produce a Vec<&Ciphertext> LogicalResult TfheRustBoolEmitter::printOperation(tensor::FromElementsOp op) { emitAssignPrefix(op.getResult()); os << "vec![" << commaSeparatedValues(op.getOperands(), [&](Value value) { // Check if block argument, if so, clone. - auto cloneStr = ""; - if (isa(value)) { - cloneStr = ".clone()"; - } - return variableNames->getNameForValue(value) + cloneStr; + auto cloneStr = isa(value) ? ".clone()" : ""; + // Get the name of defining operation its dialect + auto tfhe_op = + value.getDefiningOp()->getDialect()->getNamespace() == "tfhe_rust_bool"; + auto prefix = tfhe_op ? "&" : ""; + return std::string(prefix) + variableNames->getNameForValue(value) + + cloneStr; }) << "];\n"; return success(); } @@ -269,11 +317,6 @@ LogicalResult TfheRustBoolEmitter::printOperation(XnorOp op) { {op.getLhs(), op.getRhs()}, "xnor"); } -LogicalResult TfheRustBoolEmitter::printOperation(AndPackedOp op) { - return printSksMethod(op.getResult(), op.getServerKey(), - {op.getLhs(), op.getRhs()}, "and_packed"); -} - FailureOr TfheRustBoolEmitter::convertType(Type type) { // Note: these are probably not the right type names to use exactly, and they // will need to chance to the right values once we try to compile it against @@ -283,7 +326,8 @@ FailureOr TfheRustBoolEmitter::convertType(Type type) { // FIXME: why can't both types be FailureOr? auto elementTy = convertType(shapedType.getElementType()); if (failed(elementTy)) return failure(); - return std::string("Vec<" + elementTy.value() + ">"); + + return std::string(std::string("Vec<") + elementTy.value() + ">"); } return llvm::TypeSwitch>(type) .Case( diff --git a/tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs b/tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs index 02f61b2ef..a0ccb8f64 100644 --- a/tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs +++ b/tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs @@ -32,7 +32,6 @@ pub fn decrypt(ciphertexts: &Vec, client_key: &ClientKey) -> u8 { accum |= (bit as u8) << i; } accum.reverse_bits() - } fn main() { diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/BUILD b/tests/tfhe_rust_bool/end_to_end_fpga/BUILD index a189be648..868d9ba3c 100644 --- a/tests/tfhe_rust_bool/end_to_end_fpga/BUILD +++ b/tests/tfhe_rust_bool/end_to_end_fpga/BUILD @@ -12,6 +12,7 @@ glob_lit_tests( data = [ "Cargo.toml", "src/main.rs", + "tfhe-rs", "@heir//tests:test_utilities", ], default_tags = [ diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/README.md b/tests/tfhe_rust_bool/end_to_end_fpga/README.md index 42be8b572..fde92ec51 100644 --- a/tests/tfhe_rust_bool/end_to_end_fpga/README.md +++ b/tests/tfhe_rust_bool/end_to_end_fpga/README.md @@ -19,10 +19,16 @@ Cargo home `$HOME/.cargo` may need to be replaced by your custom `$CARGO_HOME`, if you overrode the default option when installing Cargo. ```bash -bazel query "filter('.mlir.test$', //tests/tfhe_rust_bool/end_to_end/...)" \ +bazel query "filter('.mlir.test$', //tests/tfhe_rust_bool/end_to_end_fpga/...)" \ | xargs bazel test --sandbox_writable_path=$HOME/.cargo "$@" ``` +Manually generate the Rust code fort the CGGI lowering: +```bash +bazel run //tools:heir-opt -- -cse --straight-line-vectorize --cggi-to-tfhe-rust-bool -cse $(pwd)/tests/cggi_to_tfhe_rust_bool/add_bool.mlir | bazel run //tools:heir-translate -- --emit-tfhe-rust-bool +``` + + The `manual` tag is added to the targets in this directory to ensure that they are not run when someone runs a glob test like `bazel test //...`. diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs b/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs index 47392eaa9..3a86c6471 100644 --- a/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs +++ b/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs @@ -72,10 +72,11 @@ fn main() { let ct_1 = encrypt(flags.input1.into(), &client_key); let ct_2 = encrypt(flags.input2.into(), &client_key); - let ct_1= ct_1.iter().collect(); - let ct_2= ct_2.iter().collect(); - + let t = Instant::now(); let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2); + let run = t.elapsed().as_millis(); + + // println!("{:?}", run); let output = decrypt(&result, &client_key); diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/test_cggi_add_bool.mlir b/tests/tfhe_rust_bool/end_to_end_fpga/test_cggi_add_bool.mlir new file mode 100644 index 000000000..6d612a0a1 --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/test_cggi_add_bool.mlir @@ -0,0 +1,77 @@ +// This test ensures the testing harness is working properly with minimal codegen. + +// RUN: heir-opt --straight-line-vectorize --cggi-to-tfhe-rust-bool -cse -remove-dead-values %s | heir-translate --emit-tfhe-rust-bool > %S/src/fn_under_test.rs +// RUN: cargo run --release --manifest-path %S/Cargo.toml -- 1 1 | FileCheck %s + +#encoding = #lwe.unspecified_bit_field_encoding +!ct_ty = !lwe.lwe_ciphertext +!pt_ty = !lwe.lwe_plaintext + +// CHECK: 01000000 +func.func @fn_under_test(%arg0: tensor<8x!ct_ty>, %arg1: tensor<8x!ct_ty>) -> tensor<8x!ct_ty> { + %true = arith.constant true + %false = arith.constant false + %c7 = arith.constant 7 : index + %c6 = arith.constant 6 : index + %c5 = arith.constant 5 : index + %c4 = arith.constant 4 : index + %c3 = arith.constant 3 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %extracted_00 = tensor.extract %arg0[%c0] : tensor<8x!ct_ty> + %extracted_01 = tensor.extract %arg0[%c1] : tensor<8x!ct_ty> + %extracted_02 = tensor.extract %arg0[%c2] : tensor<8x!ct_ty> + %extracted_03 = tensor.extract %arg0[%c3] : tensor<8x!ct_ty> + %extracted_04 = tensor.extract %arg0[%c4] : tensor<8x!ct_ty> + %extracted_05 = tensor.extract %arg0[%c5] : tensor<8x!ct_ty> + %extracted_06 = tensor.extract %arg0[%c6] : tensor<8x!ct_ty> + %extracted_07 = tensor.extract %arg0[%c7] : tensor<8x!ct_ty> + %extracted_10 = tensor.extract %arg1[%c0] : tensor<8x!ct_ty> + %extracted_11 = tensor.extract %arg1[%c1] : tensor<8x!ct_ty> + %extracted_12 = tensor.extract %arg1[%c2] : tensor<8x!ct_ty> + %extracted_13 = tensor.extract %arg1[%c3] : tensor<8x!ct_ty> + %extracted_14 = tensor.extract %arg1[%c4] : tensor<8x!ct_ty> + %extracted_15 = tensor.extract %arg1[%c5] : tensor<8x!ct_ty> + %extracted_16 = tensor.extract %arg1[%c6] : tensor<8x!ct_ty> + %extracted_17 = tensor.extract %arg1[%c7] : tensor<8x!ct_ty> + %ha_s = cggi.xor %extracted_00, %extracted_10 : !ct_ty + %ha_c = cggi.and %extracted_00, %extracted_10 : !ct_ty + %fa0_1 = cggi.xor %extracted_01, %extracted_11 : !ct_ty + %fa0_2 = cggi.and %extracted_01, %extracted_11 : !ct_ty + %fa0_3 = cggi.and %fa0_1, %ha_c : !ct_ty + %fa0_s = cggi.xor %fa0_1, %ha_c : !ct_ty + %fa0_c = cggi.xor %fa0_2, %fa0_3 : !ct_ty + %fa1_1 = cggi.xor %extracted_02, %extracted_12 : !ct_ty + %fa1_2 = cggi.and %extracted_02, %extracted_12 : !ct_ty + %fa1_3 = cggi.and %fa1_1, %fa0_c : !ct_ty + %fa1_s = cggi.xor %fa1_1, %fa0_c : !ct_ty + %fa1_c = cggi.xor %fa1_2, %fa1_3 : !ct_ty + %fa2_1 = cggi.xor %extracted_03, %extracted_13 : !ct_ty + %fa2_2 = cggi.and %extracted_03, %extracted_13 : !ct_ty + %fa2_3 = cggi.and %fa2_1, %fa1_c : !ct_ty + %fa2_s = cggi.xor %fa2_1, %fa1_c : !ct_ty + %fa2_c = cggi.xor %fa2_2, %fa2_3 : !ct_ty + %fa3_1 = cggi.xor %extracted_04, %extracted_14 : !ct_ty + %fa3_2 = cggi.and %extracted_04, %extracted_14 : !ct_ty + %fa3_3 = cggi.and %fa3_1, %fa2_c : !ct_ty + %fa3_s = cggi.xor %fa3_1, %fa2_c : !ct_ty + %fa3_c = cggi.xor %fa3_2, %fa3_3 : !ct_ty + %fa4_1 = cggi.xor %extracted_05, %extracted_15 : !ct_ty + %fa4_2 = cggi.and %extracted_05, %extracted_15 : !ct_ty + %fa4_3 = cggi.and %fa4_1, %fa3_c : !ct_ty + %fa4_s = cggi.xor %fa4_1, %fa3_c : !ct_ty + %fa4_c = cggi.xor %fa4_2, %fa4_3 : !ct_ty + %fa5_1 = cggi.xor %extracted_06, %extracted_16 : !ct_ty + %fa5_2 = cggi.and %extracted_06, %extracted_16 : !ct_ty + %fa5_3 = cggi.and %fa5_1, %fa4_c : !ct_ty + %fa5_s = cggi.xor %fa5_1, %fa4_c : !ct_ty + %fa5_c = cggi.xor %fa5_2, %fa5_3 : !ct_ty + %fa6_1 = cggi.xor %extracted_07, %extracted_17 : !ct_ty + %fa6_2 = cggi.and %extracted_07, %extracted_17 : !ct_ty + %fa6_3 = cggi.and %fa6_1, %fa5_c : !ct_ty + %fa6_s = cggi.xor %fa6_1, %fa5_c : !ct_ty + %fa6_c = cggi.xor %fa6_2, %fa6_3 : !ct_ty + %from_elements = tensor.from_elements %fa6_s, %fa5_s, %fa4_s, %fa3_s, %fa2_s, %fa1_s, %fa0_s, %ha_s : tensor<8x!ct_ty> + return %from_elements : tensor<8x!ct_ty> +} diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir b/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir index 8ecca7cad..d22cadf5d 100644 --- a/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir +++ b/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir @@ -1,13 +1,13 @@ // This test ensures the testing harness is working properly with minimal codegen. // RUN: heir-translate %s --emit-tfhe-rust-bool > %S/src/fn_under_test.rs -// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main -- 1 1 | FileCheck %s +// RUN: cargo run --release --manifest-path %S/Cargo.toml -- 1 1 | FileCheck %s !bsks = !tfhe_rust_bool.server_key !eb = !tfhe_rust_bool.eb // CHECK: 1 func.func @fn_under_test(%bsks : !bsks, %a: tensor<8x!eb>, %b: tensor<8x!eb>) -> tensor<8x!eb> { - %res = tfhe_rust_bool.and_packed %bsks, %a, %b: (!bsks, tensor<8x!eb>, tensor<8x!eb>) -> tensor<8x!eb> + %res = tfhe_rust_bool.and %bsks, %a, %b: (!bsks, tensor<8x!eb>, tensor<8x!eb>) -> tensor<8x!eb> return %res : tensor<8x!eb> } diff --git a/tests/tfhe_rust_bool/ops.mlir b/tests/tfhe_rust_bool/ops.mlir index 1eb7f7353..b5ca17d68 100644 --- a/tests/tfhe_rust_bool/ops.mlir +++ b/tests/tfhe_rust_bool/ops.mlir @@ -29,7 +29,7 @@ module { // CHECK-LABEL: func @test_packed_and func.func @test_packed_and(%bsks : !bsks, %lhs : tensor<4x!eb>, %rhs : tensor<4x!eb>) { - %out = tfhe_rust_bool.and_packed %bsks, %lhs, %rhs: (!bsks, tensor<4x!eb>, tensor<4x!eb>) -> tensor<4x!eb> + %out = tfhe_rust_bool.and %bsks, %lhs, %rhs: (!bsks, tensor<4x!eb>, tensor<4x!eb>) -> tensor<4x!eb> return } }