Skip to content

Commit

Permalink
Update tfhe-rs-emitter: cggi to tfhe-rs-bool fpga ready code
Browse files Browse the repository at this point in the history
  • Loading branch information
WoutLegiest committed Mar 27, 2024
1 parent c67f0a0 commit ad9b7b1
Show file tree
Hide file tree
Showing 9 changed files with 180 additions and 73 deletions.
37 changes: 8 additions & 29 deletions include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td
Expand Up @@ -24,16 +24,20 @@ def CreateTrivialOp : TfheRustBool_Op<"create_trivial", [Pure]> {
let results = (outs TfheRustBool_Encrypted:$output);
}

// --- Operations for a gate-bootstrapping API of a CGGI library ---
def TfheRustBoolLike : TypeOrContainer<TfheRustBool_Encrypted, "eb-like">;


class TfheRustBool_BinaryGateOp<string mnemonic>
: TfheRustBool_Op<mnemonic, [
Pure,
AllTypesMatch<["lhs", "rhs", "output"]>
]> {
let arguments = (ins TfheRustBool_ServerKey:$serverKey,
TfheRustBool_Encrypted:$lhs,
TfheRustBool_Encrypted:$rhs
TfheRustBoolLike:$lhs,
TfheRustBoolLike:$rhs
);
let results = (outs TfheRustBool_Encrypted:$output);
let results = (outs TfheRustBoolLike:$output);
}

def AndOp : TfheRustBool_BinaryGateOp<"and"> { let summary = "Logical AND of two TFHE-rs Bool ciphertexts."; }
Expand All @@ -43,31 +47,6 @@ def NorOp : TfheRustBool_BinaryGateOp<"nor"> { let summary = "Logical NOR of t
def XorOp : TfheRustBool_BinaryGateOp<"xor"> { let summary = "Logical XOR of two TFHE-rs Bool ciphertexts."; }
def XnorOp : TfheRustBool_BinaryGateOp<"xnor"> { let summary = "Logical XNOR of two TFHE-rs Bool ciphertexts."; }


def AndPackedOp : TfheRustBool_Op<"and_packed", [
Pure,
AllTypesMatch<["lhs", "rhs", "output"]>
]> {
let arguments = (ins
TfheRustBool_ServerKey:$serverKey,
TensorOf<[TfheRustBool_Encrypted]>:$lhs,
TensorOf<[TfheRustBool_Encrypted]>:$rhs
);
let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output);
}

def XorPackedOp : TfheRustBool_Op<"xor_packed", [
Pure,
AllTypesMatch<["lhs", "rhs", "output"]>
]> {
let arguments = (ins
TfheRustBool_ServerKey:$serverKey,
TensorOf<[TfheRustBool_Encrypted]>:$lhs,
TensorOf<[TfheRustBool_Encrypted]>:$rhs
);
let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output);
}

def NotOp : TfheRustBool_Op<"not", [
Pure,
AllTypesMatch<["input", "output"]>
Expand All @@ -94,4 +73,4 @@ def MuxOp : TfheRustBool_Op<"mux", [



#endif // INCLUDE_DIALECT_TFHERUSTBOOL_IR_TFHERUSTBOOLOPS_TD_
#endif // INCLUDE_DIALECT_TFHERUSTBOOL_IR_TFHERUSTBOOLOPS_TD_
5 changes: 2 additions & 3 deletions include/Target/TfheRustBool/TfheRustBoolEmitter.h
Expand Up @@ -59,8 +59,6 @@ class TfheRustBoolEmitter {
LogicalResult printOperation(XorOp op);
LogicalResult printOperation(XnorOp op);

LogicalResult printOperation(AndPackedOp op);

// Helpers for above
LogicalResult printSksMethod(::mlir::Value result, ::mlir::Value sks,
::mlir::ValueRange nonSksOperands,
Expand All @@ -72,10 +70,11 @@ class TfheRustBoolEmitter {
FailureOr<std::string> convertType(Type type);

void emitAssignPrefix(::mlir::Value result);
void emitReferenceConversion(::mlir::Value value);
};

} // namespace tfhe_rust_bool
} // namespace heir
} // namespace mlir

#endif // INCLUDE_TARGET_TFHERUSTBOOL_TFHERUSTBOOLEMITTER_H_
#endif // INCLUDE_TARGET_TFHERUSTBOOL_TFHERUSTBOOLEMITTER_H_
106 changes: 75 additions & 31 deletions lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp
Expand Up @@ -64,7 +64,7 @@ LogicalResult TfheRustBoolEmitter::translate(Operation &op) {
// Arith ops
.Case<arith::ConstantOp>([&](auto op) { return printOperation(op); })
// TfheRustBool ops
.Case<AndOp, NandOp, OrOp, NorOp, XorOp, XnorOp, AndPackedOp>(
.Case<AndOp, NandOp, OrOp, NorOp, XorOp, XnorOp>(
[&](auto op) { return printOperation(op); })
// Tensor ops
.Case<tensor::ExtractOp, tensor::FromElementsOp>(
Expand Down Expand Up @@ -148,6 +148,9 @@ LogicalResult TfheRustBoolEmitter::printOperation(func::ReturnOp op) {
if (isa<BlockArgument>(value)) {
cloneStr = ".clone()";
}
if (isa<tensor::FromElementsOp>(value.getDefiningOp())) {
cloneStr = ".into_iter().cloned().collect()";
}
return variableNames->getNameForValue(value) + cloneStr;
};

Expand All @@ -165,29 +168,70 @@ void TfheRustBoolEmitter::emitAssignPrefix(Value result) {
os << "let " << variableNames->getNameForValue(result) << " = ";
}

void TfheRustBoolEmitter::emitReferenceConversion(Value value) {
auto varName = variableNames->getNameForValue(value);
os << "let " << varName << "_ref = " << varName << ".clone();\n";
os << "let " << varName << "_ref: Vec<&Ciphertext> = " << varName
<< ".iter().collect();\n";
}

LogicalResult TfheRustBoolEmitter::printSksMethod(
::mlir::Value result, ::mlir::Value sks, ::mlir::ValueRange nonSksOperands,
std::string_view op, SmallVector<std::string> operandTypes) {
emitAssignPrefix(result);

auto operandTypesIt = operandTypes.begin();
os << variableNames->getNameForValue(sks) << "." << op << "(";
os << commaSeparatedValues(nonSksOperands, [&](Value value) {
auto *prefix = value.getType().hasTrait<PassByReference>() ? "&" : "";
// First check if a DefiningOp exists
// if not: comes from function definition
mlir::Operation *op = value.getDefiningOp();
if (op) {
prefix = isa<tensor::ExtractOp>(op) ? "" : prefix;
} else {
prefix = "";
if (isa<TensorType>(nonSksOperands[0].getType())) {
auto *opParent = nonSksOperands[0].getDefiningOp();
if (!opParent) {
for (auto nonSksOperand : nonSksOperands) {
emitReferenceConversion(nonSksOperand);
}
}

return prefix + variableNames->getNameForValue(value) +
(!operandTypes.empty() ? " as " + *operandTypesIt++ : "");
});
os << ");\n";
return success();
emitAssignPrefix(result);

os << variableNames->getNameForValue(sks) << "." << op << "_packed(";
os << commaSeparatedValues(
{nonSksOperands[0], nonSksOperands[1]}, [&](Value value) {
auto *prefix = value.getType().hasTrait<PassByReference>() ? "&" : "";
auto suffix = "";
// First check if a DefiningOp exists
// if not: comes from function definition
mlir::Operation *opParent = value.getDefiningOp();
if (opParent) {
prefix = isa<tensor::ExtractOp>(opParent) ? prefix : "";
prefix =
isa<tensor::FromElementsOp>(value.getDefiningOp()) ? "&" : "";
} else {
prefix = "&";
suffix = "_ref";
}

return prefix + variableNames->getNameForValue(value) + suffix;
});
os << ");\n";
return success();

} else {
emitAssignPrefix(result);

auto operandTypesIt = operandTypes.begin();
os << variableNames->getNameForValue(sks) << "." << op << "(";
os << commaSeparatedValues(nonSksOperands, [&](Value value) {
auto *prefix = value.getType().hasTrait<PassByReference>() ? "&" : "";
// First check if a DefiningOp exists
// if not: comes from function definition
mlir::Operation *op = value.getDefiningOp();
if (op) {
prefix = isa<tensor::ExtractOp>(op) ? "" : prefix;
} else {
prefix = "";
}

return prefix + variableNames->getNameForValue(value) +
(!operandTypes.empty() ? " as " + *operandTypesIt++ : "");
});
os << ");\n";
return success();
}
}

LogicalResult TfheRustBoolEmitter::printOperation(CreateTrivialOp op) {
Expand Down Expand Up @@ -215,6 +259,7 @@ LogicalResult TfheRustBoolEmitter::printOperation(arith::ConstantOp op) {
return success();
}

// Produces a &Ciphertext
LogicalResult TfheRustBoolEmitter::printOperation(tensor::ExtractOp op) {
// We assume here that the indices are SSA values (not integer attributes).
emitAssignPrefix(op.getResult());
Expand All @@ -226,15 +271,18 @@ LogicalResult TfheRustBoolEmitter::printOperation(tensor::ExtractOp op) {
return success();
}

// Need to produce a Vec<&Ciphertext>
LogicalResult TfheRustBoolEmitter::printOperation(tensor::FromElementsOp op) {
emitAssignPrefix(op.getResult());
os << "vec![" << commaSeparatedValues(op.getOperands(), [&](Value value) {
// Check if block argument, if so, clone.
auto cloneStr = "";
if (isa<BlockArgument>(value)) {
cloneStr = ".clone()";
}
return variableNames->getNameForValue(value) + cloneStr;
auto cloneStr = isa<BlockArgument>(value) ? ".clone()" : "";
// Get the name of defining operation its dialect
auto tfhe_op =
value.getDefiningOp()->getDialect()->getNamespace() == "tfhe_rust_bool";
auto prefix = tfhe_op ? "&" : "";
return std::string(prefix) + variableNames->getNameForValue(value) +
cloneStr;
}) << "];\n";
return success();
}
Expand Down Expand Up @@ -269,11 +317,6 @@ LogicalResult TfheRustBoolEmitter::printOperation(XnorOp op) {
{op.getLhs(), op.getRhs()}, "xnor");
}

LogicalResult TfheRustBoolEmitter::printOperation(AndPackedOp op) {
return printSksMethod(op.getResult(), op.getServerKey(),
{op.getLhs(), op.getRhs()}, "and_packed");
}

FailureOr<std::string> TfheRustBoolEmitter::convertType(Type type) {
// Note: these are probably not the right type names to use exactly, and they
// will need to chance to the right values once we try to compile it against
Expand All @@ -283,7 +326,8 @@ FailureOr<std::string> TfheRustBoolEmitter::convertType(Type type) {
// FIXME: why can't both types be FailureOr<std::string>?
auto elementTy = convertType(shapedType.getElementType());
if (failed(elementTy)) return failure();
return std::string("Vec<" + elementTy.value() + ">");

return std::string(std::string("Vec<") + elementTy.value() + ">");
}
return llvm::TypeSwitch<Type &, FailureOr<std::string>>(type)
.Case<EncryptedBoolType>(
Expand All @@ -306,4 +350,4 @@ TfheRustBoolEmitter::TfheRustBoolEmitter(raw_ostream &os,
: os(os), variableNames(variableNames) {}
} // namespace tfhe_rust_bool
} // namespace heir
} // namespace mlir
} // namespace mlir
1 change: 1 addition & 0 deletions tests/tfhe_rust_bool/end_to_end_fpga/BUILD
Expand Up @@ -12,6 +12,7 @@ glob_lit_tests(
data = [
"Cargo.toml",
"src/main.rs",
"tfhe-rs",
"@heir//tests:test_utilities",
],
default_tags = [
Expand Down
10 changes: 8 additions & 2 deletions tests/tfhe_rust_bool/end_to_end_fpga/README.md
Expand Up @@ -19,10 +19,16 @@ Cargo home `$HOME/.cargo` may need to be replaced by your custom `$CARGO_HOME`,
if you overrode the default option when installing Cargo.

```bash
bazel query "filter('.mlir.test$', //tests/tfhe_rust_bool/end_to_end/...)" \
bazel query "filter('.mlir.test$', //tests/tfhe_rust_bool/end_to_end_fpga/...)" \
| xargs bazel test --sandbox_writable_path=$HOME/.cargo "$@"
```

Manually generate the Rust code fort the CGGI lowering:
```bash
bazel run //tools:heir-opt -- -cse --straight-line-vectorize --cggi-to-tfhe-rust-bool -cse $(pwd)/tests/cggi_to_tfhe_rust_bool/add_bool.mlir | bazel run //tools:heir-translate -- --emit-tfhe-rust-bool
```


The `manual` tag is added to the targets in this directory to ensure that they
are not run when someone runs a glob test like `bazel test //...`.

Expand All @@ -41,4 +47,4 @@ If you don't do this correctly, you will see an error like this:
# | Caused by:
# | Read-only file system (os error 30)
# `-----------------------------
```
```
9 changes: 5 additions & 4 deletions tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs
Expand Up @@ -72,12 +72,13 @@ fn main() {
let ct_1 = encrypt(flags.input1.into(), &client_key);
let ct_2 = encrypt(flags.input2.into(), &client_key);

let ct_1= ct_1.iter().collect();
let ct_2= ct_2.iter().collect();

let t = Instant::now();
let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2);
let run = t.elapsed().as_millis();

// println!("{:?}", run);

let output = decrypt(&result, &client_key);

println!("{:08b}", output);
}
}
77 changes: 77 additions & 0 deletions tests/tfhe_rust_bool/end_to_end_fpga/test_cggi_add_bool.mlir
@@ -0,0 +1,77 @@
// This test ensures the testing harness is working properly with minimal codegen.

// RUN: heir-opt --straight-line-vectorize --cggi-to-tfhe-rust-bool -cse -remove-dead-values %s | heir-translate --emit-tfhe-rust-bool > %S/src/fn_under_test.rs
// RUN: cargo run --release --manifest-path %S/Cargo.toml -- 1 1 | FileCheck %s

#encoding = #lwe.unspecified_bit_field_encoding<cleartext_bitwidth = 1>
!ct_ty = !lwe.lwe_ciphertext<encoding = #encoding>
!pt_ty = !lwe.lwe_plaintext<encoding = #encoding>

// CHECK: 01000000
func.func @fn_under_test(%arg0: tensor<8x!ct_ty>, %arg1: tensor<8x!ct_ty>) -> tensor<8x!ct_ty> {
%true = arith.constant true
%false = arith.constant false
%c7 = arith.constant 7 : index
%c6 = arith.constant 6 : index
%c5 = arith.constant 5 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%extracted_00 = tensor.extract %arg0[%c0] : tensor<8x!ct_ty>
%extracted_01 = tensor.extract %arg0[%c1] : tensor<8x!ct_ty>
%extracted_02 = tensor.extract %arg0[%c2] : tensor<8x!ct_ty>
%extracted_03 = tensor.extract %arg0[%c3] : tensor<8x!ct_ty>
%extracted_04 = tensor.extract %arg0[%c4] : tensor<8x!ct_ty>
%extracted_05 = tensor.extract %arg0[%c5] : tensor<8x!ct_ty>
%extracted_06 = tensor.extract %arg0[%c6] : tensor<8x!ct_ty>
%extracted_07 = tensor.extract %arg0[%c7] : tensor<8x!ct_ty>
%extracted_10 = tensor.extract %arg1[%c0] : tensor<8x!ct_ty>
%extracted_11 = tensor.extract %arg1[%c1] : tensor<8x!ct_ty>
%extracted_12 = tensor.extract %arg1[%c2] : tensor<8x!ct_ty>
%extracted_13 = tensor.extract %arg1[%c3] : tensor<8x!ct_ty>
%extracted_14 = tensor.extract %arg1[%c4] : tensor<8x!ct_ty>
%extracted_15 = tensor.extract %arg1[%c5] : tensor<8x!ct_ty>
%extracted_16 = tensor.extract %arg1[%c6] : tensor<8x!ct_ty>
%extracted_17 = tensor.extract %arg1[%c7] : tensor<8x!ct_ty>
%ha_s = cggi.xor %extracted_00, %extracted_10 : !ct_ty
%ha_c = cggi.and %extracted_00, %extracted_10 : !ct_ty
%fa0_1 = cggi.xor %extracted_01, %extracted_11 : !ct_ty
%fa0_2 = cggi.and %extracted_01, %extracted_11 : !ct_ty
%fa0_3 = cggi.and %fa0_1, %ha_c : !ct_ty
%fa0_s = cggi.xor %fa0_1, %ha_c : !ct_ty
%fa0_c = cggi.xor %fa0_2, %fa0_3 : !ct_ty
%fa1_1 = cggi.xor %extracted_02, %extracted_12 : !ct_ty
%fa1_2 = cggi.and %extracted_02, %extracted_12 : !ct_ty
%fa1_3 = cggi.and %fa1_1, %fa0_c : !ct_ty
%fa1_s = cggi.xor %fa1_1, %fa0_c : !ct_ty
%fa1_c = cggi.xor %fa1_2, %fa1_3 : !ct_ty
%fa2_1 = cggi.xor %extracted_03, %extracted_13 : !ct_ty
%fa2_2 = cggi.and %extracted_03, %extracted_13 : !ct_ty
%fa2_3 = cggi.and %fa2_1, %fa1_c : !ct_ty
%fa2_s = cggi.xor %fa2_1, %fa1_c : !ct_ty
%fa2_c = cggi.xor %fa2_2, %fa2_3 : !ct_ty
%fa3_1 = cggi.xor %extracted_04, %extracted_14 : !ct_ty
%fa3_2 = cggi.and %extracted_04, %extracted_14 : !ct_ty
%fa3_3 = cggi.and %fa3_1, %fa2_c : !ct_ty
%fa3_s = cggi.xor %fa3_1, %fa2_c : !ct_ty
%fa3_c = cggi.xor %fa3_2, %fa3_3 : !ct_ty
%fa4_1 = cggi.xor %extracted_05, %extracted_15 : !ct_ty
%fa4_2 = cggi.and %extracted_05, %extracted_15 : !ct_ty
%fa4_3 = cggi.and %fa4_1, %fa3_c : !ct_ty
%fa4_s = cggi.xor %fa4_1, %fa3_c : !ct_ty
%fa4_c = cggi.xor %fa4_2, %fa4_3 : !ct_ty
%fa5_1 = cggi.xor %extracted_06, %extracted_16 : !ct_ty
%fa5_2 = cggi.and %extracted_06, %extracted_16 : !ct_ty
%fa5_3 = cggi.and %fa5_1, %fa4_c : !ct_ty
%fa5_s = cggi.xor %fa5_1, %fa4_c : !ct_ty
%fa5_c = cggi.xor %fa5_2, %fa5_3 : !ct_ty
%fa6_1 = cggi.xor %extracted_07, %extracted_17 : !ct_ty
%fa6_2 = cggi.and %extracted_07, %extracted_17 : !ct_ty
%fa6_3 = cggi.and %fa6_1, %fa5_c : !ct_ty
%fa6_s = cggi.xor %fa6_1, %fa5_c : !ct_ty
%fa6_c = cggi.xor %fa6_2, %fa6_3 : !ct_ty
%from_elements = tensor.from_elements %fa6_s, %fa5_s, %fa4_s, %fa3_s, %fa2_s, %fa1_s, %fa0_s, %ha_s : tensor<8x!ct_ty>
return %from_elements : tensor<8x!ct_ty>
}

0 comments on commit ad9b7b1

Please sign in to comment.