Skip to content

Commit

Permalink
Add support for affine for loops in tfhe-rs emitter
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 610509708
  • Loading branch information
asraa authored and Copybara-Service committed Feb 29, 2024
1 parent 43d8f65 commit 5a60d78
Show file tree
Hide file tree
Showing 15 changed files with 589 additions and 74 deletions.
12 changes: 11 additions & 1 deletion include/Analysis/SelectVariableNames/SelectVariableNames.h
@@ -1,6 +1,8 @@
#ifndef INCLUDE_ANALYSIS_SELECTVARIABLENAMES_SELECTVARIABLENAMES_H_
#define INCLUDE_ANALYSIS_SELECTVARIABLENAMES_SELECTVARIABLENAMES_H_

#include <string>

#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
Expand All @@ -17,12 +19,20 @@ class SelectVariableNames {
/// value was not assigned a name (suggesting the value was not in the IR
/// tree that this class was constructed with).
std::string getNameForValue(Value value) const {
assert(variableNames.contains(value));
return prefix + std::to_string(variableNames.lookup(value));
}

// Return the unique integer assigned to a given value.
int getIntForValue(Value value) const {
assert(variableNames.contains(value));
return variableNames.lookup(value);
}

private:
llvm::DenseMap<Value, std::string> variableNames;
llvm::DenseMap<Value, int> variableNames;

std::string prefix{"v"};
};

} // namespace heir
Expand Down
1 change: 1 addition & 0 deletions include/Target/TfheRust/BUILD
Expand Up @@ -17,6 +17,7 @@ cc_library(
"@heir//lib/Analysis/SelectVariableNames",
"@heir//lib/Dialect/TfheRust/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
Expand Down
8 changes: 7 additions & 1 deletion include/Target/TfheRust/TfheRustEmitter.h
Expand Up @@ -7,7 +7,8 @@
#include "include/Analysis/SelectVariableNames/SelectVariableNames.h"
#include "include/Dialect/TfheRust/IR/TfheRustDialect.h"
#include "include/Dialect/TfheRust/IR/TfheRustOps.h"
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
Expand Down Expand Up @@ -51,6 +52,7 @@ class TfheRustEmitter {
// Functions for printing individual ops
LogicalResult printOperation(::mlir::ModuleOp op);
LogicalResult printOperation(::mlir::arith::ConstantOp op);
LogicalResult printOperation(::mlir::arith::IndexCastOp op);
LogicalResult printOperation(::mlir::arith::ShLIOp op);
LogicalResult printOperation(::mlir::arith::AndIOp op);
LogicalResult printOperation(::mlir::arith::ShRSIOp op);
Expand All @@ -60,6 +62,7 @@ class TfheRustEmitter {
LogicalResult printOperation(AddOp op);
LogicalResult printOperation(BitAndOp op);
LogicalResult printOperation(CreateTrivialOp op);
LogicalResult printOperation(affine::AffineForOp op);
LogicalResult printOperation(tensor::ExtractOp op);
LogicalResult printOperation(tensor::FromElementsOp op);
LogicalResult printOperation(memref::AllocOp op);
Expand All @@ -77,6 +80,9 @@ class TfheRustEmitter {
SmallVector<std::string> operandTypes = {});
LogicalResult printBinaryOp(::mlir::Value result, ::mlir::Value lhs,
::mlir::Value rhs, std::string_view op);
void printStoreOp(memref::StoreOp op, std::string valueToStore);
void printLoadOp(memref::LoadOp op);
std::string operationType(Operation *op);

// Emit a TfheRust type
LogicalResult emitType(Type type);
Expand Down
5 changes: 2 additions & 3 deletions lib/Analysis/SelectVariableNames/SelectVariableNames.cpp
Expand Up @@ -12,16 +12,15 @@ namespace heir {

SelectVariableNames::SelectVariableNames(Operation *op) {
int i = 0;
std::string prefix = "v";
op->walk<WalkOrder::PreOrder>([&](Operation *op) {
for (Value result : op->getResults()) {
variableNames.try_emplace(result, prefix + std::to_string(i++));
variableNames.try_emplace(result, i++);
}

for (Region &region : op->getRegions()) {
for (Block &block : region) {
for (Value arg : block.getArguments()) {
variableNames.try_emplace(arg, prefix + std::to_string(i++));
variableNames.try_emplace(arg, i++);
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions lib/Target/TfheRust/BUILD
Expand Up @@ -13,10 +13,13 @@ cc_library(
"@heir//include/Target/TfheRust:TfheRustEmitter.h",
],
deps = [
"@heir//include/Graph",
"@heir//lib/Analysis/SelectVariableNames",
"@heir//lib/Dialect/TfheRust/IR:Dialect",
"@heir//lib/Target:Utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
Expand Down

0 comments on commit 5a60d78

Please sign in to comment.