From 09f256a0b368ac997beb536f9bd2f6aed9d8f966 Mon Sep 17 00:00:00 2001 From: Asra Date: Tue, 5 Mar 2024 17:32:23 +0000 Subject: [PATCH] fix: fix issues with void return in rust emission Signed-off-by: Asra --- lib/Target/TfheRust/TfheRustEmitter.cpp | 57 +++++++++++++++---------- lib/Target/Utils.cpp | 9 ++++ tests/tfhe_rust/ops.mlir | 4 ++ 3 files changed, 47 insertions(+), 23 deletions(-) diff --git a/lib/Target/TfheRust/TfheRustEmitter.cpp b/lib/Target/TfheRust/TfheRustEmitter.cpp index dcf9314a1..206b764a4 100644 --- a/lib/Target/TfheRust/TfheRustEmitter.cpp +++ b/lib/Target/TfheRust/TfheRustEmitter.cpp @@ -211,38 +211,44 @@ LogicalResult TfheRustEmitter::printOperation(func::FuncOp funcOp) { } } os.unindent(); - os << ") -> "; + os << ")"; if (serverKeyArg_.empty()) { return funcOp.emitWarning() << "expected server key function argument to " "create default ciphertexts"; } - if (funcOp.getNumResults() == 1) { - Type result = funcOp.getResultTypes()[0]; - if (failed(emitType(result))) { - return funcOp.emitOpError() << "Failed to emit tfhe-rs type " << result; + if (funcOp.getNumResults() > 0) { + os << " -> "; + if (funcOp.getNumResults() == 1) { + Type result = funcOp.getResultTypes()[0]; + if (failed(emitType(result))) { + return funcOp.emitOpError() << "Failed to emit tfhe-rs type " << result; + } + } else { + auto result = commaSeparatedTypes( + funcOp.getResultTypes(), [&](Type type) -> FailureOr { + auto result = convertType(type); + if (failed(result)) { + return funcOp.emitOpError() + << "Failed to emit tfhe-rs type " << type; + } + return result; + }); + os << "(" << result.value() << ")"; } - } else { - auto result = commaSeparatedTypes( - funcOp.getResultTypes(), [&](Type type) -> FailureOr { - auto result = convertType(type); - if (failed(result)) { - return funcOp.emitOpError() - << "Failed to emit tfhe-rs type " << type; - } - return result; - }); - os << "(" << result.value() << ")"; } os << " {\n"; os.indent(); // Create a global temp_nodes hashmap for any created SSA values. - // TODO(#462): Insert block argument that are encrypted ints into temp_nodes. - os << "let mut temp_nodes : HashMap = HashMap::new();\n"; - os << "let mut luts : HashMap<&str, LookupTableOwned> = HashMap::new();\n"; + // TODO(#462): Insert block argument that are encrypted ints into + // temp_nodes. + os << "let mut temp_nodes : HashMap = " + "HashMap::new();\n"; + os << "let mut luts : HashMap<&str, LookupTableOwned> = " + "HashMap::new();\n"; os << kRunLevelDefn << "\n"; @@ -285,6 +291,10 @@ LogicalResult TfheRustEmitter::printOperation(func::ReturnOp op) { return variableNames->getNameForValue(value); }; + if (op.getNumOperands() == 0) { + return success(); + } + if (op.getNumOperands() == 1) { os << valueOrClonedValue(op.getOperands()[0]) << "\n"; return success(); @@ -415,9 +425,9 @@ LogicalResult TfheRustEmitter::printOperation(affine::AffineForOp forOp) { variableNames->getIntForValue(op->getResult(0)), commaSeparatedValues( getCiphertextOperands(op->getOperands()), [&](Value value) { - // TODO(#462): This assumes that all ciphertexts are loaded - // into temp_nodes. Currently, block arguments are not - // supported. + // TODO(#462): This assumes that all ciphertexts are + // loaded into temp_nodes. Currently, block arguments are + // not supported. return "Tv(" + std::to_string(variableNames->getIntForValue(value)) + ")"; @@ -552,7 +562,8 @@ LogicalResult TfheRustEmitter::printOperation(::mlir::arith::TruncIOp op) { } LogicalResult TfheRustEmitter::printOperation(tensor::ExtractOp op) { - // We assume here that the indices are SSA values (not integer attributes). + // We assume here that the indices are SSA values (not integer + // attributes). emitAssignPrefix(op.getResult()); os << "&" << variableNames->getNameForValue(op.getTensor()) << "[" << commaSeparatedValues( diff --git a/lib/Target/Utils.cpp b/lib/Target/Utils.cpp index a3e6e01a5..17bb179d3 100644 --- a/lib/Target/Utils.cpp +++ b/lib/Target/Utils.cpp @@ -16,6 +16,9 @@ namespace heir { std::string commaSeparatedValues( ValueRange values, std::function valueToString) { + if (values.empty()) { + return std::string(); + } return std::accumulate( std::next(values.begin()), values.end(), valueToString(values[0]), [&](std::string a, Value b) { return a + ", " + valueToString(b); }); @@ -23,6 +26,9 @@ std::string commaSeparatedValues( FailureOr commaSeparatedTypes( TypeRange types, std::function(Type)> typeToString) { + if (types.empty()) { + return std::string(); + } return std::accumulate( std::next(types.begin()), types.end(), typeToString(types[0]), [&](FailureOr a, Type b) -> FailureOr { @@ -36,6 +42,9 @@ FailureOr commaSeparatedTypes( std::string bracketEnclosedValues( ValueRange values, std::function valueToString) { + if (values.empty()) { + return std::string(); + } return std::accumulate( std::next(values.begin()), values.end(), "[" + valueToString(values[0]) + "]", diff --git a/tests/tfhe_rust/ops.mlir b/tests/tfhe_rust/ops.mlir index ae135b9d7..a7284621b 100644 --- a/tests/tfhe_rust/ops.mlir +++ b/tests/tfhe_rust/ops.mlir @@ -1,10 +1,12 @@ // RUN: heir-opt %s | FileCheck %s +// RUN: heir-translate --emit-tfhe-rust %s | FileCheck --check-prefix=RS %s // This simply tests for syntax. !sks = !tfhe_rust.server_key module { // CHECK-LABEL: func @test_create_trivial + // RS-LABEL: pub fn test_create_trivial func.func @test_create_trivial(%sks : !sks) { %0 = arith.constant 1 : i8 %1 = arith.constant 1 : i3 @@ -16,6 +18,7 @@ module { } // CHECK-LABEL: func @test_bitand + // RS-LABEL: pub fn test_bitand func.func @test_bitand(%sks : !sks) { %0 = arith.constant 1 : i1 %1 = arith.constant 1 : i1 @@ -28,6 +31,7 @@ module { // CHECK-LABEL: func @test_apply_lookup_table + // RS-LABEL: pub fn test_apply_lookup_table func.func @test_apply_lookup_table(%sks : !sks, %lut: !tfhe_rust.lookup_table) { %0 = arith.constant 1 : i3 %1 = arith.constant 2 : i3