Skip to content

Commit

Permalink
Merge pull request #569 from WoutLegiest:try3
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 620288502
  • Loading branch information
Copybara-Service committed Mar 29, 2024
2 parents cdc3dee + 3d84825 commit ef52e8d
Show file tree
Hide file tree
Showing 10 changed files with 194 additions and 77 deletions.
35 changes: 7 additions & 28 deletions include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td
Expand Up @@ -24,16 +24,20 @@ def CreateTrivialOp : TfheRustBool_Op<"create_trivial", [Pure]> {
let results = (outs TfheRustBool_Encrypted:$output);
}

// --- Operations for a gate-bootstrapping API of a CGGI library ---
def TfheRustBoolLike : TypeOrContainer<TfheRustBool_Encrypted, "eb-like">;


class TfheRustBool_BinaryGateOp<string mnemonic>
: TfheRustBool_Op<mnemonic, [
Pure,
AllTypesMatch<["lhs", "rhs", "output"]>
]> {
let arguments = (ins TfheRustBool_ServerKey:$serverKey,
TfheRustBool_Encrypted:$lhs,
TfheRustBool_Encrypted:$rhs
TfheRustBoolLike:$lhs,
TfheRustBoolLike:$rhs
);
let results = (outs TfheRustBool_Encrypted:$output);
let results = (outs TfheRustBoolLike:$output);
}

def AndOp : TfheRustBool_BinaryGateOp<"and"> { let summary = "Logical AND of two TFHE-rs Bool ciphertexts."; }
Expand All @@ -43,31 +47,6 @@ def NorOp : TfheRustBool_BinaryGateOp<"nor"> { let summary = "Logical NOR of t
def XorOp : TfheRustBool_BinaryGateOp<"xor"> { let summary = "Logical XOR of two TFHE-rs Bool ciphertexts."; }
def XnorOp : TfheRustBool_BinaryGateOp<"xnor"> { let summary = "Logical XNOR of two TFHE-rs Bool ciphertexts."; }


def AndPackedOp : TfheRustBool_Op<"and_packed", [
Pure,
AllTypesMatch<["lhs", "rhs", "output"]>
]> {
let arguments = (ins
TfheRustBool_ServerKey:$serverKey,
TensorOf<[TfheRustBool_Encrypted]>:$lhs,
TensorOf<[TfheRustBool_Encrypted]>:$rhs
);
let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output);
}

def XorPackedOp : TfheRustBool_Op<"xor_packed", [
Pure,
AllTypesMatch<["lhs", "rhs", "output"]>
]> {
let arguments = (ins
TfheRustBool_ServerKey:$serverKey,
TensorOf<[TfheRustBool_Encrypted]>:$lhs,
TensorOf<[TfheRustBool_Encrypted]>:$rhs
);
let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output);
}

def NotOp : TfheRustBool_Op<"not", [
Pure,
AllTypesMatch<["input", "output"]>
Expand Down
3 changes: 1 addition & 2 deletions include/Target/TfheRustBool/TfheRustBoolEmitter.h
Expand Up @@ -59,8 +59,6 @@ class TfheRustBoolEmitter {
LogicalResult printOperation(XorOp op);
LogicalResult printOperation(XnorOp op);

LogicalResult printOperation(AndPackedOp op);

// Helpers for above
LogicalResult printSksMethod(::mlir::Value result, ::mlir::Value sks,
::mlir::ValueRange nonSksOperands,
Expand All @@ -72,6 +70,7 @@ class TfheRustBoolEmitter {
FailureOr<std::string> convertType(Type type);

void emitAssignPrefix(::mlir::Value result);
void emitReferenceConversion(::mlir::Value value);
};

} // namespace tfhe_rust_bool
Expand Down
117 changes: 84 additions & 33 deletions lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp
Expand Up @@ -64,7 +64,7 @@ LogicalResult TfheRustBoolEmitter::translate(Operation &op) {
// Arith ops
.Case<arith::ConstantOp>([&](auto op) { return printOperation(op); })
// TfheRustBool ops
.Case<AndOp, NandOp, OrOp, NorOp, XorOp, XnorOp, AndPackedOp>(
.Case<AndOp, NandOp, OrOp, NorOp, XorOp, XnorOp>(
[&](auto op) { return printOperation(op); })
// Tensor ops
.Case<tensor::ExtractOp, tensor::FromElementsOp>(
Expand Down Expand Up @@ -144,11 +144,14 @@ LogicalResult TfheRustBoolEmitter::printOperation(func::FuncOp funcOp) {

LogicalResult TfheRustBoolEmitter::printOperation(func::ReturnOp op) {
std::function<std::string(Value)> valueOrClonedValue = [&](Value value) {
auto cloneStr = "";
auto suffix = "";
if (isa<BlockArgument>(value)) {
cloneStr = ".clone()";
suffix = ".clone()";
}
return variableNames->getNameForValue(value) + cloneStr;
if (isa<tensor::FromElementsOp>(value.getDefiningOp())) {
suffix = ".into_iter().cloned().collect()";
}
return variableNames->getNameForValue(value) + suffix;
};

if (op.getNumOperands() == 1) {
Expand All @@ -165,29 +168,77 @@ void TfheRustBoolEmitter::emitAssignPrefix(Value result) {
os << "let " << variableNames->getNameForValue(result) << " = ";
}

void TfheRustBoolEmitter::emitReferenceConversion(Value value) {
auto tensorType = dyn_cast<TensorType>(value.getType());

if (isa<EncryptedBoolType>(tensorType.getElementType())) {
auto varName = variableNames->getNameForValue(value);

os << "let " << varName << "_ref = " << varName << ".clone();\n";
os << "let " << varName << "_ref: Vec<&Ciphertext> = " << varName
<< ".iter().collect();\n";
}
}

LogicalResult TfheRustBoolEmitter::printSksMethod(
::mlir::Value result, ::mlir::Value sks, ::mlir::ValueRange nonSksOperands,
std::string_view op, SmallVector<std::string> operandTypes) {
emitAssignPrefix(result);

auto operandTypesIt = operandTypes.begin();
os << variableNames->getNameForValue(sks) << "." << op << "(";
os << commaSeparatedValues(nonSksOperands, [&](Value value) {
auto *prefix = value.getType().hasTrait<PassByReference>() ? "&" : "";
// First check if a DefiningOp exists
// if not: comes from function definition
mlir::Operation *op = value.getDefiningOp();
if (op) {
prefix = isa<tensor::ExtractOp>(op) ? "" : prefix;
} else {
prefix = "";
if (isa<TensorType>(nonSksOperands[0].getType())) {
auto *opParent = nonSksOperands[0].getDefiningOp();

if (!opParent) {
for (auto nonSksOperand : nonSksOperands) {
emitReferenceConversion(nonSksOperand);
}
}

return prefix + variableNames->getNameForValue(value) +
(!operandTypes.empty() ? " as " + *operandTypesIt++ : "");
});
os << ");\n";
return success();
emitAssignPrefix(result);

os << variableNames->getNameForValue(sks) << "." << op << "_packed(";
os << commaSeparatedValues(
{nonSksOperands[0], nonSksOperands[1]}, [&](Value value) {
auto *prefix = "&";
auto suffix = "";
// First check if a DefiningOp exists
// if not: comes from function definition
mlir::Operation *opParent = value.getDefiningOp();
if (opParent) {
if (!isa<tensor::FromElementsOp>(value.getDefiningOp()) &
!isa<tensor::ExtractOp>(opParent))
prefix = "";

} else {
prefix = "&";
suffix = "_ref";
}

return prefix + variableNames->getNameForValue(value) + suffix;
});
os << ");\n";
return success();

} else {
emitAssignPrefix(result);

auto operandTypesIt = operandTypes.begin();
os << variableNames->getNameForValue(sks) << "." << op << "(";
os << commaSeparatedValues(nonSksOperands, [&](Value value) {
auto *prefix = value.getType().hasTrait<PassByReference>() ? "&" : "";
// First check if a DefiningOp exists
// if not: comes from function definition
mlir::Operation *op = value.getDefiningOp();
if (op) {
prefix = isa<tensor::ExtractOp>(op) ? "" : prefix;
} else {
prefix = "";
}

return prefix + variableNames->getNameForValue(value) +
(!operandTypes.empty() ? " as " + *operandTypesIt++ : "");
});
os << ");\n";
return success();
}
}

LogicalResult TfheRustBoolEmitter::printOperation(CreateTrivialOp op) {
Expand Down Expand Up @@ -215,6 +266,7 @@ LogicalResult TfheRustBoolEmitter::printOperation(arith::ConstantOp op) {
return success();
}

// Produces a &Ciphertext
LogicalResult TfheRustBoolEmitter::printOperation(tensor::ExtractOp op) {
// We assume here that the indices are SSA values (not integer attributes).
emitAssignPrefix(op.getResult());
Expand All @@ -226,15 +278,18 @@ LogicalResult TfheRustBoolEmitter::printOperation(tensor::ExtractOp op) {
return success();
}

// Need to produce a Vec<&Ciphertext>
LogicalResult TfheRustBoolEmitter::printOperation(tensor::FromElementsOp op) {
emitAssignPrefix(op.getResult());
os << "vec![" << commaSeparatedValues(op.getOperands(), [&](Value value) {
// Check if block argument, if so, clone.
auto cloneStr = "";
if (isa<BlockArgument>(value)) {
cloneStr = ".clone()";
}
return variableNames->getNameForValue(value) + cloneStr;
auto cloneStr = isa<BlockArgument>(value) ? ".clone()" : "";
// Get the name of defining operation its dialect
auto tfhe_op =
value.getDefiningOp()->getDialect()->getNamespace() == "tfhe_rust_bool";
auto prefix = tfhe_op ? "&" : "";
return std::string(prefix) + variableNames->getNameForValue(value) +
cloneStr;
}) << "];\n";
return success();
}
Expand Down Expand Up @@ -269,11 +324,6 @@ LogicalResult TfheRustBoolEmitter::printOperation(XnorOp op) {
{op.getLhs(), op.getRhs()}, "xnor");
}

LogicalResult TfheRustBoolEmitter::printOperation(AndPackedOp op) {
return printSksMethod(op.getResult(), op.getServerKey(),
{op.getLhs(), op.getRhs()}, "and_packed");
}

FailureOr<std::string> TfheRustBoolEmitter::convertType(Type type) {
// Note: these are probably not the right type names to use exactly, and they
// will need to chance to the right values once we try to compile it against
Expand All @@ -283,7 +333,8 @@ FailureOr<std::string> TfheRustBoolEmitter::convertType(Type type) {
// FIXME: why can't both types be FailureOr<std::string>?
auto elementTy = convertType(shapedType.getElementType());
if (failed(elementTy)) return failure();
return std::string("Vec<" + elementTy.value() + ">");

return std::string(std::string("Vec<") + elementTy.value() + ">");
}
return llvm::TypeSwitch<Type &, FailureOr<std::string>>(type)
.Case<EncryptedBoolType>(
Expand Down
1 change: 1 addition & 0 deletions tests/tfhe_rust_bool/end_to_end_fpga/BUILD
Expand Up @@ -12,6 +12,7 @@ glob_lit_tests(
data = [
"Cargo.toml",
"src/main.rs",
"tfhe-rs",
"@heir//tests:test_utilities",
],
default_tags = [
Expand Down
8 changes: 7 additions & 1 deletion tests/tfhe_rust_bool/end_to_end_fpga/README.md
Expand Up @@ -19,10 +19,16 @@ Cargo home `$HOME/.cargo` may need to be replaced by your custom `$CARGO_HOME`,
if you overrode the default option when installing Cargo.

```bash
bazel query "filter('.mlir.test$', //tests/tfhe_rust_bool/end_to_end/...)" \
bazel query "filter('.mlir.test$', //tests/tfhe_rust_bool/end_to_end_fpga/...)" \
| xargs bazel test --sandbox_writable_path=$HOME/.cargo "$@"
```

Manually generate the Rust code for the CGGI lowering:
```bash
bazel run //tools:heir-opt -- -cse --straight-line-vectorize --cggi-to-tfhe-rust-bool -cse $(pwd)/tests/cggi_to_tfhe_rust_bool/add_bool.mlir | bazel run //tools:heir-translate -- --emit-tfhe-rust-bool
```


The `manual` tag is added to the targets in this directory to ensure that they
are not run when someone runs a glob test like `bazel test //...`.

Expand Down
12 changes: 9 additions & 3 deletions tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs
@@ -1,9 +1,11 @@
#[allow(unused_imports)]
use std::time::Instant;

use clap::Parser;
use tfhe::boolean::prelude::*;

use tfhe::boolean::engine::BooleanEngine;
use tfhe::boolean::prelude::*;
use std::time::Instant;

#[cfg(feature = "fpga")]
use tfhe::boolean::server_key::FpgaGates;
Expand Down Expand Up @@ -72,11 +74,15 @@ fn main() {
let ct_1 = encrypt(flags.input1.into(), &client_key);
let ct_2 = encrypt(flags.input2.into(), &client_key);

let ct_1= ct_1.iter().collect();
let ct_2= ct_2.iter().collect();
// timing placeholders to quickly obtain the measurements of the generated function
// let t = Instant::now();

let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2);

// let run = t.elapsed().as_millis();

// println!("{:?}", run);

let output = decrypt(&result, &client_key);

println!("{:08b}", output);
Expand Down
8 changes: 2 additions & 6 deletions tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir
@@ -1,14 +1,10 @@
// RUN: heir-translate %s --emit-tfhe-rust-bool > %S/src/fn_under_test.rs
// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main_add_one -- 1 1 | FileCheck %s
// RUN: cargo run --release --manifest-path %S/Cargo.toml -- 1 1 | FileCheck %s

!bsks = !tfhe_rust_bool.server_key
!eb = !tfhe_rust_bool.eb

// CHECK-LABEL: pub fn fn_under_test(
// CHECK-NEXT: [[bsks:v[0-9]+]]: &ServerKey,
// CHECK-NEXT: [[input1:v[0-9]+]]: &Vec<Ciphertext>,
// CHECK-NEXT: [[input2:v[0-9]+]]: &Vec<Ciphertext>,
// CHECK-NEXT: ) -> Vec<Ciphertext> {
// CHECK: 01000000
func.func @fn_under_test(%bsks : !bsks, %arg0: tensor<8x!eb>, %arg1: tensor<8x!eb>) -> tensor<8x!eb> {
%c7 = arith.constant 7 : index
%c6 = arith.constant 6 : index
Expand Down

0 comments on commit ef52e8d

Please sign in to comment.