Skip to content

Commit

Permalink
fix: fix issues with void return in rust emission
Browse files Browse the repository at this point in the history
Signed-off-by: Asra <asraa@google.com>
  • Loading branch information
asraa committed Mar 5, 2024
1 parent e574416 commit 09f256a
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 23 deletions.
57 changes: 34 additions & 23 deletions lib/Target/TfheRust/TfheRustEmitter.cpp
Expand Up @@ -211,38 +211,44 @@ LogicalResult TfheRustEmitter::printOperation(func::FuncOp funcOp) {
}
}
os.unindent();
os << ") -> ";
os << ")";

if (serverKeyArg_.empty()) {
return funcOp.emitWarning() << "expected server key function argument to "
"create default ciphertexts";
}

if (funcOp.getNumResults() == 1) {
Type result = funcOp.getResultTypes()[0];
if (failed(emitType(result))) {
return funcOp.emitOpError() << "Failed to emit tfhe-rs type " << result;
if (funcOp.getNumResults() > 0) {
os << " -> ";
if (funcOp.getNumResults() == 1) {
Type result = funcOp.getResultTypes()[0];
if (failed(emitType(result))) {
return funcOp.emitOpError() << "Failed to emit tfhe-rs type " << result;
}
} else {
auto result = commaSeparatedTypes(
funcOp.getResultTypes(), [&](Type type) -> FailureOr<std::string> {
auto result = convertType(type);
if (failed(result)) {
return funcOp.emitOpError()
<< "Failed to emit tfhe-rs type " << type;
}
return result;
});
os << "(" << result.value() << ")";
}
} else {
auto result = commaSeparatedTypes(
funcOp.getResultTypes(), [&](Type type) -> FailureOr<std::string> {
auto result = convertType(type);
if (failed(result)) {
return funcOp.emitOpError()
<< "Failed to emit tfhe-rs type " << type;
}
return result;
});
os << "(" << result.value() << ")";
}

os << " {\n";
os.indent();

// Create a global temp_nodes hashmap for any created SSA values.
// TODO(#462): Insert block argument that are encrypted ints into temp_nodes.
os << "let mut temp_nodes : HashMap<usize, Ciphertext> = HashMap::new();\n";
os << "let mut luts : HashMap<&str, LookupTableOwned> = HashMap::new();\n";
// TODO(#462): Insert block argument that are encrypted ints into
// temp_nodes.
os << "let mut temp_nodes : HashMap<usize, Ciphertext> = "
"HashMap::new();\n";
os << "let mut luts : HashMap<&str, LookupTableOwned> = "
"HashMap::new();\n";

os << kRunLevelDefn << "\n";

Expand Down Expand Up @@ -285,6 +291,10 @@ LogicalResult TfheRustEmitter::printOperation(func::ReturnOp op) {
return variableNames->getNameForValue(value);
};

if (op.getNumOperands() == 0) {
return success();
}

if (op.getNumOperands() == 1) {
os << valueOrClonedValue(op.getOperands()[0]) << "\n";
return success();
Expand Down Expand Up @@ -415,9 +425,9 @@ LogicalResult TfheRustEmitter::printOperation(affine::AffineForOp forOp) {
variableNames->getIntForValue(op->getResult(0)),
commaSeparatedValues(
getCiphertextOperands(op->getOperands()), [&](Value value) {
// TODO(#462): This assumes that all ciphertexts are loaded
// into temp_nodes. Currently, block arguments are not
// supported.
// TODO(#462): This assumes that all ciphertexts are
// loaded into temp_nodes. Currently, block arguments are
// not supported.
return "Tv(" +
std::to_string(variableNames->getIntForValue(value)) +
")";
Expand Down Expand Up @@ -552,7 +562,8 @@ LogicalResult TfheRustEmitter::printOperation(::mlir::arith::TruncIOp op) {
}

LogicalResult TfheRustEmitter::printOperation(tensor::ExtractOp op) {
// We assume here that the indices are SSA values (not integer attributes).
// We assume here that the indices are SSA values (not integer
// attributes).
emitAssignPrefix(op.getResult());
os << "&" << variableNames->getNameForValue(op.getTensor()) << "["
<< commaSeparatedValues(
Expand Down
9 changes: 9 additions & 0 deletions lib/Target/Utils.cpp
Expand Up @@ -16,13 +16,19 @@ namespace heir {

std::string commaSeparatedValues(
ValueRange values, std::function<std::string(Value)> valueToString) {
if (values.empty()) {
return std::string();
}
return std::accumulate(
std::next(values.begin()), values.end(), valueToString(values[0]),
[&](std::string a, Value b) { return a + ", " + valueToString(b); });
}

FailureOr<std::string> commaSeparatedTypes(
TypeRange types, std::function<FailureOr<std::string>(Type)> typeToString) {
if (types.empty()) {
return std::string();
}
return std::accumulate(
std::next(types.begin()), types.end(), typeToString(types[0]),
[&](FailureOr<std::string> a, Type b) -> FailureOr<std::string> {
Expand All @@ -36,6 +42,9 @@ FailureOr<std::string> commaSeparatedTypes(

std::string bracketEnclosedValues(
ValueRange values, std::function<std::string(Value)> valueToString) {
if (values.empty()) {
return std::string();
}
return std::accumulate(
std::next(values.begin()), values.end(),
"[" + valueToString(values[0]) + "]",
Expand Down
4 changes: 4 additions & 0 deletions tests/tfhe_rust/ops.mlir
@@ -1,10 +1,12 @@
// RUN: heir-opt %s | FileCheck %s
// RUN: heir-translate --emit-tfhe-rust %s | FileCheck --check-prefix=RS %s

// This simply tests for syntax.
!sks = !tfhe_rust.server_key

module {
// CHECK-LABEL: func @test_create_trivial
// RS-LABEL: pub fn test_create_trivial
func.func @test_create_trivial(%sks : !sks) {
%0 = arith.constant 1 : i8
%1 = arith.constant 1 : i3
Expand All @@ -16,6 +18,7 @@ module {
}

// CHECK-LABEL: func @test_bitand
// RS-LABEL: pub fn test_bitand
func.func @test_bitand(%sks : !sks) {
%0 = arith.constant 1 : i1
%1 = arith.constant 1 : i1
Expand All @@ -28,6 +31,7 @@ module {


// CHECK-LABEL: func @test_apply_lookup_table
// RS-LABEL: pub fn test_apply_lookup_table
func.func @test_apply_lookup_table(%sks : !sks, %lut: !tfhe_rust.lookup_table) {
%0 = arith.constant 1 : i3
%1 = arith.constant 2 : i3
Expand Down

0 comments on commit 09f256a

Please sign in to comment.