From 6eb3b9e79124a038b57f01eb8b6ec9c18f80d436 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Thu, 29 Feb 2024 17:44:09 -0800 Subject: [PATCH] Generalize cggi-straight-line-vectorizer Supports any elementwise mappable ops, so we can apply it to the BGV path as well. --- include/Dialect/CGGI/IR/CGGIOps.td | 4 +- include/Dialect/CGGI/Transforms/BUILD | 1 - include/Dialect/CGGI/Transforms/Passes.h | 1 - include/Dialect/CGGI/Transforms/Passes.td | 9 - .../CGGI/Transforms/StraightLineVectorizer.h | 17 -- .../Transforms/StraightLineVectorizer/BUILD | 35 ++++ .../StraightLineVectorizer.h | 18 ++ .../StraightLineVectorizer.td | 15 ++ lib/Dialect/CGGI/Transforms/BUILD | 23 --- lib/Transforms/StraightLineVectorizer/BUILD | 24 +++ .../StraightLineVectorizer.cpp | 164 ++++++++++++++++++ tests/cggi/straight_line_vectorizer.mlir | 4 +- tools/BUILD | 1 + tools/heir-opt.cpp | 2 + 14 files changed, 263 insertions(+), 55 deletions(-) delete mode 100644 include/Dialect/CGGI/Transforms/StraightLineVectorizer.h create mode 100644 include/Transforms/StraightLineVectorizer/BUILD create mode 100644 include/Transforms/StraightLineVectorizer/StraightLineVectorizer.h create mode 100644 include/Transforms/StraightLineVectorizer/StraightLineVectorizer.td create mode 100644 lib/Transforms/StraightLineVectorizer/BUILD create mode 100644 lib/Transforms/StraightLineVectorizer/StraightLineVectorizer.cpp diff --git a/include/Dialect/CGGI/IR/CGGIOps.td b/include/Dialect/CGGI/IR/CGGIOps.td index 99459586c..1d139a614 100644 --- a/include/Dialect/CGGI/IR/CGGIOps.td +++ b/include/Dialect/CGGI/IR/CGGIOps.td @@ -28,7 +28,7 @@ class CGGI_BinaryGateOp Pure, Commutative, SameOperandsAndResultType, - Elementwise, + ElementwiseMappable, Scalarizable ]> { let arguments = (ins LWECiphertextLike:$lhs, LWECiphertextLike:$rhs); @@ -53,7 +53,7 @@ class CGGI_LutOp traits = []> : CGGI_Op diff --git a/include/Dialect/CGGI/Transforms/BUILD b/include/Dialect/CGGI/Transforms/BUILD index f901ebe2a..a7e82a43f 100644 --- a/include/Dialect/CGGI/Transforms/BUILD +++ b/include/Dialect/CGGI/Transforms/BUILD @@ -33,5 +33,4 @@ gentbl_cc_library( exports_files([ "Passes.h", "SetDefaultParameters.h", - "StraightLineVectorizer.h", ]) diff --git a/include/Dialect/CGGI/Transforms/Passes.h b/include/Dialect/CGGI/Transforms/Passes.h index 8ba6143de..69097351a 100644 --- a/include/Dialect/CGGI/Transforms/Passes.h +++ b/include/Dialect/CGGI/Transforms/Passes.h @@ -3,7 +3,6 @@ #include "include/Dialect/CGGI/IR/CGGIDialect.h" #include "include/Dialect/CGGI/Transforms/SetDefaultParameters.h" -#include "include/Dialect/CGGI/Transforms/StraightLineVectorizer.h" namespace mlir { namespace heir { diff --git a/include/Dialect/CGGI/Transforms/Passes.td b/include/Dialect/CGGI/Transforms/Passes.td index 48862d1e7..8787e08f4 100644 --- a/include/Dialect/CGGI/Transforms/Passes.td +++ b/include/Dialect/CGGI/Transforms/Passes.td @@ -19,13 +19,4 @@ def SetDefaultParameters : Pass<"cggi-set-default-parameters"> { let dependentDialects = ["mlir::heir::cggi::CGGIDialect"]; } -def StraightLineVectorizer : Pass<"cggi-straight-line-vectorizer"> { - let summary = "A straight-line vectorizer for CGGI bootstrapping ops."; - let description = [{ - This pass vectorizes CGGI ops. It ignores control flow and only vectorizes - straight-line programs within a given region. - }]; - let dependentDialects = ["mlir::heir::cggi::CGGIDialect"]; -} - #endif // INCLUDE_DIALECT_CGGI_TRANSFORMS_PASSES_TD_ diff --git a/include/Dialect/CGGI/Transforms/StraightLineVectorizer.h b/include/Dialect/CGGI/Transforms/StraightLineVectorizer.h deleted file mode 100644 index e97a4b3ef..000000000 --- a/include/Dialect/CGGI/Transforms/StraightLineVectorizer.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef INCLUDE_DIALECT_CGGI_TRANSFORMS_STRAIGHTLINEVECTORIZER_H_ -#define INCLUDE_DIALECT_CGGI_TRANSFORMS_STRAIGHTLINEVECTORIZER_H_ - -#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project - -namespace mlir { -namespace heir { -namespace cggi { - -#define GEN_PASS_DECL_STRAIGHTLINEVECTORIZER -#include "include/Dialect/CGGI/Transforms/Passes.h.inc" - -} // namespace cggi -} // namespace heir -} // namespace mlir - -#endif // INCLUDE_DIALECT_CGGI_TRANSFORMS_STRAIGHTLINEVECTORIZER_H_ diff --git a/include/Transforms/StraightLineVectorizer/BUILD b/include/Transforms/StraightLineVectorizer/BUILD new file mode 100644 index 000000000..7b2465611 --- /dev/null +++ b/include/Transforms/StraightLineVectorizer/BUILD @@ -0,0 +1,35 @@ +# StraightLineVectorizer tablegen and headers. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files([ + "StraightLineVectorizer.h", +]) + +gentbl_cc_library( + name = "pass_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=StraightLineVectorizer", + ], + "StraightLineVectorizer.h.inc", + ), + ( + ["-gen-pass-doc"], + "StraightLineVectorizerPasses.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "StraightLineVectorizer.td", + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) diff --git a/include/Transforms/StraightLineVectorizer/StraightLineVectorizer.h b/include/Transforms/StraightLineVectorizer/StraightLineVectorizer.h new file mode 100644 index 000000000..64427c7c4 --- /dev/null +++ b/include/Transforms/StraightLineVectorizer/StraightLineVectorizer.h @@ -0,0 +1,18 @@ +#ifndef INCLUDE_TRANSFORMS_STRAIGHTLINEVECTORIZER_STRAIGHTLINEVECTORIZER_H_ +#define INCLUDE_TRANSFORMS_STRAIGHTLINEVECTORIZER_STRAIGHTLINEVECTORIZER_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#define GEN_PASS_DECL +#include "include/Transforms/StraightLineVectorizer/StraightLineVectorizer.h.inc" + +#define GEN_PASS_REGISTRATION +#include "include/Transforms/StraightLineVectorizer/StraightLineVectorizer.h.inc" + +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_TRANSFORMS_STRAIGHTLINEVECTORIZER_STRAIGHTLINEVECTORIZER_H_ diff --git a/include/Transforms/StraightLineVectorizer/StraightLineVectorizer.td b/include/Transforms/StraightLineVectorizer/StraightLineVectorizer.td new file mode 100644 index 000000000..88ada6862 --- /dev/null +++ b/include/Transforms/StraightLineVectorizer/StraightLineVectorizer.td @@ -0,0 +1,15 @@ +#ifndef INCLUDE_TRANSFORMS_STRAIGHTLINEVECTORIZER_STRAIGHTLINEVECTORIZER_TD_ +#define INCLUDE_TRANSFORMS_STRAIGHTLINEVECTORIZER_STRAIGHTLINEVECTORIZER_TD_ + +include "mlir/Pass/PassBase.td" + +def StraightLineVectorizer : Pass<"straight-line-vectorize"> { + let summary = "A vectorizer for straight line programs."; + let description = [{ + This pass ignores control flow and only vectorizes straight-line programs + within a given region. + }]; + let dependentDialects = []; +} + +#endif // INCLUDE_TRANSFORMS_STRAIGHTLINEVECTORIZER_STRAIGHTLINEVECTORIZER_TD_ diff --git a/lib/Dialect/CGGI/Transforms/BUILD b/lib/Dialect/CGGI/Transforms/BUILD index cba7a00b1..c4f9d0bed 100644 --- a/lib/Dialect/CGGI/Transforms/BUILD +++ b/lib/Dialect/CGGI/Transforms/BUILD @@ -10,7 +10,6 @@ cc_library( ], deps = [ ":SetDefaultParameters", - ":StraightLineVectorizer", "@heir//include/Dialect/CGGI/Transforms:pass_inc_gen", "@heir//lib/Dialect/CGGI/IR:Dialect", "@llvm-project//mlir:IR", @@ -32,25 +31,3 @@ cc_library( "@llvm-project//mlir:Pass", ], ) - -cc_library( - name = "StraightLineVectorizer", - srcs = ["StraightLineVectorizer.cpp"], - hdrs = [ - "@heir//include/Dialect/CGGI/Transforms:StraightLineVectorizer.h", - ], - deps = [ - "@heir//include/Dialect/CGGI/Transforms:pass_inc_gen", - "@heir//include/Graph", - "@heir//lib/Dialect/CGGI/IR:Dialect", - "@heir//lib/Dialect/LWE/IR:Dialect", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Analysis", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:Transforms", - ], -) diff --git a/lib/Transforms/StraightLineVectorizer/BUILD b/lib/Transforms/StraightLineVectorizer/BUILD new file mode 100644 index 000000000..78ad1ead1 --- /dev/null +++ b/lib/Transforms/StraightLineVectorizer/BUILD @@ -0,0 +1,24 @@ +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "StraightLineVectorizer", + srcs = ["StraightLineVectorizer.cpp"], + hdrs = [ + "@heir//include/Transforms/StraightLineVectorizer:StraightLineVectorizer.h", + ], + deps = [ + "@heir//include/Graph", + "@heir//include/Transforms/StraightLineVectorizer:pass_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/lib/Transforms/StraightLineVectorizer/StraightLineVectorizer.cpp b/lib/Transforms/StraightLineVectorizer/StraightLineVectorizer.cpp new file mode 100644 index 000000000..bfce9e08a --- /dev/null +++ b/lib/Transforms/StraightLineVectorizer/StraightLineVectorizer.cpp @@ -0,0 +1,164 @@ +#include "include/Transforms/StraightLineVectorizer/StraightLineVectorizer.h" + +#include "include/Graph/Graph.h" +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/SliceAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/TopologicalSortUtils.h" // from @llvm-project + +#define DEBUG_TYPE "straight-line-vectorizer" + +namespace mlir { +namespace heir { + +#define GEN_PASS_DEF_STRAIGHTLINEVECTORIZER +#include "include/Transforms/StraightLineVectorizer/StraightLineVectorizer.h.inc" + +/// Returns true if the two operations can be combined into a single vectorized +/// operation. +bool areCompatible(Operation *lhs, Operation *rhs) { + if (lhs->getName() != rhs->getName() || + lhs->getDialect() != rhs->getDialect() || + lhs->getResultTypes() != rhs->getResultTypes()) { + return false; + } + return OpTrait::hasElementwiseMappableTraits(lhs); +} + +bool tryVectorizeBlock(Block *block) { + graph::Graph graph; + for (auto &op : block->getOperations()) { + if (!op.hasTrait()) { + continue; + } + + graph.addVertex(&op); + SetVector backwardSlice; + BackwardSliceOptions options; + options.omitBlockArguments = true; + + getBackwardSlice(&op, &backwardSlice, options); + for (auto *upstreamDep : backwardSlice) { + // An edge from upstreamDep to `op` means that upstreamDep must be + // computed before `op`. + graph.addEdge(upstreamDep, &op); + } + } + + if (graph.empty()) { + return false; + } + + auto result = graph.sortGraphByLevels(); + assert(succeeded(result) && + "Only possible failure is a cycle in the SSA graph!"); + auto levels = result.value(); + + LLVM_DEBUG({ + llvm::dbgs() + << "Found operations to vectorize. In topo-sorted level order:\n"; + int level_num = 0; + for (const auto &level : levels) { + llvm::dbgs() << "\nLevel " << level_num++ << ":\n"; + for (auto op : level) { + llvm::dbgs() << " - " << *op << "\n"; + } + } + }); + + bool madeReplacement = false; + for (const auto &level : levels) { + DenseMap> compatibleOps; + for (auto *op : level) { + bool foundCompatible = false; + for (auto &[key, bucket] : compatibleOps) { + if (areCompatible(key, op)) { + compatibleOps[key].push_back(op); + foundCompatible = true; + } + } + if (!foundCompatible) { + compatibleOps[op].push_back(op); + } + } + LLVM_DEBUG(llvm::dbgs() + << "Partitioned level of size " << level.size() << " into " + << compatibleOps.size() << " groups of compatible ops\n"); + + for (auto &[key, bucket] : compatibleOps) { + if (bucket.size() < 2) { + continue; + } + + LLVM_DEBUG({ + llvm::dbgs() << "Vectorizing ops:\n"; + for (auto op : bucket) { + llvm::dbgs() << " - " << *op << "\n"; + } + }); + + OpBuilder builder(bucket.back()); + // relies on CGGI ops having a single result type + Type elementType = key->getResultTypes()[0]; + RankedTensorType tensorType = RankedTensorType::get( + {static_cast(bucket.size())}, elementType); + + SmallVector vectorizedOperands; + for (int operandIndex = 0; operandIndex < key->getNumOperands(); + ++operandIndex) { + SmallVector operands; + operands.reserve(bucket.size()); + for (auto *op : bucket) { + operands.push_back(op->getOperand(operandIndex)); + } + auto fromElementsOp = builder.create( + key->getLoc(), tensorType, operands); + vectorizedOperands.push_back(fromElementsOp.getResult()); + } + + Operation *vectorizedOp = builder.clone(*key); + vectorizedOp->setOperands(vectorizedOperands); + vectorizedOp->getResult(0).setType(tensorType); + + int bucketIndex = 0; + for (auto *op : bucket) { + auto extractionIndex = builder.create( + op->getLoc(), builder.getIndexAttr(bucketIndex)); + auto extractOp = builder.create( + op->getLoc(), elementType, vectorizedOp->getResult(0), + extractionIndex.getResult()); + op->replaceAllUsesWith(ValueRange{extractOp.getResult()}); + bucketIndex++; + } + + for (auto *op : bucket) { + op->erase(); + } + madeReplacement = true; + } + } + + return madeReplacement; +} + +struct StraightLineVectorizer + : impl::StraightLineVectorizerBase { + using StraightLineVectorizerBase::StraightLineVectorizerBase; + + void runOnOperation() override { + getOperation()->walk([&](Block *block) { + if (tryVectorizeBlock(block)) { + sortTopologically(block); + } + }); + } +}; + +} // namespace heir +} // namespace mlir diff --git a/tests/cggi/straight_line_vectorizer.mlir b/tests/cggi/straight_line_vectorizer.mlir index 9d93f7cae..c5609c806 100644 --- a/tests/cggi/straight_line_vectorizer.mlir +++ b/tests/cggi/straight_line_vectorizer.mlir @@ -1,11 +1,11 @@ -// RUN: heir-opt --cggi-straight-line-vectorizer %s | FileCheck %s +// RUN: heir-opt --straight-line-vectorize %s | FileCheck %s #encoding = #lwe.unspecified_bit_field_encoding !ct_ty = !lwe.lwe_ciphertext !pt_ty = !lwe.lwe_plaintext // CHECK-LABEL: add_one -// CHECK: cggi.lut3(%[[arg1:.*]], %[[arg2:.*]], %[[arg3:.*]]) {lookup_table = 105 : ui8} : tensor<6x!lwe.lwe_ciphertext +// CHECK: cggi.lut3(%[[arg1:.*]], %[[arg2:.*]], %[[arg3:.*]]) {lookup_table = 105 : ui8} : tensor<8x!lwe.lwe_ciphertext func.func @add_one(%arg0: tensor<8x!ct_ty>) -> tensor<8x!ct_ty> { %true = arith.constant true %false = arith.constant false diff --git a/tools/BUILD b/tools/BUILD index 0c9bc44a6..5cb84a371 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -57,6 +57,7 @@ cc_binary( "@heir//lib/Transforms/ForwardStoreToLoad", "@heir//lib/Transforms/FullLoopUnroll", "@heir//lib/Transforms/Secretize", + "@heir//lib/Transforms/StraightLineVectorizer", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 8eb0ed41e..e7c25029e 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -24,6 +24,7 @@ #include "include/Transforms/ForwardStoreToLoad/ForwardStoreToLoad.h" #include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h" #include "include/Transforms/Secretize/Passes.h" +#include "include/Transforms/StraightLineVectorizer/StraightLineVectorizer.h" #include "llvm/include/llvm/Support/CommandLine.h" // from @llvm-project #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project @@ -291,6 +292,7 @@ int main(int argc, char **argv) { registerSecretizePasses(); registerFullLoopUnrollPasses(); registerForwardStoreToLoadPasses(); + registerStraightLineVectorizerPasses(); // Register yosys optimizer pipeline if configured. #ifndef HEIR_NO_YOSYS const char *abcEnvPath = std::getenv("HEIR_ABC_BINARY");