diff --git a/include/Analysis/SelectVariableNames/SelectVariableNames.h b/include/Analysis/SelectVariableNames/SelectVariableNames.h index e1ed7bf59..f97aa2dc3 100644 --- a/include/Analysis/SelectVariableNames/SelectVariableNames.h +++ b/include/Analysis/SelectVariableNames/SelectVariableNames.h @@ -1,6 +1,8 @@ #ifndef INCLUDE_ANALYSIS_SELECTVARIABLENAMES_SELECTVARIABLENAMES_H_ #define INCLUDE_ANALYSIS_SELECTVARIABLENAMES_SELECTVARIABLENAMES_H_ +#include + #include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project #include "mlir/include/mlir/IR/Operation.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project @@ -17,12 +19,20 @@ class SelectVariableNames { /// value was not assigned a name (suggesting the value was not in the IR /// tree that this class was constructed with). std::string getNameForValue(Value value) const { + assert(variableNames.contains(value)); + return prefix + std::to_string(variableNames.lookup(value)); + } + + // Return the unique integer assigned to a given value. + int getIntForValue(Value value) const { assert(variableNames.contains(value)); return variableNames.lookup(value); } private: - llvm::DenseMap variableNames; + llvm::DenseMap variableNames; + + std::string prefix{"v"}; }; } // namespace heir diff --git a/include/Target/TfheRust/BUILD b/include/Target/TfheRust/BUILD index d9e200f80..ba612be37 100644 --- a/include/Target/TfheRust/BUILD +++ b/include/Target/TfheRust/BUILD @@ -17,6 +17,7 @@ cc_library( "@heir//lib/Analysis/SelectVariableNames", "@heir//lib/Dialect/TfheRust/IR:Dialect", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/include/Target/TfheRust/TfheRustEmitter.h b/include/Target/TfheRust/TfheRustEmitter.h index a459d4146..d6a9c0ebd 100644 --- a/include/Target/TfheRust/TfheRustEmitter.h +++ b/include/Target/TfheRust/TfheRustEmitter.h @@ -7,7 +7,8 @@ #include "include/Analysis/SelectVariableNames/SelectVariableNames.h" #include "include/Dialect/TfheRust/IR/TfheRustDialect.h" #include "include/Dialect/TfheRust/IR/TfheRustOps.h" -#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project @@ -51,6 +52,7 @@ class TfheRustEmitter { // Functions for printing individual ops LogicalResult printOperation(::mlir::ModuleOp op); LogicalResult printOperation(::mlir::arith::ConstantOp op); + LogicalResult printOperation(::mlir::arith::IndexCastOp op); LogicalResult printOperation(::mlir::arith::ShLIOp op); LogicalResult printOperation(::mlir::arith::AndIOp op); LogicalResult printOperation(::mlir::arith::ShRSIOp op); @@ -60,6 +62,7 @@ class TfheRustEmitter { LogicalResult printOperation(AddOp op); LogicalResult printOperation(BitAndOp op); LogicalResult printOperation(CreateTrivialOp op); + LogicalResult printOperation(affine::AffineForOp op); LogicalResult printOperation(tensor::ExtractOp op); LogicalResult printOperation(tensor::FromElementsOp op); LogicalResult printOperation(memref::AllocOp op); @@ -77,6 +80,9 @@ class TfheRustEmitter { SmallVector operandTypes = {}); LogicalResult printBinaryOp(::mlir::Value result, ::mlir::Value lhs, ::mlir::Value rhs, std::string_view op); + void printStoreOp(memref::StoreOp op, std::string valueToStore); + void printLoadOp(memref::LoadOp op); + std::string operationType(Operation *op); // Emit a TfheRust type LogicalResult emitType(Type type); diff --git a/lib/Analysis/SelectVariableNames/SelectVariableNames.cpp b/lib/Analysis/SelectVariableNames/SelectVariableNames.cpp index 369d54dca..13b2493c0 100644 --- a/lib/Analysis/SelectVariableNames/SelectVariableNames.cpp +++ b/lib/Analysis/SelectVariableNames/SelectVariableNames.cpp @@ -12,16 +12,15 @@ namespace heir { SelectVariableNames::SelectVariableNames(Operation *op) { int i = 0; - std::string prefix = "v"; op->walk([&](Operation *op) { for (Value result : op->getResults()) { - variableNames.try_emplace(result, prefix + std::to_string(i++)); + variableNames.try_emplace(result, i++); } for (Region ®ion : op->getRegions()) { for (Block &block : region) { for (Value arg : block.getArguments()) { - variableNames.try_emplace(arg, prefix + std::to_string(i++)); + variableNames.try_emplace(arg, i++); } } } diff --git a/lib/Target/TfheRust/BUILD b/lib/Target/TfheRust/BUILD index e558a924f..864f11255 100644 --- a/lib/Target/TfheRust/BUILD +++ b/lib/Target/TfheRust/BUILD @@ -13,10 +13,13 @@ cc_library( "@heir//include/Target/TfheRust:TfheRustEmitter.h", ], deps = [ + "@heir//include/Graph", "@heir//lib/Analysis/SelectVariableNames", "@heir//lib/Dialect/TfheRust/IR:Dialect", "@heir//lib/Target:Utils", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/lib/Target/TfheRust/TfheRustEmitter.cpp b/lib/Target/TfheRust/TfheRustEmitter.cpp index db0934bff..dcf9314a1 100644 --- a/lib/Target/TfheRust/TfheRustEmitter.cpp +++ b/lib/Target/TfheRust/TfheRustEmitter.cpp @@ -1,5 +1,7 @@ #include "include/Target/TfheRust/TfheRustEmitter.h" +#include +#include #include #include #include @@ -10,11 +12,18 @@ #include "include/Dialect/TfheRust/IR/TfheRustDialect.h" #include "include/Dialect/TfheRust/IR/TfheRustOps.h" #include "include/Dialect/TfheRust/IR/TfheRustTypes.h" +#include "include/Graph/Graph.h" #include "lib/Target/TfheRust/TfheRustTemplates.h" #include "lib/Target/Utils.h" -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project -#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project +#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "llvm/include/llvm/Support/ErrorHandling.h" // from @llvm-project +#include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/SliceAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project @@ -31,12 +40,55 @@ #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Tools/mlir-translate/Translation.h" // from @llvm-project +#define DEBUG_TYPE "tfhe-rust-emitter" + namespace mlir { namespace heir { namespace tfhe_rust { namespace { +graph::Graph getGraph(affine::AffineForOp forOp) { + graph::Graph graph; + auto block = forOp.getBody(); + + // Skip if there isn't any apply_lookup_table + if (llvm::none_of(block->getOperations(), [&](Operation &op) { + return isa(op); + })) { + return graph; + } + + for (auto &op : block->getOperations()) { + if (!isa(op)) { + continue; + } + graph.addVertex(&op); + for (auto operand : op.getOperands()) { + auto *definingOp = operand.getDefiningOp(); + if (!definingOp || definingOp->getBlock() != block || + !isa(definingOp)) { + continue; + } + graph.addEdge(definingOp, &op); + } + } + + return graph; +} + +SmallVector getCiphertextOperands(ValueRange inputs) { + SmallVector vals; + for (Value val : inputs) { + // TODO(#474): Generalize to any encrypted uint. + if (isa(val.getType())) { + vals.push_back(val); + } + } + + return vals; +} + // getRustIntegerType returns the width of the closest builtin integer type. FailureOr getRustIntegerType(int width) { for (int candidate : {8, 16, 32, 64, 128}) { @@ -74,7 +126,7 @@ void registerToTfheRustTranslation() { [](DialectRegistry ®istry) { registry.insert(); + memref::MemRefDialect, affine::AffineDialect>(); }); } @@ -93,9 +145,19 @@ LogicalResult TfheRustEmitter::translate(Operation &op) { // Func ops .Case( [&](auto op) { return printOperation(op); }) + // Affine ops + .Case( + [&](auto op) { return printOperation(op); }) + .Case([&](auto op) -> LogicalResult { + if (op->getNumResults() != 0) { + return op.emitOpError() + << "AffineYieldOp has non-zero number of results"; + } + return success(); + }) // Arith ops - .Case( + .Case( [&](auto op) { return printOperation(op); }) // TfheRust ops .Case( // todo subview & copy - [&](auto op) { return printOperation(op); }) + memref::StoreOp>([&](auto op) { return printOperation(op); }) .Default([&](Operation &) { return op.emitOpError("unable to find printer for op"); }); @@ -178,6 +239,13 @@ LogicalResult TfheRustEmitter::printOperation(func::FuncOp funcOp) { os << " {\n"; os.indent(); + // Create a global temp_nodes hashmap for any created SSA values. + // TODO(#462): Insert block argument that are encrypted ints into temp_nodes. + os << "let mut temp_nodes : HashMap = HashMap::new();\n"; + os << "let mut luts : HashMap<&str, LookupTableOwned> = HashMap::new();\n"; + + os << kRunLevelDefn << "\n"; + for (Block &block : funcOp.getBlocks()) { for (Operation &op : block.getOperations()) { if (failed(translate(op))) { @@ -193,11 +261,28 @@ LogicalResult TfheRustEmitter::printOperation(func::FuncOp funcOp) { LogicalResult TfheRustEmitter::printOperation(func::ReturnOp op) { std::function valueOrClonedValue = [&](Value value) { - auto cloneStr = ""; if (isa(value)) { - cloneStr = ".clone()"; + // Function arguments used as outputs must be cloned. + return variableNames->getNameForValue(value) + ".clone()"; + } else if (MemRefType memRefType = dyn_cast(value.getType())) { + auto shape = memRefType.getShape(); + // Internally allocated memrefs that are treated as hashmaps must be + // converted to arrays. + unsigned int i = 0; + std::string res = + variableNames->getNameForValue(value) + std::string(".get(&(") + + std::accumulate(std::next(shape.begin()), shape.end(), + std::string("i0"), + [&](std::string a, int64_t value) { + return a + ", i" + std::to_string(++i); + }) + + std::string(")).unwrap().clone()"); + for (unsigned _ : shape) { + res = llvm::formatv("core::array::from_fn(|i{0}| {1})", i--, res); + } + return res; } - return variableNames->getNameForValue(value) + cloneStr; + return variableNames->getNameForValue(value); }; if (op.getNumOperands() == 1) { @@ -221,14 +306,36 @@ LogicalResult TfheRustEmitter::printSksMethod( std::string_view op, SmallVector operandTypes) { emitAssignPrefix(result); - auto operandTypesIt = operandTypes.begin(); + if (!operandTypes.empty()) { + assert(operandTypes.size() == nonSksOperands.size() && + "invalid sizes of operandTypes"); + operandTypes = + llvm::to_vector(llvm::map_range(operandTypes, [&](std::string value) { + return value.empty() ? "" : " as " + value; + })); + } + auto *operandTypesIt = operandTypes.begin(); os << variableNames->getNameForValue(sks) << "." << op << "("; os << commaSeparatedValues(nonSksOperands, [&](Value value) { - const auto *prefix = value.getType().hasTrait() ? "&" : ""; - return prefix + variableNames->getNameForValue(value) + - (!operandTypes.empty() ? " as " + *operandTypesIt++ : ""); + auto valueStr = variableNames->getNameForValue(value); + if (isa(value.getType())) { + valueStr = "luts[\"" + variableNames->getNameForValue(value) + "\"]"; + } + std::string prefix = value.getType().hasTrait() ? "&" : ""; + std::string suffix = operandTypes.empty() ? "" : *operandTypesIt++; + return prefix + valueStr + suffix; }); + os << ");\n"; + + // Insert ciphertext results into temp_nodes so that the levelled ops can + // reference them. + // TODO(#474): Generalize to any encrypted uint. + if (isa(result.getType())) { + os << llvm::formatv("temp_nodes.insert({0}, {1}.clone());\n", + variableNames->getIntForValue(result), + variableNames->getNameForValue(result)); + } return success(); } @@ -258,17 +365,124 @@ LogicalResult TfheRustEmitter::printOperation(GenerateLookupTableOp op) { uint64_t truthTable = op.getTruthTable().getUInt(); auto result = op.getResult(); - emitAssignPrefix(result); + os << "luts.insert(\"" << variableNames->getNameForValue(result) << "\", "; + os << variableNames->getNameForValue(sks) << ".generate_lookup_table("; os << "|x| (" << std::to_string(truthTable) << " >> x) & 1"; - os << ");\n"; + + os << "));\n"; + return success(); +} + +std::string TfheRustEmitter::operationType(Operation *op) { + return llvm::TypeSwitch(op) + .Case([&](ApplyLookupTableOp op) { + return "LUT3(\"" + variableNames->getNameForValue(op.getLookupTable()) + + "\")"; + }) + .Case([&](ScalarLeftShiftOp op) { + auto constantShift = + cast(op.getShiftAmount().getDefiningOp()); + return "LSH(" + + std::to_string( + cast(constantShift.getValue()).getInt()) + + ")"; + }) + .Case([&](Operation *) { return "ADD"; }); +} + +LogicalResult TfheRustEmitter::printOperation(affine::AffineForOp forOp) { + os << "for " << variableNames->getNameForValue(forOp.getInductionVar()) + << " in " << forOp.getConstantLowerBound() << ".." + << forOp.getConstantUpperBound() << " {\n"; + os.indent(); + + auto graph = getGraph(forOp); + if (!graph.empty()) { + auto sortedGraph = graph.sortGraphByLevels(); + if (failed(sortedGraph)) { + llvm_unreachable("Only possible failure is a cycle in the SSA graph!"); + } + auto levels = sortedGraph.value(); + // Print lists of operations per level. + for (int level = 0; level < levels.size(); ++level) { + os << "static LEVEL_" << level << " : [((OpType, usize), &[GateInput]); " + << levels[level].size() << "] = ["; + for (auto &op : levels[level]) { + // Print the operation type and its ciphertext args + os << llvm::formatv( + "(({0}, {1}), &[{2}]), ", operationType(op), + variableNames->getIntForValue(op->getResult(0)), + commaSeparatedValues( + getCiphertextOperands(op->getOperands()), [&](Value value) { + // TODO(#462): This assumes that all ciphertexts are loaded + // into temp_nodes. Currently, block arguments are not + // supported. + return "Tv(" + + std::to_string(variableNames->getIntForValue(value)) + + ")"; + })); + } + os << "];\n"; + } + + // Walk operations of the for loop body until we hit an op besides + // GenerateLookupTable, CreateTrivial, or a memref::LoadOp. + forOp.getBody()->walk([&](Operation *op) -> WalkResult { + return llvm::TypeSwitch(op) + .Case( + [&](Operation *op) { + if (failed(translate(*op))) { + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }) + .Case([&](memref::LoadOp op) { + // Insert the result into the temp_nodes hashmap. + os << llvm::formatv("temp_nodes.insert({0}, ", + variableNames->getIntForValue(op.getResult())); + printLoadOp(op); + os << ".clone());\n"; + + return WalkResult::advance(); + }) + // Note: if these ops are hoisted before the add, shift, and + // apply_lookup_table ops, we could interrupt and stop here. + .Default([](Operation *) { return WalkResult::advance(); }); + }); + + // Execute each task in the level. + for (int level = 0; level < levels.size(); ++level) { + os << llvm::formatv( + "run_level({1}, &mut temp_nodes, &mut luts, &LEVEL_{0});\n", level, + serverKeyArg_); + } + + // Store into memrefs by taking the values from temp_nodes. + for (memref::StoreOp op : forOp.getBody()->getOps()) { + std::string valueToStore = + llvm::formatv("temp_nodes[&{0}].clone()", + variableNames->getIntForValue(op.getValueToStore())); + printStoreOp(op, valueToStore); + } + } else { + // Without levelled execution, trasnslate the body operations. + for (Operation &op : forOp.getBody()->getOperations()) { + if (failed(translate(op))) { + return failure(); + } + } + } + + os.unindent(); + os << "}\n"; return success(); } LogicalResult TfheRustEmitter::printOperation(ScalarLeftShiftOp op) { return printSksMethod(op.getResult(), op.getServerKey(), {op.getCiphertext(), op.getShiftAmount()}, - "scalar_left_shift"); + "scalar_left_shift", {"", "u8"}); } LogicalResult TfheRustEmitter::printOperation(CreateTrivialOp op) { @@ -296,6 +510,17 @@ LogicalResult TfheRustEmitter::printOperation(arith::ConstantOp op) { return success(); } +LogicalResult TfheRustEmitter::printOperation(arith::IndexCastOp op) { + emitAssignPrefix(op.getOut()); + os << variableNames->getNameForValue(op.getIn()) << " as "; + if (failed(emitType(op.getOut().getType()))) { + return op.emitOpError() + << "Failed to emit index cast type " << op.getOut().getType(); + } + os << ";\n"; + return success(); +} + LogicalResult TfheRustEmitter::printOperation(::mlir::arith::ShLIOp op) { return printBinaryOp(op.getResult(), op.getLhs(), op.getRhs(), "<<"); } @@ -351,24 +576,18 @@ LogicalResult TfheRustEmitter::printOperation(tensor::FromElementsOp op) { } LogicalResult TfheRustEmitter::printOperation(memref::AllocOp op) { - // Uses an iterator to create an array with a default value - MemRefType memRefType = op.getMemref().getType(); - auto typeStr = convertType(memRefType); - if (failed(typeStr)) { - op.emitOpError() << "failed to emit memref type " << memRefType; + os << "let mut " << variableNames->getNameForValue(op.getMemref()) + << " : HashMap<(" + << std::accumulate( + std::next(op.getMemref().getType().getShape().begin()), + op.getMemref().getType().getShape().end(), std::string("usize"), + [&](std::string a, int64_t value) { return a + ", usize"; }) + << "), "; + if (failed(emitType(op.getMemref().getType().getElementType()))) { + return op.emitOpError() << "Failed to get memref element type"; } - emitAssignPrefix(op.getResult(), true, typeStr.value()); - auto defaultOr = defaultValue(memRefType.getElementType()); - if (failed(defaultOr)) { - return op.emitOpError() - << "Failed to emit default memref element type " << memRefType; - } - std::string res = defaultOr.value(); - for ([[maybe_unused]] unsigned dim : memRefType.getShape()) { - res = llvm::formatv("core::array::from_fn(|_| {0})", res); - } - os << res << ";\n"; + os << "> = HashMap::new();\n"; return success(); } @@ -411,46 +630,68 @@ LogicalResult TfheRustEmitter::printOperation(memref::GetGlobalOp op) { return success(); } +void TfheRustEmitter::printStoreOp(memref::StoreOp op, + std::string valueToStore) { + os << variableNames->getNameForValue(op.getMemref()); + os << ".insert((" + << commaSeparatedValues(op.getIndices(), + [&](Value value) { + return variableNames->getNameForValue(value) + + std::string(" as usize"); + }) + << "), " << valueToStore << ");\n"; +} + LogicalResult TfheRustEmitter::printOperation(memref::StoreOp op) { - os << variableNames->getNameForValue(op.getMemref()) << "[" - << commaSeparatedValues( - op.getIndices(), - [&](Value value) { return variableNames->getNameForValue(value); }) - << "] = " << variableNames->getNameForValue(op.getValueToStore()) << ";\n"; + printStoreOp(op, variableNames->getNameForValue(op.getValueToStore())); return success(); } -LogicalResult TfheRustEmitter::printOperation(memref::LoadOp op) { - emitAssignPrefix(op.getResult()); - bool isRef = isa(op.getResult().getType()); - os << (isRef ? "&" : "") << variableNames->getNameForValue(op.getMemref()) - << "["; - +void TfheRustEmitter::printLoadOp(memref::LoadOp op) { + os << variableNames->getNameForValue(op.getMemref()); if (dyn_cast_or_null(op.getMemRef().getDefiningOp())) { - // Global arrays are 1-dimensional, so flatten the index. + // Global arrays are 1-dimensional, so flatten the index // TODO(#449): Share with Verilog Emitter. const auto [strides, offset] = getStridesAndOffset(cast(op.getMemRefType())); - os << std::to_string(offset); + os << "[" << std::to_string(offset); for (int i = 0; i < strides.size(); ++i) { os << llvm::formatv(" + {0} * {1}", variableNames->getNameForValue(op.getIndices()[i]), strides[i]); } - } else { - os << commaSeparatedValues(op.getIndices(), [&](Value value) { + os << "]"; + } else if (isa(op.getMemRef())) { + // This is a block argument array. + os << bracketEnclosedValues(op.getIndices(), [&](Value value) { return variableNames->getNameForValue(value); }); + } else { + // Otherwise, this must be an internally allocated memref, treated as a + // hashmap. + os << ".get(&(" << commaSeparatedValues(op.getIndices(), [&](Value value) { + return variableNames->getNameForValue(value) + " as usize"; + }) << ")).unwrap()"; } +} - os << "];\n"; +LogicalResult TfheRustEmitter::printOperation(memref::LoadOp op) { + emitAssignPrefix(op.getResult()); + + // TODO(#474): Generalize to any encrypted uint. + bool isRef = isa(op.getResult().getType()); + os << (isRef ? "&" : ""); + + printLoadOp(op); + + os << ";\n"; return success(); } FailureOr TfheRustEmitter::convertType(Type type) { - // Note: these are probably not the right type names to use exactly, and they - // will need to chance to the right values once we try to compile it against - // a specific API version. + // Note: these are probably not the right type names to use exactly, and + // they will need to chance to the right values once we try to compile it + // against a specific API version. return llvm::TypeSwitch>(type) .Case( [&](RankedTensorType type) -> FailureOr { @@ -464,7 +705,7 @@ FailureOr TfheRustEmitter::convertType(Type type) { auto elementTy = convertType(type.getElementType()); if (failed(elementTy)) return failure(); std::string res = elementTy.value(); - for (unsigned dim : type.getShape()) { + for (unsigned dim : llvm::reverse(type.getShape())) { res = llvm::formatv("[{0}; {1}]", res, dim); } return res; @@ -478,6 +719,7 @@ FailureOr TfheRustEmitter::convertType(Type type) { return (type.isUnsigned() ? std::string("u") : "") + "i" + std::to_string(width.value()); }) + // TODO(#474): Generalize to any encrypted uint. .Case( [&](auto type) { return std::string("Ciphertext"); }) .Case([&](auto type) { return std::string("ServerKey"); }) @@ -489,6 +731,7 @@ FailureOr TfheRustEmitter::convertType(Type type) { FailureOr TfheRustEmitter::defaultValue(Type type) { return llvm::TypeSwitch>(type) .Case([&](IntegerType type) { return std::string("0"); }) + // TODO(#474): Generalize to any encrypted uint. .Case([&](auto type) -> FailureOr { if (serverKeyArg_.empty()) return failure(); return std::string( diff --git a/lib/Target/TfheRust/TfheRustTemplates.h b/lib/Target/TfheRust/TfheRustTemplates.h index b46c0da49..3d6e5443e 100644 --- a/lib/Target/TfheRust/TfheRustTemplates.h +++ b/lib/Target/TfheRust/TfheRustTemplates.h @@ -8,7 +8,66 @@ namespace heir { namespace tfhe_rust { constexpr std::string_view kModulePrelude = R"rust( +use rayon::prelude::*; +use std::collections::HashMap; use tfhe::shortint::prelude::*; + +use tfhe::shortint::server_key::LookupTableOwned; + +enum GateInput { + Tv(usize), // key in a global hashmap +} + +use GateInput::*; + +enum OpType<'a> { + LUT3(&'a str), // key in a global hashmap + ADD, + LSH(u8), // shift value +} + +use OpType::*; +)rust"; + +constexpr std::string_view kRunLevelDefn = R"rust( +let lut3 = |args: &[&Ciphertext], lut: &LookupTableOwned, server_key: &ServerKey| -> Ciphertext { + return server_key.apply_lookup_table(args[0], lut); +}; + +let add = |args: &[&Ciphertext], server_key: &ServerKey| -> Ciphertext { + return server_key.unchecked_add(args[0], args[1]); +}; + +let left_shift = |args: &[&Ciphertext], shift: u8, server_key: &ServerKey| -> Ciphertext { + return server_key.scalar_left_shift(args[0], shift); +}; + +let mut run_level = | + server_key: &ServerKey, + temp_nodes: &mut HashMap, + luts: &mut HashMap<&str, LookupTableOwned>, + tasks: &[((OpType, usize), &[GateInput])] +| { + let updates = tasks + .into_par_iter() + .map(|(k, task_args)| { + let (op_type, result) = k; + let task_args = task_args.into_iter() + .map(|arg| match arg { + Tv(ndx) => &temp_nodes[&ndx], + }).collect::>(); + let op = |args: &[&Ciphertext]| match op_type { + LUT3(lut) => lut3(args, &luts[lut], server_key), + ADD => add(args, server_key), + LSH(shift) => left_shift(args, *shift, server_key) + }; + ((result), op(&task_args)) + }) + .collect::>(); + updates.into_iter().for_each(|(id, v)| { + temp_nodes.insert(*id, v); + }); +}; )rust"; } // namespace tfhe_rust diff --git a/lib/Target/Utils.cpp b/lib/Target/Utils.cpp index b74457217..a3e6e01a5 100644 --- a/lib/Target/Utils.cpp +++ b/lib/Target/Utils.cpp @@ -34,5 +34,13 @@ FailureOr commaSeparatedTypes( }); } +std::string bracketEnclosedValues( + ValueRange values, std::function valueToString) { + return std::accumulate( + std::next(values.begin()), values.end(), + "[" + valueToString(values[0]) + "]", + [&](std::string a, Value b) { return a + "[" + valueToString(b) + "]"; }); +} + } // namespace heir } // namespace mlir diff --git a/lib/Target/Utils.h b/lib/Target/Utils.h index ae574a02e..ccb60cebd 100644 --- a/lib/Target/Utils.h +++ b/lib/Target/Utils.h @@ -21,6 +21,12 @@ std::string commaSeparatedValues( FailureOr commaSeparatedTypes( TypeRange types, std::function(Type)> typeToString); +// Return a string containing the values in the given +// ValueRange enclosed in square brackets, with each value being converted to a +// string by the given mapping function, for example [1][2]. +std::string bracketEnclosedValues( + ValueRange values, std::function valueToString); + } // namespace heir } // namespace mlir diff --git a/tests/tfhe_rust/emit_tfhe_rust.mlir b/tests/tfhe_rust/emit_tfhe_rust.mlir index 7a4f3e3e8..5c4933290 100644 --- a/tests/tfhe_rust/emit_tfhe_rust.mlir +++ b/tests/tfhe_rust/emit_tfhe_rust.mlir @@ -10,7 +10,9 @@ // CHECK-NEXT: [[input1:v[0-9]+]]: &Ciphertext, // CHECK-NEXT: [[input2:v[0-9]+]]: &Ciphertext, // CHECK-NEXT: ) -> Ciphertext { -// CHECK-NEXT: let [[v0:.*]] = [[sks]].bitand(&[[input1]], &[[input2]]); +// CHECK: let [[v0:.*]] = [[sks]].bitand(&[[input1]], &[[input2]]); +// CHECK-NEXT: temp_nodes.insert +// CHECK-SAME: [[v0]] // CHECK-NEXT: [[v0]] // CHECK-NEXT: } func.func @test_bitand(%sks : !sks, %input1 : !eui3, %input2 : !eui3) -> !eui3 { @@ -23,7 +25,9 @@ func.func @test_bitand(%sks : !sks, %input1 : !eui3, %input2 : !eui3) -> !eui3 { // CHECK-NEXT: [[lut:v[0-9]+]]: &LookupTableOwned, // CHECK-NEXT: [[input:v[0-9]+]]: &Ciphertext, // CHECK-NEXT: ) -> Ciphertext { -// CHECK-NEXT: let [[v0:.*]] = [[sks]].apply_lookup_table(&[[input]], &[[lut]]); +// CHECK: let [[v0:.*]] = [[sks]].apply_lookup_table(&[[input]], &luts["[[lut]]"]); +// CHECK-NEXT: temp_nodes.insert +// CHECK-SAME: [[v0]] // CHECK-NEXT: [[v0]] // CHECK-NEXT: } func.func @test_apply_lookup_table(%sks : !sks, %lut: !lut, %input : !eui3) -> !eui3 { @@ -36,11 +40,13 @@ func.func @test_apply_lookup_table(%sks : !sks, %lut: !lut, %input : !eui3) -> ! // CHECK-NEXT: [[lut:v[0-9]+]]: &LookupTableOwned, // CHECK-NEXT: [[input:v[0-9]+]]: &Ciphertext, // CHECK-NEXT: ) -> Ciphertext { -// CHECK-NEXT: let [[v1:.*]] = [[sks]].apply_lookup_table(&[[input]], &[[lut]]); -// CHECK-NEXT: let [[v2:.*]] = [[sks]].unchecked_add(&[[input]], &[[v1]]); -// CHECK-NEXT: let [[c1:.*]] = 1; -// CHECK-NEXT: let [[v3:.*]] = [[sks]].scalar_left_shift(&[[v2]], [[c1]]); -// CHECK-NEXT: let [[v4:.*]] = [[sks]].apply_lookup_table(&[[v3]], &[[lut]]); +// CHECK: let [[v1:.*]] = [[sks]].apply_lookup_table(&[[input]], &luts["[[lut]]"]); +// CHECK: let [[v2:.*]] = [[sks]].unchecked_add(&[[input]], &[[v1]]); +// CHECK: let [[c1:.*]] = 1; +// CHECK: let [[v3:.*]] = [[sks]].scalar_left_shift(&[[v2]], [[c1]] as u8); +// CHECK: let [[v4:.*]] = [[sks]].apply_lookup_table(&[[v3]], &luts["[[lut]]"]); +// CHECK-NEXT: temp_nodes.insert +// CHECK-SAME: [[v4]] // CHECK-NEXT: [[v4]] // CHECK-NEXT: } func.func @test_apply_lookup_table2(%sks : !sks, %lut: !lut, %input : !eui3) -> !eui3 { @@ -55,7 +61,7 @@ func.func @test_apply_lookup_table2(%sks : !sks, %lut: !lut, %input : !eui3) -> // CHECK-LABEL: pub fn test_return_multiple_values( // CHECK-NEXT: [[input:v[0-9]+]]: &Ciphertext, // CHECK-NEXT: ) -> (Ciphertext, Ciphertext) { -// CHECK-NEXT: ([[input]].clone(), [[input]].clone()) +// CHECK: ([[input]].clone(), [[input]].clone()) // CHECK-NEXT: } func.func @test_return_multiple_values(%input : !eui3) -> (!eui3, !eui3) { return %input, %input : !eui3, !eui3 @@ -65,15 +71,15 @@ func.func @test_return_multiple_values(%input : !eui3) -> (!eui3, !eui3) { // CHECK-NEXT: [[sks:v[0-9]+]]: &ServerKey, // CHECK-NEXT: [[input:v[0-9]+]]: &[Ciphertext; 1], // CHECK-NEXT: ) -> [Ciphertext; 1] { - // CHECK-NEXT: let [[v1:.*]] = 0; + // CHECK: let [[v1:.*]] = 0; // CHECK-NEXT: let [[v2:.*]] = &[[input]][[[v1]]]; // CHECK-NEXT: static [[v3:.*]] : [bool; 1] = [-1]; // CHECK-NEXT: let [[v4:.*]] = [[v3]][0 + [[v1]] * 1 + [[v1]] * 1]; // CHECK-NEXT: let [[v5:.*]] = [[sks]].create_trivial([[v4]] as u64); - // CHECK-NEXT: let [[v6:.*]] = [[sks]].bitand(&[[v2]], &[[v5]]); - // CHECK-NEXT: let mut [[v7:.*]] : [Ciphertext; 1] = core::array::from_fn(|_| [[sks]].create_trivial(0 as u64)); - // CHECK-NEXT: [[v7]][[[v1]]] = [[v6]]; - // CHECK-NEXT: [[v7]] + // CHECK: let [[v6:.*]] = [[sks]].bitand(&[[v2]], &[[v5]]); + // CHECK: let mut [[v7:.*]] : HashMap<(usize), Ciphertext> = HashMap::new(); + // CHECK-NEXT: [[v7]].insert(([[v1]] as usize), [[v6]]); + // CHECK-NEXT: core::array::from_fn(|i0| [[v7]].get memref.global constant @__constant_1x1xi1 : memref<1x1xi1> = dense<[[1]]> {alignment = 64 : i64} func.func @test_memref(%sks : !sks, %input : memref<1x!eui3>) -> (memref<1x!eui3>) { %c0 = arith.constant 0 : index @@ -92,7 +98,7 @@ func.func @test_memref(%sks : !sks, %input : memref<1x!eui3>) -> (memref<1x!eui3 // CHECK-NEXT: [[sks:v[0-9]+]]: &ServerKey, // CHECK-NEXT: [[input:v[0-9]+]]: i64, // CHECK-NEXT: ) -> Ciphertext { - // CHECK-NEXT: let [[v1:.*]] = 1; + // CHECK: let [[v1:.*]] = 1; // CHECK-NEXT: let [[v2:.*]] = 429; // CHECK-NEXT: let [[v0:.*]] = [[input]] as i32; // CHECK-NEXT: let [[v3:.*]] = [[v1]] << [[v0]]; @@ -100,7 +106,9 @@ func.func @test_memref(%sks : !sks, %input : memref<1x!eui3>) -> (memref<1x!eui3 // CHECK-NEXT: let [[v5:.*]] = [[v4]] >> [[v0]]; // CHECK-NEXT: let [[v6:.*]] = [[v5]] != 0; // CHECK-NEXT: let [[v7:.*]] = [[sks]].create_trivial([[v6]] as u64); - // CHECK-NEXT: [[v7]] + // CHECK-NEXT: temp_nodes.insert + // CHECK-SAME: [[v7]] + // CHECK: [[v7]] // CHECK-NEXT: } func.func @test_plaintext_arith_ops(%sks : !sks, %input : i64) -> (!eui3) { %c1_i32 = arith.constant 1 : i32 diff --git a/tests/tfhe_rust/end_to_end/BUILD b/tests/tfhe_rust/end_to_end/BUILD index 6587735de..53239ab74 100644 --- a/tests/tfhe_rust/end_to_end/BUILD +++ b/tests/tfhe_rust/end_to_end/BUILD @@ -10,6 +10,7 @@ glob_lit_tests( "Cargo.toml", "src/main.rs", "src/main_add_one.rs", + "src/main_fully_connected.rs", "@heir//tests:test_utilities", ], default_tags = [ @@ -22,6 +23,7 @@ glob_lit_tests( "test_add_one.mlir": "large", "test_bitand.mlir": "large", "test_simple_lut.mlir": "large", + "test_fully_connected.mlir": "large", }, test_file_exts = ["mlir"], ) diff --git a/tests/tfhe_rust/end_to_end/Cargo.toml b/tests/tfhe_rust/end_to_end/Cargo.toml index de1741e82..3fb1e972f 100644 --- a/tests/tfhe_rust/end_to_end/Cargo.toml +++ b/tests/tfhe_rust/end_to_end/Cargo.toml @@ -16,3 +16,7 @@ path = "src/main.rs" [[bin]] name = "main_add_one" path = "src/main_add_one.rs" + +[[bin]] +name = "main_fully_connected" +path = "src/main_fully_connected.rs" diff --git a/tests/tfhe_rust/end_to_end/src/main_fully_connected.rs b/tests/tfhe_rust/end_to_end/src/main_fully_connected.rs new file mode 100644 index 000000000..9b93fee76 --- /dev/null +++ b/tests/tfhe_rust/end_to_end/src/main_fully_connected.rs @@ -0,0 +1,48 @@ +use clap::Parser; +use tfhe::shortint::parameters::get_parameters_from_message_and_carry; +use tfhe::shortint::*; + +mod fn_under_test; + +// TODO(#235): improve generality +#[derive(Parser, Debug)] +struct Args { + #[arg(id = "message_bits", long)] + message_bits: usize, + #[arg(id = "carry_bits", long, default_value = "2")] + carry_bits: usize, + /// arguments to forward to function under test + #[arg(id = "input_1", index = 1)] + input1: u8, +} +// Encrypt a u8 +pub fn encrypt(value: u8, client_key: &ClientKey) -> [[[Ciphertext; 8]; 1]; 1] { + core::array::from_fn(|_| { + core::array::from_fn(|_| { + core::array::from_fn(|shift| { + let bit = (value >> shift) & 1; + client_key.encrypt(if bit != 0 { 1 } else { 0 }) + }) + }) + }) +} + +// Decrypt a u8 +pub fn decrypt(ciphertexts: &[Ciphertext], client_key: &ClientKey) -> u8 { + let mut accum = 0u8; + for (i, ct) in ciphertexts.iter().enumerate() { + let bit = client_key.decrypt(ct); + accum |= (bit as u8) << i; + } + accum +} +fn main() { + let flags = Args::parse(); + let parameters = + get_parameters_from_message_and_carry((1 << flags.message_bits) - 1, flags.carry_bits); + let (client_key, server_key) = tfhe::shortint::gen_keys(parameters); + let ct_1 = encrypt(flags.input1.into(), &client_key); + let result = fn_under_test::fn_under_test(&server_key, &ct_1); + let output = decrypt(&result[0][0][0..8], &client_key); + println!("{:08b}", output); +} diff --git a/tests/tfhe_rust/end_to_end/test_fully_connected.mlir b/tests/tfhe_rust/end_to_end/test_fully_connected.mlir new file mode 100644 index 000000000..eb0299b9d --- /dev/null +++ b/tests/tfhe_rust/end_to_end/test_fully_connected.mlir @@ -0,0 +1,13 @@ +// RUN: heir-opt --tosa-to-boolean-tfhe="abc-fast=true entry-function=fn_under_test" %s | heir-translate --emit-tfhe-rust > %S/src/fn_under_test.rs +// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main_fully_connected -- 2 --message_bits=3 | FileCheck %s + +// This takes takes the input x and outputs 2 \cdot x + 1. +// CHECK: 00000101 +module attributes {tf_saved_model.semantics} { + func.func @fn_under_test(%11: tensor<1x1xi8>) -> tensor<1x1xi32> { + %0 = "tosa.const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tosa.const"() {value = dense<[[2]]> : tensor<1x1xi8>} : () -> tensor<1x1xi8> + %2 = "tosa.fully_connected"(%11, %1, %0) {quantization_info = #tosa.conv_quant} : (tensor<1x1xi8>, tensor<1x1xi8>, tensor<1xi32>) -> tensor<1x1xi32> + return %2 : tensor<1x1xi32> + } +} diff --git a/tests/tfhe_rust/iterators.mlir b/tests/tfhe_rust/iterators.mlir new file mode 100644 index 000000000..25464d6ed --- /dev/null +++ b/tests/tfhe_rust/iterators.mlir @@ -0,0 +1,110 @@ +// RUN: heir-translate %s --emit-tfhe-rust | FileCheck %s + +// CHECK: enum GateInput + +// CHECK: enum OpType + +module { + // CHECK-LABEL: pub fn generate_cleartext_ops + // CHECK-NEXT: [[sks:v[0-9]+]]: &ServerKey, + // CHECK-NEXT: [[input:v[0-9]+]]: &[i16; 2], + // CHECK-NEXT: ) -> {{[[][[]}}i8; 8]; 2] { + // CHECK-NEXT: let mut temp_nodes : HashMap = HashMap::new(); + // CHECK-NEXT: let mut luts : HashMap<&str, LookupTableOwned> = HashMap::new(); + // CHECK: let mut run_level + + // CHECK: let mut [[v1:.*]] : HashMap<(usize, usize), i8> = HashMap::new(); + // CHECK-NEXT: for [[v2:.*]] in 0..2 { + // CHECK-NEXT: let [[v3:.*]] = [[input]][[[v2]]]; + // CHECK-NEXT: for [[v4:.*]] in 0..8 { + // CHECK-NEXT: let [[v5:.*]] = [[v4]] as i16; + // CHECK-NEXT: let [[v6:.*]] = [[v5]] & [[v3]]; + // CHECK-NEXT: let [[v7:.*]] = [[v6]] as i8; + // CHECK-NEXT: [[v1]].insert(([[v2]] as usize, [[v4]] as usize), [[v7]]); + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: core::array::from_fn + // CHECK-SAME: [[v1]].get(&([[i0:.*]], [[i1:.*]])).unwrap().clone() + func.func @generate_cleartext_ops(%sks : !tfhe_rust.server_key, %arg0 : memref<2xi16>) -> (memref<2x8xi8>) { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<2x8xi8> + affine.for %arg1 = 0 to 2 { + %0 = memref.load %arg0[%arg1] : memref<2xi16> + affine.for %arg2 = 0 to 8 { + %1 = arith.index_cast %arg2 : index to i16 + %2 = arith.andi %1, %0 : i16 + %3 = arith.trunci %2 : i16 to i8 + memref.store %3, %alloc[%arg1, %arg2] :memref<2x8xi8> + } + } + return %alloc : memref<2x8xi8> + } + + // A memref is stored with an initial value and then iteratively summed + // CHECK-LABEL: pub fn iterative + // CHECK-NEXT: [[input1:v[0-9]+]]: &{{[[][[]}}i8; 16]; 1], + // CHECK-NEXT: [[input2:v[0-9]+]]: &{{[[][[]}}i8; 1]; 16], + // CHECK-NEXT: ) -> {{[[][[]}}i8; 1]; 1] { + // CHECK-NEXT: let mut temp_nodes : HashMap = HashMap::new(); + // CHECK-NEXT: let mut luts : HashMap<&str, LookupTableOwned> = HashMap::new(); + // CHECK: let mut run_level + // + // CHECK: let [[v0:.*]] = 29; + // CHECK-NEXT: let mut [[v1:.*]] : HashMap<(usize, usize), i8> = + // CHECK-NEXT: for [[v2:.*]] in 0..1 { + // CHECK-NEXT: for [[v3:.*]] in 0..1 { + // CHECK-NEXT: [[v1]].insert(([[v2]] as usize, [[v3]] as usize), [[v0]]); + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: for [[v4:.*]] in 0..1 { + // CHECK-NEXT: for [[v5:.*]] in 0..1 { + // CHECK-NEXT: for [[v6:.*]] in 0..16 { + // CHECK-NEXT: let [[v7:.*]] = [[input1]][[[v4]]][[[v6]]]; + // CHECK-NEXT: let [[v8:.*]] = [[input2]][[[v6]]][[[v5]]]; + // CHECK-NEXT: let [[v9:.*]] = [[v1]].get(&([[v4]] as usize, [[v5]] as usize)).unwrap(); + // CHECK-NEXT: let [[v10:.*]] = [[v7]] & [[v8]]; + // CHECK-NEXT: let [[v11:.*]] = [[v10]] & [[v9]]; + // CHECK-NEXT: [[v1]].insert(([[v4]] as usize, [[v5]] as usize), [[v11]]); + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: let mut [[v12:.*]] : HashMap<(usize, usize), i8> = HashMap::new(); + // CHECK-NEXT: for [[v13:.*]] in 0..1 { + // CHECK-NEXT: for [[v14:.*]] in 0..16 { + // CHECK-NEXT: let [[v15:.*]] = [[v1]].get(&([[v13]] as usize, [[v14]] as usize)).unwrap(); + // CHECK-NEXT: let [[v16:.*]] = [[v15]] as i8; + // CHECK-NEXT: [[v12]].insert(([[v13]] as usize, [[v14]] as usize), [[v16]]); + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: core::array::from_fn + // CHECK-SAME: [[v12]].get(&([[i0:.*]], [[i1:.*]])).unwrap().clone() + func.func @iterative(%alloc_6: memref<1x16xi8>, %alloc_7 : memref<16x1xi8>) -> memref<1x1xi4> { + %c29_i8 = arith.constant 29 : i8 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1xi8> + affine.for %arg1 = 0 to 1 { + affine.for %arg2 = 0 to 1 { + memref.store %c29_i8, %alloc[%arg1, %arg2] : memref<1x1xi8> + } + } + affine.for %arg1 = 0 to 1 { + affine.for %arg2 = 0 to 1 { + affine.for %arg3 = 0 to 16 { + %5 = memref.load %alloc_6[%arg1, %arg3] : memref<1x16xi8> + %6 = memref.load %alloc_7[%arg3, %arg2] : memref<16x1xi8> + %7 = memref.load %alloc[%arg1, %arg2] : memref<1x1xi8> + %8 = arith.andi %5, %6 : i8 + %13 = arith.andi %8, %7 : i8 + memref.store %13, %alloc[%arg1, %arg2] : memref<1x1xi8> + } + } + } + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x1xi4> + affine.for %arg1 = 0 to 1 { + affine.for %arg2 = 0 to 16 { + %5 = memref.load %alloc[%arg1, %arg2] : memref<1x1xi8> + %6 = arith.trunci %5 : i8 to i4 + memref.store %6, %alloc_1[%arg1, %arg2] : memref<1x1xi4> + } + } + return %alloc_1 : memref<1x1xi4> + } +}