Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
150 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# FullLoopUnroll tablegen and headers. | ||
|
||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") | ||
|
||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
exports_files([ | ||
"FullLoopUnroll.h", | ||
]) | ||
|
||
gentbl_cc_library( | ||
name = "pass_inc_gen", | ||
tbl_outs = [ | ||
( | ||
[ | ||
"-gen-pass-decls", | ||
"-name=FullLoopUnroll", | ||
], | ||
"FullLoopUnroll.h.inc", | ||
), | ||
( | ||
["-gen-pass-doc"], | ||
"FullLoopUnrollPasses.md", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "FullLoopUnroll.td", | ||
deps = [ | ||
"@llvm-project//mlir:OpBaseTdFiles", | ||
"@llvm-project//mlir:PassBaseTdFiles", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#ifndef INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_H_ | ||
#define INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_H_ | ||
|
||
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
|
||
#define GEN_PASS_DECL | ||
#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h.inc" | ||
|
||
#define GEN_PASS_REGISTRATION | ||
#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h.inc" | ||
|
||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#ifndef INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_TD_ | ||
#define INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_TD_ | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def FullLoopUnroll : Pass<"full-loop-unroll"> { | ||
let summary = "Fully unroll all loops"; | ||
let description = [{ | ||
Scan the IR for affine.for loops and unroll them all. | ||
}]; | ||
let dependentDialects = [ | ||
|
||
]; | ||
} | ||
|
||
#endif // INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_TD_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
cc_library( | ||
name = "FullLoopUnroll", | ||
srcs = ["FullLoopUnroll.cpp"], | ||
hdrs = [ | ||
"@heir//include/Transforms/FullLoopUnroll:FullLoopUnroll.h", | ||
], | ||
deps = [ | ||
"@heir//include/Transforms/FullLoopUnroll:pass_inc_gen", | ||
"@llvm-project//mlir:AffineDialect", | ||
"@llvm-project//mlir:AffineUtils", | ||
"@llvm-project//mlir:FuncDialect", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:Pass", | ||
"@llvm-project//mlir:Support", | ||
"@llvm-project//mlir:Transforms", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h" | ||
|
||
#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Affine/LoopUtils.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
|
||
#define GEN_PASS_DEF_FULLLOOPUNROLL | ||
#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h.inc" | ||
|
||
struct FullLoopUnroll : impl::FullLoopUnrollBase<FullLoopUnroll> { | ||
using FullLoopUnrollBase::FullLoopUnrollBase; | ||
|
||
void runOnOperation() override { | ||
auto walkResult = | ||
getOperation()->walk<WalkOrder::PostOrder>([&](affine::AffineForOp op) { | ||
auto result = mlir::affine::loopUnrollFull(op); | ||
if (failed(result)) { | ||
return WalkResult::interrupt(); | ||
} | ||
return WalkResult::advance(); | ||
}); | ||
|
||
if (walkResult.wasInterrupted()) { | ||
signalPassFailure(); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace heir | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
// RUN: heir-opt --full-loop-unroll %s | FileCheck %s | ||
|
||
!sks = !tfhe_rust.server_key | ||
|
||
// CHECK-LABEL: func @test_move_out_of_loop | ||
func.func @test_move_out_of_loop(%sks : !sks, %lut: !tfhe_rust.lookup_table) -> memref<10x!tfhe_rust.eui3> { | ||
// CHECK-NOT: affine.for | ||
%0 = arith.constant 1 : i3 | ||
%1 = arith.constant 2 : i3 | ||
%memref = memref.alloca() : memref<10x!tfhe_rust.eui3> | ||
|
||
affine.for %i = 0 to 10 { | ||
%e2 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 | ||
%shiftAmount = arith.constant 1 : i8 | ||
%e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2, %shiftAmount : (!sks, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 | ||
%e1 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 | ||
%eCombined = tfhe_rust.add %sks, %e1, %e2Shifted : (!sks, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 | ||
%out = tfhe_rust.apply_lookup_table %sks, %eCombined, %lut : (!sks, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 | ||
memref.store %out, %memref[%i] : memref<10x!tfhe_rust.eui3> | ||
affine.yield | ||
} | ||
return %memref : memref<10x!tfhe_rust.eui3> | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters