Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix issues with void return in rust emission #480

Merged
merged 1 commit into from Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
57 changes: 34 additions & 23 deletions lib/Target/TfheRust/TfheRustEmitter.cpp
Expand Up @@ -211,38 +211,44 @@ LogicalResult TfheRustEmitter::printOperation(func::FuncOp funcOp) {
}
}
os.unindent();
os << ") -> ";
os << ")";

if (serverKeyArg_.empty()) {
return funcOp.emitWarning() << "expected server key function argument to "
"create default ciphertexts";
}

if (funcOp.getNumResults() == 1) {
Type result = funcOp.getResultTypes()[0];
if (failed(emitType(result))) {
return funcOp.emitOpError() << "Failed to emit tfhe-rs type " << result;
if (funcOp.getNumResults() > 0) {
os << " -> ";
if (funcOp.getNumResults() == 1) {
Type result = funcOp.getResultTypes()[0];
if (failed(emitType(result))) {
return funcOp.emitOpError() << "Failed to emit tfhe-rs type " << result;
}
} else {
auto result = commaSeparatedTypes(
funcOp.getResultTypes(), [&](Type type) -> FailureOr<std::string> {
auto result = convertType(type);
if (failed(result)) {
return funcOp.emitOpError()
<< "Failed to emit tfhe-rs type " << type;
}
return result;
});
os << "(" << result.value() << ")";
}
} else {
auto result = commaSeparatedTypes(
funcOp.getResultTypes(), [&](Type type) -> FailureOr<std::string> {
auto result = convertType(type);
if (failed(result)) {
return funcOp.emitOpError()
<< "Failed to emit tfhe-rs type " << type;
}
return result;
});
os << "(" << result.value() << ")";
}

os << " {\n";
os.indent();

// Create a global temp_nodes hashmap for any created SSA values.
// TODO(#462): Insert block argument that are encrypted ints into temp_nodes.
os << "let mut temp_nodes : HashMap<usize, Ciphertext> = HashMap::new();\n";
os << "let mut luts : HashMap<&str, LookupTableOwned> = HashMap::new();\n";
// TODO(#462): Insert block argument that are encrypted ints into
// temp_nodes.
os << "let mut temp_nodes : HashMap<usize, Ciphertext> = "
"HashMap::new();\n";
os << "let mut luts : HashMap<&str, LookupTableOwned> = "
"HashMap::new();\n";

os << kRunLevelDefn << "\n";

Expand Down Expand Up @@ -285,6 +291,10 @@ LogicalResult TfheRustEmitter::printOperation(func::ReturnOp op) {
return variableNames->getNameForValue(value);
};

if (op.getNumOperands() == 0) {
return success();
}

if (op.getNumOperands() == 1) {
os << valueOrClonedValue(op.getOperands()[0]) << "\n";
return success();
Expand Down Expand Up @@ -415,9 +425,9 @@ LogicalResult TfheRustEmitter::printOperation(affine::AffineForOp forOp) {
variableNames->getIntForValue(op->getResult(0)),
commaSeparatedValues(
getCiphertextOperands(op->getOperands()), [&](Value value) {
// TODO(#462): This assumes that all ciphertexts are loaded
// into temp_nodes. Currently, block arguments are not
// supported.
// TODO(#462): This assumes that all ciphertexts are
// loaded into temp_nodes. Currently, block arguments are
// not supported.
return "Tv(" +
std::to_string(variableNames->getIntForValue(value)) +
")";
Expand Down Expand Up @@ -552,7 +562,8 @@ LogicalResult TfheRustEmitter::printOperation(::mlir::arith::TruncIOp op) {
}

LogicalResult TfheRustEmitter::printOperation(tensor::ExtractOp op) {
// We assume here that the indices are SSA values (not integer attributes).
// We assume here that the indices are SSA values (not integer
// attributes).
emitAssignPrefix(op.getResult());
os << "&" << variableNames->getNameForValue(op.getTensor()) << "["
<< commaSeparatedValues(
Expand Down
9 changes: 9 additions & 0 deletions lib/Target/Utils.cpp
Expand Up @@ -16,13 +16,19 @@ namespace heir {

std::string commaSeparatedValues(
ValueRange values, std::function<std::string(Value)> valueToString) {
if (values.empty()) {
return std::string();
}
return std::accumulate(
std::next(values.begin()), values.end(), valueToString(values[0]),
[&](std::string a, Value b) { return a + ", " + valueToString(b); });
}

FailureOr<std::string> commaSeparatedTypes(
TypeRange types, std::function<FailureOr<std::string>(Type)> typeToString) {
if (types.empty()) {
return std::string();
}
return std::accumulate(
std::next(types.begin()), types.end(), typeToString(types[0]),
[&](FailureOr<std::string> a, Type b) -> FailureOr<std::string> {
Expand All @@ -36,6 +42,9 @@ FailureOr<std::string> commaSeparatedTypes(

std::string bracketEnclosedValues(
ValueRange values, std::function<std::string(Value)> valueToString) {
if (values.empty()) {
return std::string();
}
return std::accumulate(
std::next(values.begin()), values.end(),
"[" + valueToString(values[0]) + "]",
Expand Down
4 changes: 4 additions & 0 deletions tests/tfhe_rust/ops.mlir
@@ -1,10 +1,12 @@
// RUN: heir-opt %s | FileCheck %s
// RUN: heir-translate --emit-tfhe-rust %s | FileCheck --check-prefix=RS %s

// This simply tests for syntax.
!sks = !tfhe_rust.server_key

module {
// CHECK-LABEL: func @test_create_trivial
// RS-LABEL: pub fn test_create_trivial
func.func @test_create_trivial(%sks : !sks) {
%0 = arith.constant 1 : i8
%1 = arith.constant 1 : i3
Expand All @@ -16,6 +18,7 @@ module {
}

// CHECK-LABEL: func @test_bitand
// RS-LABEL: pub fn test_bitand
func.func @test_bitand(%sks : !sks) {
%0 = arith.constant 1 : i1
%1 = arith.constant 1 : i1
Expand All @@ -28,6 +31,7 @@ module {


// CHECK-LABEL: func @test_apply_lookup_table
// RS-LABEL: pub fn test_apply_lookup_table
func.func @test_apply_lookup_table(%sks : !sks, %lut: !tfhe_rust.lookup_table) {
%0 = arith.constant 1 : i3
%1 = arith.constant 2 : i3
Expand Down