Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a full loop unroll pass #477

Merged
merged 1 commit into from Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 35 additions & 0 deletions include/Transforms/FullLoopUnroll/BUILD
@@ -0,0 +1,35 @@
# FullLoopUnroll tablegen and headers.

load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

exports_files([
"FullLoopUnroll.h",
])

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=FullLoopUnroll",
],
"FullLoopUnroll.h.inc",
),
(
["-gen-pass-doc"],
"FullLoopUnrollPasses.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "FullLoopUnroll.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
18 changes: 18 additions & 0 deletions include/Transforms/FullLoopUnroll/FullLoopUnroll.h
@@ -0,0 +1,18 @@
#ifndef INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_H_
#define INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {

#define GEN_PASS_DECL
#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h.inc"

#define GEN_PASS_REGISTRATION
#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h.inc"

} // namespace heir
} // namespace mlir

#endif // INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_H_
13 changes: 13 additions & 0 deletions include/Transforms/FullLoopUnroll/FullLoopUnroll.td
@@ -0,0 +1,13 @@
#ifndef INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_TD_
#define INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_TD_

include "mlir/Pass/PassBase.td"

def FullLoopUnroll : Pass<"full-loop-unroll"> {
let summary = "Fully unroll all loops";
let description = [{
Scan the IR for affine.for loops and unroll them all.
}];
}

#endif // INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_TD_
22 changes: 22 additions & 0 deletions lib/Transforms/FullLoopUnroll/BUILD
@@ -0,0 +1,22 @@
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "FullLoopUnroll",
srcs = ["FullLoopUnroll.cpp"],
hdrs = [
"@heir//include/Transforms/FullLoopUnroll:FullLoopUnroll.h",
],
deps = [
"@heir//include/Transforms/FullLoopUnroll:pass_inc_gen",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)
33 changes: 33 additions & 0 deletions lib/Transforms/FullLoopUnroll/FullLoopUnroll.cpp
@@ -0,0 +1,33 @@
#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h"

#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Affine/LoopUtils.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project

namespace mlir {
namespace heir {

#define GEN_PASS_DEF_FULLLOOPUNROLL
#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h.inc"

struct FullLoopUnroll : impl::FullLoopUnrollBase<FullLoopUnroll> {
using FullLoopUnrollBase::FullLoopUnrollBase;

void runOnOperation() override {
auto walkResult =
getOperation()->walk<WalkOrder::PostOrder>([&](affine::AffineForOp op) {
auto result = mlir::affine::loopUnrollFull(op);
if (failed(result)) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});

if (walkResult.wasInterrupted()) {
signalPassFailure();
}
}
};

} // namespace heir
} // namespace mlir
23 changes: 23 additions & 0 deletions tests/full_loop_unroll.mlir
@@ -0,0 +1,23 @@
// RUN: heir-opt --full-loop-unroll %s | FileCheck %s

!sks = !tfhe_rust.server_key

// CHECK-LABEL: func @test_move_out_of_loop
func.func @test_move_out_of_loop(%sks : !sks, %lut: !tfhe_rust.lookup_table) -> memref<10x!tfhe_rust.eui3> {
// CHECK-NOT: affine.for
%0 = arith.constant 1 : i3
%1 = arith.constant 2 : i3
%memref = memref.alloca() : memref<10x!tfhe_rust.eui3>

affine.for %i = 0 to 10 {
%e2 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3
%shiftAmount = arith.constant 1 : i8
%e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2, %shiftAmount : (!sks, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3
%e1 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3
%eCombined = tfhe_rust.add %sks, %e1, %e2Shifted : (!sks, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3
%out = tfhe_rust.apply_lookup_table %sks, %eCombined, %lut : (!sks, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3
memref.store %out, %memref[%i] : memref<10x!tfhe_rust.eui3>
affine.yield
}
return %memref : memref<10x!tfhe_rust.eui3>
}
1 change: 1 addition & 0 deletions tools/BUILD
Expand Up @@ -55,6 +55,7 @@ cc_binary(
"@heir//lib/Dialect/TfheRust/IR:Dialect",
"@heir//lib/Dialect/TfheRustBool/IR:Dialect",
"@heir//lib/Transforms/ForwardStoreToLoad",
"@heir//lib/Transforms/FullLoopUnroll",
"@heir//lib/Transforms/Secretize",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
Expand Down
2 changes: 2 additions & 0 deletions tools/heir-opt.cpp
Expand Up @@ -22,6 +22,7 @@
#include "include/Dialect/TfheRust/IR/TfheRustDialect.h"
#include "include/Dialect/TfheRustBool/IR/TfheRustBoolDialect.h"
#include "include/Transforms/ForwardStoreToLoad/ForwardStoreToLoad.h"
#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h"
#include "include/Transforms/Secretize/Passes.h"
#include "llvm/include/llvm/Support/CommandLine.h" // from @llvm-project
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
Expand Down Expand Up @@ -288,6 +289,7 @@ int main(int argc, char **argv) {
lwe::registerLWEPasses();
secret::registerSecretPasses();
registerSecretizePasses();
registerFullLoopUnrollPasses();
registerForwardStoreToLoadPasses();
// Register yosys optimizer pipeline if configured.
#ifndef HEIR_NO_YOSYS
Expand Down