diff --git a/include/Transforms/FullLoopUnroll/BUILD b/include/Transforms/FullLoopUnroll/BUILD new file mode 100644 index 000000000..a51a8eb14 --- /dev/null +++ b/include/Transforms/FullLoopUnroll/BUILD @@ -0,0 +1,35 @@ +# FullLoopUnroll tablegen and headers. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files([ + "FullLoopUnroll.h", +]) + +gentbl_cc_library( + name = "pass_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=FullLoopUnroll", + ], + "FullLoopUnroll.h.inc", + ), + ( + ["-gen-pass-doc"], + "FullLoopUnrollPasses.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "FullLoopUnroll.td", + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) diff --git a/include/Transforms/FullLoopUnroll/FullLoopUnroll.h b/include/Transforms/FullLoopUnroll/FullLoopUnroll.h new file mode 100644 index 000000000..f7ed05963 --- /dev/null +++ b/include/Transforms/FullLoopUnroll/FullLoopUnroll.h @@ -0,0 +1,18 @@ +#ifndef INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_H_ +#define INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#define GEN_PASS_DECL +#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h.inc" + +#define GEN_PASS_REGISTRATION +#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h.inc" + +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_H_ diff --git a/include/Transforms/FullLoopUnroll/FullLoopUnroll.td b/include/Transforms/FullLoopUnroll/FullLoopUnroll.td new file mode 100644 index 000000000..640e22bbc --- /dev/null +++ b/include/Transforms/FullLoopUnroll/FullLoopUnroll.td @@ -0,0 +1,13 @@ +#ifndef INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_TD_ +#define INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_TD_ + +include "mlir/Pass/PassBase.td" + +def FullLoopUnroll : Pass<"full-loop-unroll"> { + let summary = "Fully unroll all loops"; + let description = [{ + Scan the IR for affine.for loops and unroll them all. + }]; +} + +#endif // INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_TD_ diff --git a/lib/Transforms/FullLoopUnroll/BUILD b/lib/Transforms/FullLoopUnroll/BUILD new file mode 100644 index 000000000..8b40f31d9 --- /dev/null +++ b/lib/Transforms/FullLoopUnroll/BUILD @@ -0,0 +1,22 @@ +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "FullLoopUnroll", + srcs = ["FullLoopUnroll.cpp"], + hdrs = [ + "@heir//include/Transforms/FullLoopUnroll:FullLoopUnroll.h", + ], + deps = [ + "@heir//include/Transforms/FullLoopUnroll:pass_inc_gen", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineUtils", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/lib/Transforms/FullLoopUnroll/FullLoopUnroll.cpp b/lib/Transforms/FullLoopUnroll/FullLoopUnroll.cpp new file mode 100644 index 000000000..9ec59f8ee --- /dev/null +++ b/lib/Transforms/FullLoopUnroll/FullLoopUnroll.cpp @@ -0,0 +1,33 @@ +#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h" + +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Affine/LoopUtils.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#define GEN_PASS_DEF_FULLLOOPUNROLL +#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h.inc" + +struct FullLoopUnroll : impl::FullLoopUnrollBase { + using FullLoopUnrollBase::FullLoopUnrollBase; + + void runOnOperation() override { + auto walkResult = + getOperation()->walk([&](affine::AffineForOp op) { + auto result = mlir::affine::loopUnrollFull(op); + if (failed(result)) { + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + if (walkResult.wasInterrupted()) { + signalPassFailure(); + } + } +}; + +} // namespace heir +} // namespace mlir diff --git a/tests/full_loop_unroll.mlir b/tests/full_loop_unroll.mlir new file mode 100644 index 000000000..4170b2322 --- /dev/null +++ b/tests/full_loop_unroll.mlir @@ -0,0 +1,23 @@ +// RUN: heir-opt --full-loop-unroll %s | FileCheck %s + +!sks = !tfhe_rust.server_key + +// CHECK-LABEL: func @test_move_out_of_loop +func.func @test_move_out_of_loop(%sks : !sks, %lut: !tfhe_rust.lookup_table) -> memref<10x!tfhe_rust.eui3> { + // CHECK-NOT: affine.for + %0 = arith.constant 1 : i3 + %1 = arith.constant 2 : i3 + %memref = memref.alloca() : memref<10x!tfhe_rust.eui3> + + affine.for %i = 0 to 10 { + %e2 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 + %shiftAmount = arith.constant 1 : i8 + %e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2, %shiftAmount : (!sks, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %e1 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 + %eCombined = tfhe_rust.add %sks, %e1, %e2Shifted : (!sks, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 + %out = tfhe_rust.apply_lookup_table %sks, %eCombined, %lut : (!sks, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 + memref.store %out, %memref[%i] : memref<10x!tfhe_rust.eui3> + affine.yield + } + return %memref : memref<10x!tfhe_rust.eui3> +} diff --git a/tools/BUILD b/tools/BUILD index 823635c91..0c9bc44a6 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -55,6 +55,7 @@ cc_binary( "@heir//lib/Dialect/TfheRust/IR:Dialect", "@heir//lib/Dialect/TfheRustBool/IR:Dialect", "@heir//lib/Transforms/ForwardStoreToLoad", + "@heir//lib/Transforms/FullLoopUnroll", "@heir//lib/Transforms/Secretize", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 1e9d54b42..8eb0ed41e 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -22,6 +22,7 @@ #include "include/Dialect/TfheRust/IR/TfheRustDialect.h" #include "include/Dialect/TfheRustBool/IR/TfheRustBoolDialect.h" #include "include/Transforms/ForwardStoreToLoad/ForwardStoreToLoad.h" +#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h" #include "include/Transforms/Secretize/Passes.h" #include "llvm/include/llvm/Support/CommandLine.h" // from @llvm-project #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project @@ -288,6 +289,7 @@ int main(int argc, char **argv) { lwe::registerLWEPasses(); secret::registerSecretPasses(); registerSecretizePasses(); + registerFullLoopUnrollPasses(); registerForwardStoreToLoadPasses(); // Register yosys optimizer pipeline if configured. #ifndef HEIR_NO_YOSYS