Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update tfhe-rs-bool-emitter: cggi to tfhe-rs-bool fpga ready code #569

Merged
merged 1 commit into from Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 7 additions & 28 deletions include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td
Expand Up @@ -24,16 +24,20 @@ def CreateTrivialOp : TfheRustBool_Op<"create_trivial", [Pure]> {
let results = (outs TfheRustBool_Encrypted:$output);
}

// --- Operations for a gate-bootstrapping API of a CGGI library ---
def TfheRustBoolLike : TypeOrContainer<TfheRustBool_Encrypted, "eb-like">;


class TfheRustBool_BinaryGateOp<string mnemonic>
: TfheRustBool_Op<mnemonic, [
Pure,
AllTypesMatch<["lhs", "rhs", "output"]>
]> {
let arguments = (ins TfheRustBool_ServerKey:$serverKey,
TfheRustBool_Encrypted:$lhs,
TfheRustBool_Encrypted:$rhs
TfheRustBoolLike:$lhs,
TfheRustBoolLike:$rhs
);
let results = (outs TfheRustBool_Encrypted:$output);
let results = (outs TfheRustBoolLike:$output);
}

def AndOp : TfheRustBool_BinaryGateOp<"and"> { let summary = "Logical AND of two TFHE-rs Bool ciphertexts."; }
Expand All @@ -43,31 +47,6 @@ def NorOp : TfheRustBool_BinaryGateOp<"nor"> { let summary = "Logical NOR of t
def XorOp : TfheRustBool_BinaryGateOp<"xor"> { let summary = "Logical XOR of two TFHE-rs Bool ciphertexts."; }
def XnorOp : TfheRustBool_BinaryGateOp<"xnor"> { let summary = "Logical XNOR of two TFHE-rs Bool ciphertexts."; }


def AndPackedOp : TfheRustBool_Op<"and_packed", [
Pure,
AllTypesMatch<["lhs", "rhs", "output"]>
]> {
let arguments = (ins
TfheRustBool_ServerKey:$serverKey,
TensorOf<[TfheRustBool_Encrypted]>:$lhs,
TensorOf<[TfheRustBool_Encrypted]>:$rhs
);
let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output);
}

def XorPackedOp : TfheRustBool_Op<"xor_packed", [
Pure,
AllTypesMatch<["lhs", "rhs", "output"]>
]> {
let arguments = (ins
TfheRustBool_ServerKey:$serverKey,
TensorOf<[TfheRustBool_Encrypted]>:$lhs,
TensorOf<[TfheRustBool_Encrypted]>:$rhs
);
let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output);
}

def NotOp : TfheRustBool_Op<"not", [
Pure,
AllTypesMatch<["input", "output"]>
Expand Down
3 changes: 1 addition & 2 deletions include/Target/TfheRustBool/TfheRustBoolEmitter.h
Expand Up @@ -59,8 +59,6 @@ class TfheRustBoolEmitter {
LogicalResult printOperation(XorOp op);
LogicalResult printOperation(XnorOp op);

LogicalResult printOperation(AndPackedOp op);

// Helpers for above
LogicalResult printSksMethod(::mlir::Value result, ::mlir::Value sks,
::mlir::ValueRange nonSksOperands,
Expand All @@ -72,6 +70,7 @@ class TfheRustBoolEmitter {
FailureOr<std::string> convertType(Type type);

void emitAssignPrefix(::mlir::Value result);
void emitReferenceConversion(::mlir::Value value);
};

} // namespace tfhe_rust_bool
Expand Down
117 changes: 84 additions & 33 deletions lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp
Expand Up @@ -64,7 +64,7 @@ LogicalResult TfheRustBoolEmitter::translate(Operation &op) {
// Arith ops
.Case<arith::ConstantOp>([&](auto op) { return printOperation(op); })
// TfheRustBool ops
.Case<AndOp, NandOp, OrOp, NorOp, XorOp, XnorOp, AndPackedOp>(
.Case<AndOp, NandOp, OrOp, NorOp, XorOp, XnorOp>(
[&](auto op) { return printOperation(op); })
// Tensor ops
.Case<tensor::ExtractOp, tensor::FromElementsOp>(
Expand Down Expand Up @@ -144,11 +144,14 @@ LogicalResult TfheRustBoolEmitter::printOperation(func::FuncOp funcOp) {

LogicalResult TfheRustBoolEmitter::printOperation(func::ReturnOp op) {
std::function<std::string(Value)> valueOrClonedValue = [&](Value value) {
auto cloneStr = "";
auto suffix = "";
if (isa<BlockArgument>(value)) {
cloneStr = ".clone()";
suffix = ".clone()";
}
return variableNames->getNameForValue(value) + cloneStr;
if (isa<tensor::FromElementsOp>(value.getDefiningOp())) {
suffix = ".into_iter().cloned().collect()";
}
return variableNames->getNameForValue(value) + suffix;
};

if (op.getNumOperands() == 1) {
Expand All @@ -165,29 +168,77 @@ void TfheRustBoolEmitter::emitAssignPrefix(Value result) {
os << "let " << variableNames->getNameForValue(result) << " = ";
}

void TfheRustBoolEmitter::emitReferenceConversion(Value value) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could you also add a quick assertion here that the value element type is a EncryptedUint (or maybe add this to the doc-string in the header file that it's expected to run on shaped types of encrypted values?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tfhe-rs-bool emitter only produces EncryptedBools
I added an additional check

auto tensorType = dyn_cast<TensorType>(value.getType());

if (isa<EncryptedBoolType>(tensorType.getElementType())) {
auto varName = variableNames->getNameForValue(value);

os << "let " << varName << "_ref = " << varName << ".clone();\n";
os << "let " << varName << "_ref: Vec<&Ciphertext> = " << varName
<< ".iter().collect();\n";
}
}

LogicalResult TfheRustBoolEmitter::printSksMethod(
::mlir::Value result, ::mlir::Value sks, ::mlir::ValueRange nonSksOperands,
std::string_view op, SmallVector<std::string> operandTypes) {
emitAssignPrefix(result);

auto operandTypesIt = operandTypes.begin();
os << variableNames->getNameForValue(sks) << "." << op << "(";
os << commaSeparatedValues(nonSksOperands, [&](Value value) {
auto *prefix = value.getType().hasTrait<PassByReference>() ? "&" : "";
// First check if a DefiningOp exists
// if not: comes from function definition
mlir::Operation *op = value.getDefiningOp();
if (op) {
prefix = isa<tensor::ExtractOp>(op) ? "" : prefix;
} else {
prefix = "";
if (isa<TensorType>(nonSksOperands[0].getType())) {
auto *opParent = nonSksOperands[0].getDefiningOp();

if (!opParent) {
for (auto nonSksOperand : nonSksOperands) {
emitReferenceConversion(nonSksOperand);
}
}

return prefix + variableNames->getNameForValue(value) +
(!operandTypes.empty() ? " as " + *operandTypesIt++ : "");
});
os << ");\n";
return success();
emitAssignPrefix(result);

os << variableNames->getNameForValue(sks) << "." << op << "_packed(";
os << commaSeparatedValues(
{nonSksOperands[0], nonSksOperands[1]}, [&](Value value) {
auto *prefix = "&";
auto suffix = "";
// First check if a DefiningOp exists
// if not: comes from function definition
mlir::Operation *opParent = value.getDefiningOp();
if (opParent) {
if (!isa<tensor::FromElementsOp>(value.getDefiningOp()) &
!isa<tensor::ExtractOp>(opParent))
prefix = "";

} else {
prefix = "&";
suffix = "_ref";
}

return prefix + variableNames->getNameForValue(value) + suffix;
});
os << ");\n";
return success();

} else {
emitAssignPrefix(result);

auto operandTypesIt = operandTypes.begin();
os << variableNames->getNameForValue(sks) << "." << op << "(";
os << commaSeparatedValues(nonSksOperands, [&](Value value) {
auto *prefix = value.getType().hasTrait<PassByReference>() ? "&" : "";
// First check if a DefiningOp exists
// if not: comes from function definition
WoutLegiest marked this conversation as resolved.
Show resolved Hide resolved
mlir::Operation *op = value.getDefiningOp();
if (op) {
prefix = isa<tensor::ExtractOp>(op) ? "" : prefix;
} else {
prefix = "";
}

return prefix + variableNames->getNameForValue(value) +
(!operandTypes.empty() ? " as " + *operandTypesIt++ : "");
});
os << ");\n";
return success();
}
}

LogicalResult TfheRustBoolEmitter::printOperation(CreateTrivialOp op) {
Expand Down Expand Up @@ -215,6 +266,7 @@ LogicalResult TfheRustBoolEmitter::printOperation(arith::ConstantOp op) {
return success();
}

// Produces a &Ciphertext
LogicalResult TfheRustBoolEmitter::printOperation(tensor::ExtractOp op) {
// We assume here that the indices are SSA values (not integer attributes).
emitAssignPrefix(op.getResult());
Expand All @@ -226,15 +278,18 @@ LogicalResult TfheRustBoolEmitter::printOperation(tensor::ExtractOp op) {
return success();
}

// Need to produce a Vec<&Ciphertext>
LogicalResult TfheRustBoolEmitter::printOperation(tensor::FromElementsOp op) {
emitAssignPrefix(op.getResult());
os << "vec![" << commaSeparatedValues(op.getOperands(), [&](Value value) {
// Check if block argument, if so, clone.
auto cloneStr = "";
if (isa<BlockArgument>(value)) {
cloneStr = ".clone()";
}
return variableNames->getNameForValue(value) + cloneStr;
auto cloneStr = isa<BlockArgument>(value) ? ".clone()" : "";
// Get the name of defining operation its dialect
auto tfhe_op =
value.getDefiningOp()->getDialect()->getNamespace() == "tfhe_rust_bool";
auto prefix = tfhe_op ? "&" : "";
return std::string(prefix) + variableNames->getNameForValue(value) +
cloneStr;
}) << "];\n";
return success();
}
Expand Down Expand Up @@ -269,11 +324,6 @@ LogicalResult TfheRustBoolEmitter::printOperation(XnorOp op) {
{op.getLhs(), op.getRhs()}, "xnor");
}

LogicalResult TfheRustBoolEmitter::printOperation(AndPackedOp op) {
return printSksMethod(op.getResult(), op.getServerKey(),
{op.getLhs(), op.getRhs()}, "and_packed");
}

FailureOr<std::string> TfheRustBoolEmitter::convertType(Type type) {
// Note: these are probably not the right type names to use exactly, and they
// will need to chance to the right values once we try to compile it against
Expand All @@ -283,7 +333,8 @@ FailureOr<std::string> TfheRustBoolEmitter::convertType(Type type) {
// FIXME: why can't both types be FailureOr<std::string>?
auto elementTy = convertType(shapedType.getElementType());
if (failed(elementTy)) return failure();
return std::string("Vec<" + elementTy.value() + ">");

return std::string(std::string("Vec<") + elementTy.value() + ">");
}
return llvm::TypeSwitch<Type &, FailureOr<std::string>>(type)
.Case<EncryptedBoolType>(
Expand Down
1 change: 1 addition & 0 deletions tests/tfhe_rust_bool/end_to_end_fpga/BUILD
Expand Up @@ -12,6 +12,7 @@ glob_lit_tests(
data = [
"Cargo.toml",
"src/main.rs",
"tfhe-rs",
"@heir//tests:test_utilities",
],
default_tags = [
Expand Down
8 changes: 7 additions & 1 deletion tests/tfhe_rust_bool/end_to_end_fpga/README.md
Expand Up @@ -19,10 +19,16 @@ Cargo home `$HOME/.cargo` may need to be replaced by your custom `$CARGO_HOME`,
if you overrode the default option when installing Cargo.

```bash
bazel query "filter('.mlir.test$', //tests/tfhe_rust_bool/end_to_end/...)" \
bazel query "filter('.mlir.test$', //tests/tfhe_rust_bool/end_to_end_fpga/...)" \
| xargs bazel test --sandbox_writable_path=$HOME/.cargo "$@"
```

Manually generate the Rust code for the CGGI lowering:
```bash
bazel run //tools:heir-opt -- -cse --straight-line-vectorize --cggi-to-tfhe-rust-bool -cse $(pwd)/tests/cggi_to_tfhe_rust_bool/add_bool.mlir | bazel run //tools:heir-translate -- --emit-tfhe-rust-bool
```


The `manual` tag is added to the targets in this directory to ensure that they
are not run when someone runs a glob test like `bazel test //...`.

Expand Down
12 changes: 9 additions & 3 deletions tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs
@@ -1,9 +1,11 @@
#[allow(unused_imports)]
use std::time::Instant;

use clap::Parser;
use tfhe::boolean::prelude::*;

use tfhe::boolean::engine::BooleanEngine;
use tfhe::boolean::prelude::*;
use std::time::Instant;

#[cfg(feature = "fpga")]
use tfhe::boolean::server_key::FpgaGates;
Expand Down Expand Up @@ -72,11 +74,15 @@ fn main() {
let ct_1 = encrypt(flags.input1.into(), &client_key);
let ct_2 = encrypt(flags.input2.into(), &client_key);

let ct_1= ct_1.iter().collect();
let ct_2= ct_2.iter().collect();
// timing placeholders to quickly obtain the measurements of the generated function
// let t = Instant::now();

let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2);

// let run = t.elapsed().as_millis();

// println!("{:?}", run);

let output = decrypt(&result, &client_key);

println!("{:08b}", output);
Expand Down
8 changes: 2 additions & 6 deletions tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir
@@ -1,14 +1,10 @@
// RUN: heir-translate %s --emit-tfhe-rust-bool > %S/src/fn_under_test.rs
// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main_add_one -- 1 1 | FileCheck %s
// RUN: cargo run --release --manifest-path %S/Cargo.toml -- 1 1 | FileCheck %s

!bsks = !tfhe_rust_bool.server_key
!eb = !tfhe_rust_bool.eb

// CHECK-LABEL: pub fn fn_under_test(
// CHECK-NEXT: [[bsks:v[0-9]+]]: &ServerKey,
// CHECK-NEXT: [[input1:v[0-9]+]]: &Vec<Ciphertext>,
// CHECK-NEXT: [[input2:v[0-9]+]]: &Vec<Ciphertext>,
// CHECK-NEXT: ) -> Vec<Ciphertext> {
// CHECK: 01000000
func.func @fn_under_test(%bsks : !bsks, %arg0: tensor<8x!eb>, %arg1: tensor<8x!eb>) -> tensor<8x!eb> {
%c7 = arith.constant 7 : index
%c6 = arith.constant 6 : index
Expand Down