Skip to content

Commit

Permalink
Generalize cggi-straight-line-vectorizer
Browse files Browse the repository at this point in the history
Supports any elementwise mappable ops, so we can apply it to
the BGV path as well.
  • Loading branch information
j2kun committed Mar 2, 2024
1 parent 7b6751c commit 6eb3b9e
Show file tree
Hide file tree
Showing 14 changed files with 263 additions and 55 deletions.
4 changes: 2 additions & 2 deletions include/Dialect/CGGI/IR/CGGIOps.td
Expand Up @@ -28,7 +28,7 @@ class CGGI_BinaryGateOp<string mnemonic>
Pure,
Commutative,
SameOperandsAndResultType,
Elementwise,
ElementwiseMappable,
Scalarizable
]> {
let arguments = (ins LWECiphertextLike:$lhs, LWECiphertextLike:$rhs);
Expand All @@ -53,7 +53,7 @@ class CGGI_LutOp<string mnemonic, list<Trait> traits = []>
: CGGI_Op<mnemonic, traits # [
Pure,
Commutative,
Elementwise,
ElementwiseMappable,
Scalarizable,
DeclareOpInterfaceMethods<LUTOpInterface>

Expand Down
1 change: 0 additions & 1 deletion include/Dialect/CGGI/Transforms/BUILD
Expand Up @@ -33,5 +33,4 @@ gentbl_cc_library(
exports_files([
"Passes.h",
"SetDefaultParameters.h",
"StraightLineVectorizer.h",
])
1 change: 0 additions & 1 deletion include/Dialect/CGGI/Transforms/Passes.h
Expand Up @@ -3,7 +3,6 @@

#include "include/Dialect/CGGI/IR/CGGIDialect.h"
#include "include/Dialect/CGGI/Transforms/SetDefaultParameters.h"
#include "include/Dialect/CGGI/Transforms/StraightLineVectorizer.h"

namespace mlir {
namespace heir {
Expand Down
9 changes: 0 additions & 9 deletions include/Dialect/CGGI/Transforms/Passes.td
Expand Up @@ -19,13 +19,4 @@ def SetDefaultParameters : Pass<"cggi-set-default-parameters"> {
let dependentDialects = ["mlir::heir::cggi::CGGIDialect"];
}

def StraightLineVectorizer : Pass<"cggi-straight-line-vectorizer"> {
let summary = "A straight-line vectorizer for CGGI bootstrapping ops.";
let description = [{
This pass vectorizes CGGI ops. It ignores control flow and only vectorizes
straight-line programs within a given region.
}];
let dependentDialects = ["mlir::heir::cggi::CGGIDialect"];
}

#endif // INCLUDE_DIALECT_CGGI_TRANSFORMS_PASSES_TD_
17 changes: 0 additions & 17 deletions include/Dialect/CGGI/Transforms/StraightLineVectorizer.h

This file was deleted.

35 changes: 35 additions & 0 deletions include/Transforms/StraightLineVectorizer/BUILD
@@ -0,0 +1,35 @@
# StraightLineVectorizer tablegen and headers.

load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

exports_files([
"StraightLineVectorizer.h",
])

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=StraightLineVectorizer",
],
"StraightLineVectorizer.h.inc",
),
(
["-gen-pass-doc"],
"StraightLineVectorizerPasses.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "StraightLineVectorizer.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
18 changes: 18 additions & 0 deletions include/Transforms/StraightLineVectorizer/StraightLineVectorizer.h
@@ -0,0 +1,18 @@
#ifndef INCLUDE_TRANSFORMS_STRAIGHTLINEVECTORIZER_STRAIGHTLINEVECTORIZER_H_
#define INCLUDE_TRANSFORMS_STRAIGHTLINEVECTORIZER_STRAIGHTLINEVECTORIZER_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {

#define GEN_PASS_DECL
#include "include/Transforms/StraightLineVectorizer/StraightLineVectorizer.h.inc"

#define GEN_PASS_REGISTRATION
#include "include/Transforms/StraightLineVectorizer/StraightLineVectorizer.h.inc"

} // namespace heir
} // namespace mlir

#endif // INCLUDE_TRANSFORMS_STRAIGHTLINEVECTORIZER_STRAIGHTLINEVECTORIZER_H_
@@ -0,0 +1,15 @@
#ifndef INCLUDE_TRANSFORMS_STRAIGHTLINEVECTORIZER_STRAIGHTLINEVECTORIZER_TD_
#define INCLUDE_TRANSFORMS_STRAIGHTLINEVECTORIZER_STRAIGHTLINEVECTORIZER_TD_

include "mlir/Pass/PassBase.td"

def StraightLineVectorizer : Pass<"straight-line-vectorize"> {
let summary = "A vectorizer for straight line programs.";
let description = [{
This pass ignores control flow and only vectorizes straight-line programs
within a given region.
}];
let dependentDialects = [];
}

#endif // INCLUDE_TRANSFORMS_STRAIGHTLINEVECTORIZER_STRAIGHTLINEVECTORIZER_TD_
23 changes: 0 additions & 23 deletions lib/Dialect/CGGI/Transforms/BUILD
Expand Up @@ -10,7 +10,6 @@ cc_library(
],
deps = [
":SetDefaultParameters",
":StraightLineVectorizer",
"@heir//include/Dialect/CGGI/Transforms:pass_inc_gen",
"@heir//lib/Dialect/CGGI/IR:Dialect",
"@llvm-project//mlir:IR",
Expand All @@ -32,25 +31,3 @@ cc_library(
"@llvm-project//mlir:Pass",
],
)

cc_library(
name = "StraightLineVectorizer",
srcs = ["StraightLineVectorizer.cpp"],
hdrs = [
"@heir//include/Dialect/CGGI/Transforms:StraightLineVectorizer.h",
],
deps = [
"@heir//include/Dialect/CGGI/Transforms:pass_inc_gen",
"@heir//include/Graph",
"@heir//lib/Dialect/CGGI/IR:Dialect",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)
24 changes: 24 additions & 0 deletions lib/Transforms/StraightLineVectorizer/BUILD
@@ -0,0 +1,24 @@
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "StraightLineVectorizer",
srcs = ["StraightLineVectorizer.cpp"],
hdrs = [
"@heir//include/Transforms/StraightLineVectorizer:StraightLineVectorizer.h",
],
deps = [
"@heir//include/Graph",
"@heir//include/Transforms/StraightLineVectorizer:pass_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)
164 changes: 164 additions & 0 deletions lib/Transforms/StraightLineVectorizer/StraightLineVectorizer.cpp
@@ -0,0 +1,164 @@
#include "include/Transforms/StraightLineVectorizer/StraightLineVectorizer.h"

#include "include/Graph/Graph.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/SliceAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/TopologicalSortUtils.h" // from @llvm-project

#define DEBUG_TYPE "straight-line-vectorizer"

namespace mlir {
namespace heir {

#define GEN_PASS_DEF_STRAIGHTLINEVECTORIZER
#include "include/Transforms/StraightLineVectorizer/StraightLineVectorizer.h.inc"

/// Returns true if the two operations can be combined into a single vectorized
/// operation.
bool areCompatible(Operation *lhs, Operation *rhs) {
if (lhs->getName() != rhs->getName() ||
lhs->getDialect() != rhs->getDialect() ||
lhs->getResultTypes() != rhs->getResultTypes()) {
return false;
}
return OpTrait::hasElementwiseMappableTraits(lhs);
}

bool tryVectorizeBlock(Block *block) {
graph::Graph<Operation *> graph;
for (auto &op : block->getOperations()) {
if (!op.hasTrait<OpTrait::Elementwise>()) {
continue;
}

graph.addVertex(&op);
SetVector<Operation *> backwardSlice;
BackwardSliceOptions options;
options.omitBlockArguments = true;

getBackwardSlice(&op, &backwardSlice, options);
for (auto *upstreamDep : backwardSlice) {
// An edge from upstreamDep to `op` means that upstreamDep must be
// computed before `op`.
graph.addEdge(upstreamDep, &op);
}
}

if (graph.empty()) {
return false;
}

auto result = graph.sortGraphByLevels();
assert(succeeded(result) &&
"Only possible failure is a cycle in the SSA graph!");
auto levels = result.value();

LLVM_DEBUG({
llvm::dbgs()
<< "Found operations to vectorize. In topo-sorted level order:\n";
int level_num = 0;
for (const auto &level : levels) {
llvm::dbgs() << "\nLevel " << level_num++ << ":\n";
for (auto op : level) {
llvm::dbgs() << " - " << *op << "\n";
}
}
});

bool madeReplacement = false;
for (const auto &level : levels) {
DenseMap<Operation *, SmallVector<Operation *, 4>> compatibleOps;
for (auto *op : level) {
bool foundCompatible = false;
for (auto &[key, bucket] : compatibleOps) {
if (areCompatible(key, op)) {
compatibleOps[key].push_back(op);
foundCompatible = true;
}
}
if (!foundCompatible) {
compatibleOps[op].push_back(op);
}
}
LLVM_DEBUG(llvm::dbgs()
<< "Partitioned level of size " << level.size() << " into "
<< compatibleOps.size() << " groups of compatible ops\n");

for (auto &[key, bucket] : compatibleOps) {
if (bucket.size() < 2) {
continue;
}

LLVM_DEBUG({
llvm::dbgs() << "Vectorizing ops:\n";
for (auto op : bucket) {
llvm::dbgs() << " - " << *op << "\n";
}
});

OpBuilder builder(bucket.back());
// relies on CGGI ops having a single result type
Type elementType = key->getResultTypes()[0];
RankedTensorType tensorType = RankedTensorType::get(
{static_cast<int64_t>(bucket.size())}, elementType);

SmallVector<Value, 4> vectorizedOperands;
for (int operandIndex = 0; operandIndex < key->getNumOperands();
++operandIndex) {
SmallVector<Value, 4> operands;
operands.reserve(bucket.size());
for (auto *op : bucket) {
operands.push_back(op->getOperand(operandIndex));
}
auto fromElementsOp = builder.create<tensor::FromElementsOp>(
key->getLoc(), tensorType, operands);
vectorizedOperands.push_back(fromElementsOp.getResult());
}

Operation *vectorizedOp = builder.clone(*key);
vectorizedOp->setOperands(vectorizedOperands);
vectorizedOp->getResult(0).setType(tensorType);

int bucketIndex = 0;
for (auto *op : bucket) {
auto extractionIndex = builder.create<arith::ConstantOp>(
op->getLoc(), builder.getIndexAttr(bucketIndex));
auto extractOp = builder.create<tensor::ExtractOp>(
op->getLoc(), elementType, vectorizedOp->getResult(0),
extractionIndex.getResult());
op->replaceAllUsesWith(ValueRange{extractOp.getResult()});
bucketIndex++;
}

for (auto *op : bucket) {
op->erase();
}
madeReplacement = true;
}
}

return madeReplacement;
}

struct StraightLineVectorizer
: impl::StraightLineVectorizerBase<StraightLineVectorizer> {
using StraightLineVectorizerBase::StraightLineVectorizerBase;

void runOnOperation() override {
getOperation()->walk<WalkOrder::PreOrder>([&](Block *block) {
if (tryVectorizeBlock(block)) {
sortTopologically(block);
}
});
}
};

} // namespace heir
} // namespace mlir
4 changes: 2 additions & 2 deletions tests/cggi/straight_line_vectorizer.mlir
@@ -1,11 +1,11 @@
// RUN: heir-opt --cggi-straight-line-vectorizer %s | FileCheck %s
// RUN: heir-opt --straight-line-vectorize %s | FileCheck %s

#encoding = #lwe.unspecified_bit_field_encoding<cleartext_bitwidth = 3>
!ct_ty = !lwe.lwe_ciphertext<encoding = #encoding>
!pt_ty = !lwe.lwe_plaintext<encoding = #encoding>

// CHECK-LABEL: add_one
// CHECK: cggi.lut3(%[[arg1:.*]], %[[arg2:.*]], %[[arg3:.*]]) {lookup_table = 105 : ui8} : tensor<6x!lwe.lwe_ciphertext
// CHECK: cggi.lut3(%[[arg1:.*]], %[[arg2:.*]], %[[arg3:.*]]) {lookup_table = 105 : ui8} : tensor<8x!lwe.lwe_ciphertext
func.func @add_one(%arg0: tensor<8x!ct_ty>) -> tensor<8x!ct_ty> {
%true = arith.constant true
%false = arith.constant false
Expand Down
1 change: 1 addition & 0 deletions tools/BUILD
Expand Up @@ -57,6 +57,7 @@ cc_binary(
"@heir//lib/Transforms/ForwardStoreToLoad",
"@heir//lib/Transforms/FullLoopUnroll",
"@heir//lib/Transforms/Secretize",
"@heir//lib/Transforms/StraightLineVectorizer",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineToStandard",
Expand Down
2 changes: 2 additions & 0 deletions tools/heir-opt.cpp
Expand Up @@ -24,6 +24,7 @@
#include "include/Transforms/ForwardStoreToLoad/ForwardStoreToLoad.h"
#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h"
#include "include/Transforms/Secretize/Passes.h"
#include "include/Transforms/StraightLineVectorizer/StraightLineVectorizer.h"
#include "llvm/include/llvm/Support/CommandLine.h" // from @llvm-project
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project
Expand Down Expand Up @@ -291,6 +292,7 @@ int main(int argc, char **argv) {
registerSecretizePasses();
registerFullLoopUnrollPasses();
registerForwardStoreToLoadPasses();
registerStraightLineVectorizerPasses();
// Register yosys optimizer pipeline if configured.
#ifndef HEIR_NO_YOSYS
const char *abcEnvPath = std::getenv("HEIR_ABC_BINARY");
Expand Down

0 comments on commit 6eb3b9e

Please sign in to comment.