Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tfhe-rs-bool emitter for (vectorized) cggi to fpga #557

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 7 additions & 28 deletions include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td
Expand Up @@ -24,16 +24,20 @@ def CreateTrivialOp : TfheRustBool_Op<"create_trivial", [Pure]> {
let results = (outs TfheRustBool_Encrypted:$output);
}

// --- Operations for a gate-bootstrapping API of a CGGI library ---
def TfheRustBoolLike : TypeOrContainer<TfheRustBool_Encrypted, "eb-like">;


class TfheRustBool_BinaryGateOp<string mnemonic>
: TfheRustBool_Op<mnemonic, [
Pure,
AllTypesMatch<["lhs", "rhs", "output"]>
]> {
let arguments = (ins TfheRustBool_ServerKey:$serverKey,
TfheRustBool_Encrypted:$lhs,
TfheRustBool_Encrypted:$rhs
TfheRustBoolLike:$lhs,
TfheRustBoolLike:$rhs
);
let results = (outs TfheRustBool_Encrypted:$output);
let results = (outs TfheRustBoolLike:$output);
}

def AndOp : TfheRustBool_BinaryGateOp<"and"> { let summary = "Logical AND of two TFHE-rs Bool ciphertexts."; }
Expand All @@ -43,31 +47,6 @@ def NorOp : TfheRustBool_BinaryGateOp<"nor"> { let summary = "Logical NOR of t
def XorOp : TfheRustBool_BinaryGateOp<"xor"> { let summary = "Logical XOR of two TFHE-rs Bool ciphertexts."; }
def XnorOp : TfheRustBool_BinaryGateOp<"xnor"> { let summary = "Logical XNOR of two TFHE-rs Bool ciphertexts."; }


def AndPackedOp : TfheRustBool_Op<"and_packed", [
Pure,
AllTypesMatch<["lhs", "rhs", "output"]>
]> {
let arguments = (ins
TfheRustBool_ServerKey:$serverKey,
TensorOf<[TfheRustBool_Encrypted]>:$lhs,
TensorOf<[TfheRustBool_Encrypted]>:$rhs
);
let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output);
}

def XorPackedOp : TfheRustBool_Op<"xor_packed", [
Pure,
AllTypesMatch<["lhs", "rhs", "output"]>
]> {
let arguments = (ins
TfheRustBool_ServerKey:$serverKey,
TensorOf<[TfheRustBool_Encrypted]>:$lhs,
TensorOf<[TfheRustBool_Encrypted]>:$rhs
);
let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output);
}

def NotOp : TfheRustBool_Op<"not", [
Pure,
AllTypesMatch<["input", "output"]>
Expand Down
3 changes: 1 addition & 2 deletions include/Target/TfheRustBool/TfheRustBoolEmitter.h
Expand Up @@ -59,8 +59,6 @@ class TfheRustBoolEmitter {
LogicalResult printOperation(XorOp op);
LogicalResult printOperation(XnorOp op);

LogicalResult printOperation(AndPackedOp op);

// Helpers for above
LogicalResult printSksMethod(::mlir::Value result, ::mlir::Value sks,
::mlir::ValueRange nonSksOperands,
Expand All @@ -72,6 +70,7 @@ class TfheRustBoolEmitter {
FailureOr<std::string> convertType(Type type);

void emitAssignPrefix(::mlir::Value result);
void emitReferenceConversion(::mlir::Value value);
};

} // namespace tfhe_rust_bool
Expand Down
104 changes: 74 additions & 30 deletions lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp
Expand Up @@ -64,7 +64,7 @@ LogicalResult TfheRustBoolEmitter::translate(Operation &op) {
// Arith ops
.Case<arith::ConstantOp>([&](auto op) { return printOperation(op); })
// TfheRustBool ops
.Case<AndOp, NandOp, OrOp, NorOp, XorOp, XnorOp, AndPackedOp>(
.Case<AndOp, NandOp, OrOp, NorOp, XorOp, XnorOp>(
[&](auto op) { return printOperation(op); })
// Tensor ops
.Case<tensor::ExtractOp, tensor::FromElementsOp>(
Expand Down Expand Up @@ -148,6 +148,9 @@ LogicalResult TfheRustBoolEmitter::printOperation(func::ReturnOp op) {
if (isa<BlockArgument>(value)) {
cloneStr = ".clone()";
}
if (isa<tensor::FromElementsOp>(value.getDefiningOp())) {
cloneStr = ".into_iter().cloned().collect()";
}
return variableNames->getNameForValue(value) + cloneStr;
};

Expand All @@ -165,29 +168,70 @@ void TfheRustBoolEmitter::emitAssignPrefix(Value result) {
os << "let " << variableNames->getNameForValue(result) << " = ";
}

void TfheRustBoolEmitter::emitReferenceConversion(Value value) {
auto varName = variableNames->getNameForValue(value);
os << "let " << varName << "_ref = " << varName << ".clone();\n";
os << "let " << varName << "_ref: Vec<&Ciphertext> = " << varName
<< ".iter().collect();\n";
}

LogicalResult TfheRustBoolEmitter::printSksMethod(
::mlir::Value result, ::mlir::Value sks, ::mlir::ValueRange nonSksOperands,
std::string_view op, SmallVector<std::string> operandTypes) {
emitAssignPrefix(result);

auto operandTypesIt = operandTypes.begin();
os << variableNames->getNameForValue(sks) << "." << op << "(";
os << commaSeparatedValues(nonSksOperands, [&](Value value) {
auto *prefix = value.getType().hasTrait<PassByReference>() ? "&" : "";
// First check if a DefiningOp exists
// if not: comes from function definition
mlir::Operation *op = value.getDefiningOp();
if (op) {
prefix = isa<tensor::ExtractOp>(op) ? "" : prefix;
} else {
prefix = "";
if (isa<TensorType>(nonSksOperands[0].getType())) {
auto *opParent = nonSksOperands[0].getDefiningOp();
if (!opParent) {
for (auto nonSksOperand : nonSksOperands) {
emitReferenceConversion(nonSksOperand);
}
}

return prefix + variableNames->getNameForValue(value) +
(!operandTypes.empty() ? " as " + *operandTypesIt++ : "");
});
os << ");\n";
return success();
emitAssignPrefix(result);

os << variableNames->getNameForValue(sks) << "." << op << "_packed(";
os << commaSeparatedValues(
{nonSksOperands[0], nonSksOperands[1]}, [&](Value value) {
auto *prefix = value.getType().hasTrait<PassByReference>() ? "&" : "";
auto suffix = "";
// First check if a DefiningOp exists
// if not: comes from function definition
mlir::Operation *opParent = value.getDefiningOp();
if (opParent) {
prefix = isa<tensor::ExtractOp>(opParent) ? prefix : "";
prefix =
isa<tensor::FromElementsOp>(value.getDefiningOp()) ? "&" : "";
} else {
prefix = "&";
suffix = "_ref";
}

return prefix + variableNames->getNameForValue(value) + suffix;
});
os << ");\n";
return success();

} else {
emitAssignPrefix(result);

auto operandTypesIt = operandTypes.begin();
os << variableNames->getNameForValue(sks) << "." << op << "(";
os << commaSeparatedValues(nonSksOperands, [&](Value value) {
auto *prefix = value.getType().hasTrait<PassByReference>() ? "&" : "";
// First check if a DefiningOp exists
// if not: comes from function definition
mlir::Operation *op = value.getDefiningOp();
if (op) {
prefix = isa<tensor::ExtractOp>(op) ? "" : prefix;
} else {
prefix = "";
}

return prefix + variableNames->getNameForValue(value) +
(!operandTypes.empty() ? " as " + *operandTypesIt++ : "");
});
os << ");\n";
return success();
}
}

LogicalResult TfheRustBoolEmitter::printOperation(CreateTrivialOp op) {
Expand Down Expand Up @@ -215,6 +259,7 @@ LogicalResult TfheRustBoolEmitter::printOperation(arith::ConstantOp op) {
return success();
}

// Produces a &Ciphertext
LogicalResult TfheRustBoolEmitter::printOperation(tensor::ExtractOp op) {
// We assume here that the indices are SSA values (not integer attributes).
emitAssignPrefix(op.getResult());
Expand All @@ -226,15 +271,18 @@ LogicalResult TfheRustBoolEmitter::printOperation(tensor::ExtractOp op) {
return success();
}

// Need to produce a Vec<&Ciphertext>
LogicalResult TfheRustBoolEmitter::printOperation(tensor::FromElementsOp op) {
emitAssignPrefix(op.getResult());
os << "vec![" << commaSeparatedValues(op.getOperands(), [&](Value value) {
// Check if block argument, if so, clone.
auto cloneStr = "";
if (isa<BlockArgument>(value)) {
cloneStr = ".clone()";
}
return variableNames->getNameForValue(value) + cloneStr;
auto cloneStr = isa<BlockArgument>(value) ? ".clone()" : "";
// Get the name of defining operation its dialect
auto tfhe_op =
value.getDefiningOp()->getDialect()->getNamespace() == "tfhe_rust_bool";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could do something like this:
llvm::isa<tfhe_rust_bool::TfheRustBoolDialect>(op.getDialect())

auto prefix = tfhe_op ? "&" : "";
return std::string(prefix) + variableNames->getNameForValue(value) +
cloneStr;
}) << "];\n";
return success();
}
Expand Down Expand Up @@ -269,11 +317,6 @@ LogicalResult TfheRustBoolEmitter::printOperation(XnorOp op) {
{op.getLhs(), op.getRhs()}, "xnor");
}

LogicalResult TfheRustBoolEmitter::printOperation(AndPackedOp op) {
return printSksMethod(op.getResult(), op.getServerKey(),
{op.getLhs(), op.getRhs()}, "and_packed");
}

FailureOr<std::string> TfheRustBoolEmitter::convertType(Type type) {
// Note: these are probably not the right type names to use exactly, and they
// will need to chance to the right values once we try to compile it against
Expand All @@ -283,7 +326,8 @@ FailureOr<std::string> TfheRustBoolEmitter::convertType(Type type) {
// FIXME: why can't both types be FailureOr<std::string>?
auto elementTy = convertType(shapedType.getElementType());
if (failed(elementTy)) return failure();
return std::string("Vec<" + elementTy.value() + ">");

return std::string(std::string("Vec<") + elementTy.value() + ">");
}
return llvm::TypeSwitch<Type &, FailureOr<std::string>>(type)
.Case<EncryptedBoolType>(
Expand Down
1 change: 0 additions & 1 deletion tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs
Expand Up @@ -32,7 +32,6 @@ pub fn decrypt(ciphertexts: &Vec<Ciphertext>, client_key: &ClientKey) -> u8 {
accum |= (bit as u8) << i;
}
accum.reverse_bits()

}

fn main() {
Expand Down
1 change: 1 addition & 0 deletions tests/tfhe_rust_bool/end_to_end_fpga/BUILD
Expand Up @@ -12,6 +12,7 @@ glob_lit_tests(
data = [
"Cargo.toml",
"src/main.rs",
"tfhe-rs",
"@heir//tests:test_utilities",
],
default_tags = [
Expand Down
8 changes: 7 additions & 1 deletion tests/tfhe_rust_bool/end_to_end_fpga/README.md
Expand Up @@ -19,10 +19,16 @@ Cargo home `$HOME/.cargo` may need to be replaced by your custom `$CARGO_HOME`,
if you overrode the default option when installing Cargo.

```bash
bazel query "filter('.mlir.test$', //tests/tfhe_rust_bool/end_to_end/...)" \
bazel query "filter('.mlir.test$', //tests/tfhe_rust_bool/end_to_end_fpga/...)" \
| xargs bazel test --sandbox_writable_path=$HOME/.cargo "$@"
```

Manually generate the Rust code fort the CGGI lowering:
```bash
bazel run //tools:heir-opt -- -cse --straight-line-vectorize --cggi-to-tfhe-rust-bool -cse $(pwd)/tests/cggi_to_tfhe_rust_bool/add_bool.mlir | bazel run //tools:heir-translate -- --emit-tfhe-rust-bool
```


The `manual` tag is added to the targets in this directory to ensure that they
are not run when someone runs a glob test like `bazel test //...`.

Expand Down
7 changes: 4 additions & 3 deletions tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs
Expand Up @@ -72,10 +72,11 @@ fn main() {
let ct_1 = encrypt(flags.input1.into(), &client_key);
let ct_2 = encrypt(flags.input2.into(), &client_key);

let ct_1= ct_1.iter().collect();
let ct_2= ct_2.iter().collect();

let t = Instant::now();
let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2);
let run = t.elapsed().as_millis();

// println!("{:?}", run);

let output = decrypt(&result, &client_key);

Expand Down
77 changes: 77 additions & 0 deletions tests/tfhe_rust_bool/end_to_end_fpga/test_cggi_add_bool.mlir
@@ -0,0 +1,77 @@
// This test ensures the testing harness is working properly with minimal codegen.

// RUN: heir-opt --straight-line-vectorize --cggi-to-tfhe-rust-bool -cse -remove-dead-values %s | heir-translate --emit-tfhe-rust-bool > %S/src/fn_under_test.rs
// RUN: cargo run --release --manifest-path %S/Cargo.toml -- 1 1 | FileCheck %s

#encoding = #lwe.unspecified_bit_field_encoding<cleartext_bitwidth = 1>
!ct_ty = !lwe.lwe_ciphertext<encoding = #encoding>
!pt_ty = !lwe.lwe_plaintext<encoding = #encoding>

// CHECK: 01000000
func.func @fn_under_test(%arg0: tensor<8x!ct_ty>, %arg1: tensor<8x!ct_ty>) -> tensor<8x!ct_ty> {
%true = arith.constant true
%false = arith.constant false
%c7 = arith.constant 7 : index
%c6 = arith.constant 6 : index
%c5 = arith.constant 5 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%extracted_00 = tensor.extract %arg0[%c0] : tensor<8x!ct_ty>
%extracted_01 = tensor.extract %arg0[%c1] : tensor<8x!ct_ty>
%extracted_02 = tensor.extract %arg0[%c2] : tensor<8x!ct_ty>
%extracted_03 = tensor.extract %arg0[%c3] : tensor<8x!ct_ty>
%extracted_04 = tensor.extract %arg0[%c4] : tensor<8x!ct_ty>
%extracted_05 = tensor.extract %arg0[%c5] : tensor<8x!ct_ty>
%extracted_06 = tensor.extract %arg0[%c6] : tensor<8x!ct_ty>
%extracted_07 = tensor.extract %arg0[%c7] : tensor<8x!ct_ty>
%extracted_10 = tensor.extract %arg1[%c0] : tensor<8x!ct_ty>
%extracted_11 = tensor.extract %arg1[%c1] : tensor<8x!ct_ty>
%extracted_12 = tensor.extract %arg1[%c2] : tensor<8x!ct_ty>
%extracted_13 = tensor.extract %arg1[%c3] : tensor<8x!ct_ty>
%extracted_14 = tensor.extract %arg1[%c4] : tensor<8x!ct_ty>
%extracted_15 = tensor.extract %arg1[%c5] : tensor<8x!ct_ty>
%extracted_16 = tensor.extract %arg1[%c6] : tensor<8x!ct_ty>
%extracted_17 = tensor.extract %arg1[%c7] : tensor<8x!ct_ty>
%ha_s = cggi.xor %extracted_00, %extracted_10 : !ct_ty
%ha_c = cggi.and %extracted_00, %extracted_10 : !ct_ty
%fa0_1 = cggi.xor %extracted_01, %extracted_11 : !ct_ty
%fa0_2 = cggi.and %extracted_01, %extracted_11 : !ct_ty
%fa0_3 = cggi.and %fa0_1, %ha_c : !ct_ty
%fa0_s = cggi.xor %fa0_1, %ha_c : !ct_ty
%fa0_c = cggi.xor %fa0_2, %fa0_3 : !ct_ty
%fa1_1 = cggi.xor %extracted_02, %extracted_12 : !ct_ty
%fa1_2 = cggi.and %extracted_02, %extracted_12 : !ct_ty
%fa1_3 = cggi.and %fa1_1, %fa0_c : !ct_ty
%fa1_s = cggi.xor %fa1_1, %fa0_c : !ct_ty
%fa1_c = cggi.xor %fa1_2, %fa1_3 : !ct_ty
%fa2_1 = cggi.xor %extracted_03, %extracted_13 : !ct_ty
%fa2_2 = cggi.and %extracted_03, %extracted_13 : !ct_ty
%fa2_3 = cggi.and %fa2_1, %fa1_c : !ct_ty
%fa2_s = cggi.xor %fa2_1, %fa1_c : !ct_ty
%fa2_c = cggi.xor %fa2_2, %fa2_3 : !ct_ty
%fa3_1 = cggi.xor %extracted_04, %extracted_14 : !ct_ty
%fa3_2 = cggi.and %extracted_04, %extracted_14 : !ct_ty
%fa3_3 = cggi.and %fa3_1, %fa2_c : !ct_ty
%fa3_s = cggi.xor %fa3_1, %fa2_c : !ct_ty
%fa3_c = cggi.xor %fa3_2, %fa3_3 : !ct_ty
%fa4_1 = cggi.xor %extracted_05, %extracted_15 : !ct_ty
%fa4_2 = cggi.and %extracted_05, %extracted_15 : !ct_ty
%fa4_3 = cggi.and %fa4_1, %fa3_c : !ct_ty
%fa4_s = cggi.xor %fa4_1, %fa3_c : !ct_ty
%fa4_c = cggi.xor %fa4_2, %fa4_3 : !ct_ty
%fa5_1 = cggi.xor %extracted_06, %extracted_16 : !ct_ty
%fa5_2 = cggi.and %extracted_06, %extracted_16 : !ct_ty
%fa5_3 = cggi.and %fa5_1, %fa4_c : !ct_ty
%fa5_s = cggi.xor %fa5_1, %fa4_c : !ct_ty
%fa5_c = cggi.xor %fa5_2, %fa5_3 : !ct_ty
%fa6_1 = cggi.xor %extracted_07, %extracted_17 : !ct_ty
%fa6_2 = cggi.and %extracted_07, %extracted_17 : !ct_ty
%fa6_3 = cggi.and %fa6_1, %fa5_c : !ct_ty
%fa6_s = cggi.xor %fa6_1, %fa5_c : !ct_ty
%fa6_c = cggi.xor %fa6_2, %fa6_3 : !ct_ty
%from_elements = tensor.from_elements %fa6_s, %fa5_s, %fa4_s, %fa3_s, %fa2_s, %fa1_s, %fa0_s, %ha_s : tensor<8x!ct_ty>
return %from_elements : tensor<8x!ct_ty>
}
4 changes: 2 additions & 2 deletions tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir
@@ -1,13 +1,13 @@
// This test ensures the testing harness is working properly with minimal codegen.

// RUN: heir-translate %s --emit-tfhe-rust-bool > %S/src/fn_under_test.rs
// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main -- 1 1 | FileCheck %s
// RUN: cargo run --release --manifest-path %S/Cargo.toml -- 1 1 | FileCheck %s

!bsks = !tfhe_rust_bool.server_key
!eb = !tfhe_rust_bool.eb

// CHECK: 1
func.func @fn_under_test(%bsks : !bsks, %a: tensor<8x!eb>, %b: tensor<8x!eb>) -> tensor<8x!eb> {
%res = tfhe_rust_bool.and_packed %bsks, %a, %b: (!bsks, tensor<8x!eb>, tensor<8x!eb>) -> tensor<8x!eb>
%res = tfhe_rust_bool.and %bsks, %a, %b: (!bsks, tensor<8x!eb>, tensor<8x!eb>) -> tensor<8x!eb>
return %res : tensor<8x!eb>
}
2 changes: 1 addition & 1 deletion tests/tfhe_rust_bool/ops.mlir
Expand Up @@ -29,7 +29,7 @@ module {

// CHECK-LABEL: func @test_packed_and
func.func @test_packed_and(%bsks : !bsks, %lhs : tensor<4x!eb>, %rhs : tensor<4x!eb>) {
%out = tfhe_rust_bool.and_packed %bsks, %lhs, %rhs: (!bsks, tensor<4x!eb>, tensor<4x!eb>) -> tensor<4x!eb>
%out = tfhe_rust_bool.and %bsks, %lhs, %rhs: (!bsks, tensor<4x!eb>, tensor<4x!eb>) -> tensor<4x!eb>
return
}
}