From 877a922cc945dbe28d0a22dc74a1b3244072c9f4 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Thu, 29 Feb 2024 17:20:46 -0800 Subject: [PATCH] add a full loop unroll pass --- include/Transforms/FullLoopUnroll/BUILD | 35 ++++++++++++++++++ .../FullLoopUnroll/FullLoopUnroll.h | 18 ++++++++++ .../FullLoopUnroll/FullLoopUnroll.td | 16 +++++++++ lib/Transforms/FullLoopUnroll/BUILD | 22 ++++++++++++ .../FullLoopUnroll/FullLoopUnroll.cpp | 36 +++++++++++++++++++ tests/full_loop_unroll.mlir | 23 ++++++++++++ tools/BUILD | 1 + tools/heir-opt.cpp | 2 ++ 8 files changed, 153 insertions(+) create mode 100644 include/Transforms/FullLoopUnroll/BUILD create mode 100644 include/Transforms/FullLoopUnroll/FullLoopUnroll.h create mode 100644 include/Transforms/FullLoopUnroll/FullLoopUnroll.td create mode 100644 lib/Transforms/FullLoopUnroll/BUILD create mode 100644 lib/Transforms/FullLoopUnroll/FullLoopUnroll.cpp create mode 100644 tests/full_loop_unroll.mlir diff --git a/include/Transforms/FullLoopUnroll/BUILD b/include/Transforms/FullLoopUnroll/BUILD new file mode 100644 index 000000000..a51a8eb14 --- /dev/null +++ b/include/Transforms/FullLoopUnroll/BUILD @@ -0,0 +1,35 @@ +# FullLoopUnroll tablegen and headers. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files([ + "FullLoopUnroll.h", +]) + +gentbl_cc_library( + name = "pass_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=FullLoopUnroll", + ], + "FullLoopUnroll.h.inc", + ), + ( + ["-gen-pass-doc"], + "FullLoopUnrollPasses.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "FullLoopUnroll.td", + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) diff --git a/include/Transforms/FullLoopUnroll/FullLoopUnroll.h b/include/Transforms/FullLoopUnroll/FullLoopUnroll.h new file mode 100644 index 000000000..f7ed05963 --- /dev/null +++ b/include/Transforms/FullLoopUnroll/FullLoopUnroll.h @@ -0,0 +1,18 @@ +#ifndef INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_H_ +#define INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#define GEN_PASS_DECL +#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h.inc" + +#define GEN_PASS_REGISTRATION +#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h.inc" + +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_H_ diff --git a/include/Transforms/FullLoopUnroll/FullLoopUnroll.td b/include/Transforms/FullLoopUnroll/FullLoopUnroll.td new file mode 100644 index 000000000..e800ee384 --- /dev/null +++ b/include/Transforms/FullLoopUnroll/FullLoopUnroll.td @@ -0,0 +1,16 @@ +#ifndef INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_TD_ +#define INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_TD_ + +include "mlir/Pass/PassBase.td" + +def FullLoopUnroll : Pass<"full-loop-unroll"> { + let summary = "Fully unroll all loops"; + let description = [{ + Scan the IR for affine.for loops and unroll them all. + }]; + let dependentDialects = [ + + ]; +} + +#endif // INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_TD_ diff --git a/lib/Transforms/FullLoopUnroll/BUILD b/lib/Transforms/FullLoopUnroll/BUILD new file mode 100644 index 000000000..8b40f31d9 --- /dev/null +++ b/lib/Transforms/FullLoopUnroll/BUILD @@ -0,0 +1,22 @@ +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "FullLoopUnroll", + srcs = ["FullLoopUnroll.cpp"], + hdrs = [ + "@heir//include/Transforms/FullLoopUnroll:FullLoopUnroll.h", + ], + deps = [ + "@heir//include/Transforms/FullLoopUnroll:pass_inc_gen", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineUtils", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/lib/Transforms/FullLoopUnroll/FullLoopUnroll.cpp b/lib/Transforms/FullLoopUnroll/FullLoopUnroll.cpp new file mode 100644 index 000000000..b48b3f7b7 --- /dev/null +++ b/lib/Transforms/FullLoopUnroll/FullLoopUnroll.cpp @@ -0,0 +1,36 @@ +#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h" + +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Affine/LoopUtils.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#define GEN_PASS_DEF_FULLLOOPUNROLL +#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h.inc" + +struct AffineFullUnrollPattern : public OpRewritePattern { + AffineFullUnrollPattern(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + LogicalResult matchAndRewrite(affine::AffineForOp op, + PatternRewriter &rewriter) const override { + return mlir::affine::loopUnrollFull(op); + } +}; + +struct FullLoopUnroll : impl::FullLoopUnrollBase { + using FullLoopUnrollBase::FullLoopUnrollBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add(context); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace heir +} // namespace mlir diff --git a/tests/full_loop_unroll.mlir b/tests/full_loop_unroll.mlir new file mode 100644 index 000000000..4170b2322 --- /dev/null +++ b/tests/full_loop_unroll.mlir @@ -0,0 +1,23 @@ +// RUN: heir-opt --full-loop-unroll %s | FileCheck %s + +!sks = !tfhe_rust.server_key + +// CHECK-LABEL: func @test_move_out_of_loop +func.func @test_move_out_of_loop(%sks : !sks, %lut: !tfhe_rust.lookup_table) -> memref<10x!tfhe_rust.eui3> { + // CHECK-NOT: affine.for + %0 = arith.constant 1 : i3 + %1 = arith.constant 2 : i3 + %memref = memref.alloca() : memref<10x!tfhe_rust.eui3> + + affine.for %i = 0 to 10 { + %e2 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 + %shiftAmount = arith.constant 1 : i8 + %e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2, %shiftAmount : (!sks, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %e1 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 + %eCombined = tfhe_rust.add %sks, %e1, %e2Shifted : (!sks, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 + %out = tfhe_rust.apply_lookup_table %sks, %eCombined, %lut : (!sks, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 + memref.store %out, %memref[%i] : memref<10x!tfhe_rust.eui3> + affine.yield + } + return %memref : memref<10x!tfhe_rust.eui3> +} diff --git a/tools/BUILD b/tools/BUILD index 823635c91..0c9bc44a6 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -55,6 +55,7 @@ cc_binary( "@heir//lib/Dialect/TfheRust/IR:Dialect", "@heir//lib/Dialect/TfheRustBool/IR:Dialect", "@heir//lib/Transforms/ForwardStoreToLoad", + "@heir//lib/Transforms/FullLoopUnroll", "@heir//lib/Transforms/Secretize", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 1e9d54b42..8eb0ed41e 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -22,6 +22,7 @@ #include "include/Dialect/TfheRust/IR/TfheRustDialect.h" #include "include/Dialect/TfheRustBool/IR/TfheRustBoolDialect.h" #include "include/Transforms/ForwardStoreToLoad/ForwardStoreToLoad.h" +#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h" #include "include/Transforms/Secretize/Passes.h" #include "llvm/include/llvm/Support/CommandLine.h" // from @llvm-project #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project @@ -288,6 +289,7 @@ int main(int argc, char **argv) { lwe::registerLWEPasses(); secret::registerSecretPasses(); registerSecretizePasses(); + registerFullLoopUnrollPasses(); registerForwardStoreToLoadPasses(); // Register yosys optimizer pipeline if configured. #ifndef HEIR_NO_YOSYS