Skip to content

Commit

Permalink
add a full loop unroll pass
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Mar 4, 2024
1 parent 1939b79 commit e50e7e0
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 0 deletions.
35 changes: 35 additions & 0 deletions include/Transforms/FullLoopUnroll/BUILD
@@ -0,0 +1,35 @@
# FullLoopUnroll tablegen and headers.

load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

exports_files([
"FullLoopUnroll.h",
])

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=FullLoopUnroll",
],
"FullLoopUnroll.h.inc",
),
(
["-gen-pass-doc"],
"FullLoopUnrollPasses.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "FullLoopUnroll.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
18 changes: 18 additions & 0 deletions include/Transforms/FullLoopUnroll/FullLoopUnroll.h
@@ -0,0 +1,18 @@
#ifndef INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_H_
#define INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {

#define GEN_PASS_DECL
#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h.inc"

#define GEN_PASS_REGISTRATION
#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h.inc"

} // namespace heir
} // namespace mlir

#endif // INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_H_
16 changes: 16 additions & 0 deletions include/Transforms/FullLoopUnroll/FullLoopUnroll.td
@@ -0,0 +1,16 @@
#ifndef INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_TD_
#define INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_TD_

include "mlir/Pass/PassBase.td"

def FullLoopUnroll : Pass<"full-loop-unroll"> {
let summary = "Fully unroll all loops";
let description = [{
Scan the IR for affine.for loops and unroll them all.
}];
let dependentDialects = [

];
}

#endif // INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_TD_
22 changes: 22 additions & 0 deletions lib/Transforms/FullLoopUnroll/BUILD
@@ -0,0 +1,22 @@
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "FullLoopUnroll",
srcs = ["FullLoopUnroll.cpp"],
hdrs = [
"@heir//include/Transforms/FullLoopUnroll:FullLoopUnroll.h",
],
deps = [
"@heir//include/Transforms/FullLoopUnroll:pass_inc_gen",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)
33 changes: 33 additions & 0 deletions lib/Transforms/FullLoopUnroll/FullLoopUnroll.cpp
@@ -0,0 +1,33 @@
#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h"

#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Affine/LoopUtils.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project

namespace mlir {
namespace heir {

#define GEN_PASS_DEF_FULLLOOPUNROLL
#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h.inc"

struct FullLoopUnroll : impl::FullLoopUnrollBase<FullLoopUnroll> {
using FullLoopUnrollBase::FullLoopUnrollBase;

void runOnOperation() override {
auto walkResult =
getOperation()->walk<WalkOrder::PostOrder>([&](affine::AffineForOp op) {
auto result = mlir::affine::loopUnrollFull(op);
if (failed(result)) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});

if (walkResult.wasInterrupted()) {
signalPassFailure();
}
}
};

} // namespace heir
} // namespace mlir
23 changes: 23 additions & 0 deletions tests/full_loop_unroll.mlir
@@ -0,0 +1,23 @@
// RUN: heir-opt --full-loop-unroll %s | FileCheck %s

!sks = !tfhe_rust.server_key

// CHECK-LABEL: func @test_move_out_of_loop
func.func @test_move_out_of_loop(%sks : !sks, %lut: !tfhe_rust.lookup_table) -> memref<10x!tfhe_rust.eui3> {
// CHECK-NOT: affine.for
%0 = arith.constant 1 : i3
%1 = arith.constant 2 : i3
%memref = memref.alloca() : memref<10x!tfhe_rust.eui3>

affine.for %i = 0 to 10 {
%e2 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3
%shiftAmount = arith.constant 1 : i8
%e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2, %shiftAmount : (!sks, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3
%e1 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3
%eCombined = tfhe_rust.add %sks, %e1, %e2Shifted : (!sks, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3
%out = tfhe_rust.apply_lookup_table %sks, %eCombined, %lut : (!sks, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3
memref.store %out, %memref[%i] : memref<10x!tfhe_rust.eui3>
affine.yield
}
return %memref : memref<10x!tfhe_rust.eui3>
}
1 change: 1 addition & 0 deletions tools/BUILD
Expand Up @@ -55,6 +55,7 @@ cc_binary(
"@heir//lib/Dialect/TfheRust/IR:Dialect",
"@heir//lib/Dialect/TfheRustBool/IR:Dialect",
"@heir//lib/Transforms/ForwardStoreToLoad",
"@heir//lib/Transforms/FullLoopUnroll",
"@heir//lib/Transforms/Secretize",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
Expand Down
2 changes: 2 additions & 0 deletions tools/heir-opt.cpp
Expand Up @@ -22,6 +22,7 @@
#include "include/Dialect/TfheRust/IR/TfheRustDialect.h"
#include "include/Dialect/TfheRustBool/IR/TfheRustBoolDialect.h"
#include "include/Transforms/ForwardStoreToLoad/ForwardStoreToLoad.h"
#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h"
#include "include/Transforms/Secretize/Passes.h"
#include "llvm/include/llvm/Support/CommandLine.h" // from @llvm-project
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
Expand Down Expand Up @@ -288,6 +289,7 @@ int main(int argc, char **argv) {
lwe::registerLWEPasses();
secret::registerSecretPasses();
registerSecretizePasses();
registerFullLoopUnrollPasses();
registerForwardStoreToLoadPasses();
// Register yosys optimizer pipeline if configured.
#ifndef HEIR_NO_YOSYS
Expand Down

0 comments on commit e50e7e0

Please sign in to comment.