Skip to content

Commit

Permalink
tests update
Browse files Browse the repository at this point in the history
  • Loading branch information
Wouter Legiest committed Mar 19, 2024
2 parents 63410a8 + 3c818ba commit 7a8759a
Show file tree
Hide file tree
Showing 54 changed files with 2,175 additions and 101 deletions.
2 changes: 1 addition & 1 deletion bazel/import_llvm.bzl
Expand Up @@ -7,7 +7,7 @@ load(

def import_llvm(name):
"""Imports LLVM."""
LLVM_COMMIT = "a83f8e0314fcdda162e54cbba1c9dcf230dff093"
LLVM_COMMIT = "a4ca07f13b560b4f6fa5459eef7159e4f9ee9a6b"

new_git_repository(
name = name,
Expand Down
9 changes: 9 additions & 0 deletions include/Analysis/TargetSlotAnalysis/BUILD
@@ -0,0 +1,9 @@
# TargetSlotAnalysis analysis pass
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

exports_files(
["TargetSlotAnalysis.h"],
)
134 changes: 134 additions & 0 deletions include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h
@@ -0,0 +1,134 @@
#ifndef INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_
#define INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_

#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace target_slot_analysis {

/// A target slot is an identification of a downstream tensor index at which an
/// SSA value will be used. To make the previous sentence even mildly
/// comprehensible, consider it in the following example.
///
/// %c3 = arith.constant 3 : index
/// %c4 = arith.constant 4 : index
/// %c11 = arith.constant 11 : index
/// %c15 = arith.constant 15 : index
/// %v11 = tensor.extract %arg1[%c11] : tensor<16xi32>
/// %v15 = tensor.extract %arg1[%c15] : tensor<16xi32>
/// %1 = arith.addi %v11, %v15: i32
/// %v3 = tensor.extract %arg1[%c3] : tensor<16xi32>
/// %2 = arith.addi %v3, %1 : i32
/// %inserted = tensor.insert %2 into %output[%c4] : tensor<16xi32>
///
/// In vectorized FHE schemes like BGV, the computation model does not
/// efficiently support extracting values at particular indices; instead, it
/// supports SIMD additions of entire vectors, and cyclic rotations of vectors
/// by constant shifts. To optimize the above computation, we want to convert
/// the extractions to rotations, and minimize rotations as much as possible.
///
/// A naive conversion convert tensor.extract %arg1[Z] to arith.rotate %arg1,
/// Z, always placing the needed values in the zero-th slot. However, the last
/// line above indicates that the downstream dependencies of these computations
/// are ultimately needed in slot 4 of the %output tensor. So one could reduce
/// the number of rotations by rotating instead to slot 4, so that the final
/// rotation is not needed.
///
/// This analysis identifies that downstream insertion index, and propagates it
/// backward through the IR to attach it to each SSA value, enabling later
/// optimization passes to access it easily.
///
/// As it turns out, if the IR is well-structured, such as an unrolled affine
/// for loop with simple iteration strides, then aligning to target slots in
/// this way leads to many common sub-expressions that can be eliminated. Cf.
/// the insert-rotate pass for more on that.

class TargetSlot {
public:
TargetSlot() : value(std::nullopt) {}
TargetSlot(int64_t value) : value(value) {}
~TargetSlot() = default;

/// Whether the slot target is initialized. It can be uninitialized when the
/// state hasn't been set during the analysis.
bool isInitialized() const { return value.has_value(); }

/// Get a known slot target.
const int64_t &getValue() const {
assert(isInitialized());
return *value;
}

bool operator==(const TargetSlot &rhs) const { return value == rhs.value; }

/// Join two target slots.
static TargetSlot join(const TargetSlot &lhs, const TargetSlot &rhs) {
if (!lhs.isInitialized()) return rhs;
if (!rhs.isInitialized()) return lhs;
// If they are both initialized, use an arbitrary deterministic rule to
// select one. A more sophisticated analysis could try to determine which
// slot is more likely to lead to beneficial optimizations.
return TargetSlot{lhs.getValue() < rhs.getValue() ? lhs.getValue()
: rhs.getValue()};
}

void print(raw_ostream &os) const { os << value; }

private:
/// The target slot, if known.
std::optional<int64_t> value;

friend mlir::Diagnostic &operator<<(mlir::Diagnostic &diagnostic,
const TargetSlot &foo) {
if (foo.isInitialized()) {
return diagnostic << foo.getValue();
}
return diagnostic << "uninitialized";
}
};

inline raw_ostream &operator<<(raw_ostream &os, const TargetSlot &v) {
v.print(os);
return os;
}

class TargetSlotLattice : public dataflow::Lattice<TargetSlot> {
public:
using Lattice::Lattice;
};

/// An analysis that identifies a target slot for an SSA value in a program.
/// This is used by downstream passes to determine how to align rotations in
/// vectorized FHE schemes.
///
/// We use a backward dataflow analysis because the target slot propagates
/// backward from its final use to the arithmetic operations at which rotations
/// can be optimized.
class TargetSlotAnalysis
: public dataflow::SparseBackwardDataFlowAnalysis<TargetSlotLattice> {
public:
explicit TargetSlotAnalysis(DataFlowSolver &solver,
SymbolTableCollection &symbolTable)
: SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
~TargetSlotAnalysis() override = default;
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;

// Given the computed results of the operation, update its operand lattice
// values.
void visitOperation(Operation *op, ArrayRef<TargetSlotLattice *> operands,
ArrayRef<const TargetSlotLattice *> results) override;

void visitBranchOperand(OpOperand &operand) override{};
void visitCallOperand(OpOperand &operand) override{};
void setToExitState(TargetSlotLattice *lattice) override{};
};

} // namespace target_slot_analysis
} // namespace heir
} // namespace mlir

#endif // INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_
12 changes: 12 additions & 0 deletions include/Dialect/Comb/IR/Combinational.td
Expand Up @@ -80,6 +80,9 @@ def XorOp : UTVariadicOp<"xor", [Commutative]> {
bool isBinaryNot();
}];
}
def XNorOp : UTVariadicOp<"xnor">;
def NandOp : UTVariadicOp<"nand">;
def NorOp : UTVariadicOp<"nor">;

//===----------------------------------------------------------------------===//
// Comparisons
Expand Down Expand Up @@ -154,6 +157,15 @@ def ICmpOp : CombOp<"icmp", [Pure, SameTypeOperands]> {
// Unary Operations
//===----------------------------------------------------------------------===//

class UnaryOp<string mnemonic, list<Trait> traits = []> :
CombOp<mnemonic, traits # [Pure, SameOperandsAndResultType]> {
let arguments = (ins HWIntegerType:$input, UnitAttr:$twoState);
let results = (outs HWIntegerType:$result);

let assemblyFormat = "(`bin` $twoState^)? $input attr-dict `:` qualified(type($input))";
}
def InvOp : UnaryOp<"inv">;

// Base class for unary reduction operations that produce an i1.
class UnaryI1ReductionOp<string mnemonic, list<Trait> traits = []> :
CombOp<mnemonic, traits # [Pure]> {
Expand Down
2 changes: 1 addition & 1 deletion include/Dialect/Polynomial/IR/PolynomialTypes.td
Expand Up @@ -14,7 +14,7 @@ class Polynomial_Type<string name, string typeMnemonic, list<Trait> traits = []>
let mnemonic = typeMnemonic;
}

def Polynomial : Polynomial_Type<"Polynomial", "polynomial", [MemRefElementTypeInterface]> {
def Polynomial : Polynomial_Type<"Polynomial", "polynomial"> {
let summary = "An element of a polynomial quotient ring";

let description = [{
Expand Down
5 changes: 3 additions & 2 deletions include/Dialect/TensorExt/Transforms/BUILD
@@ -1,4 +1,4 @@
# InsertRotate tablegen and headers.
# TensorExt pass tablegen and headers.

load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

Expand Down Expand Up @@ -50,7 +50,8 @@ gentbl_cc_library(
)

exports_files([
"Passes.h",
"CollapseInsertionChains.h",
"InsertRotate.h",
"Passes.h",
"RotateAndReduce.h",
])
5 changes: 4 additions & 1 deletion include/Dialect/TensorExt/Transforms/InsertRotate.h
@@ -1,7 +1,10 @@
#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_
#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project
#include "include/Dialect/TensorExt/IR/TensorExtOps.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
Expand Down
1 change: 1 addition & 0 deletions include/Dialect/TensorExt/Transforms/Passes.h
Expand Up @@ -4,6 +4,7 @@
#include "include/Dialect/TensorExt/IR/TensorExtDialect.h"
#include "include/Dialect/TensorExt/Transforms/CollapseInsertionChains.h"
#include "include/Dialect/TensorExt/Transforms/InsertRotate.h"
#include "include/Dialect/TensorExt/Transforms/RotateAndReduce.h"

namespace mlir {
namespace heir {
Expand Down
43 changes: 43 additions & 0 deletions include/Dialect/TensorExt/Transforms/Passes.td
Expand Up @@ -60,4 +60,47 @@ def CollapseInsertionChains : Pass<"collapse-insertion-chains"> {
let dependentDialects = ["mlir::heir::tensor_ext::TensorExtDialect"];
}

def RotateAndReduce : Pass<"rotate-and-reduce"> {
let summary = "Use a logarithmic number of rotations to reduce a tensor.";
let description = [{
This pass identifies when a commutative, associative binary operation is used
to reduce all of the entries of a tensor to a single value, and optimizes the
operations by using a logarithmic number of reduction operations.

In particular, this pass identifies an unrolled set of operations of the form
(the binary ops may come in any order):

```mlir
%0 = tensor.extract %t[0] : tensor<8xi32>
%1 = tensor.extract %t[1] : tensor<8xi32>
%2 = tensor.extract %t[2] : tensor<8xi32>
%3 = tensor.extract %t[3] : tensor<8xi32>
%4 = tensor.extract %t[4] : tensor<8xi32>
%5 = tensor.extract %t[5] : tensor<8xi32>
%6 = tensor.extract %t[6] : tensor<8xi32>
%7 = tensor.extract %t[7] : tensor<8xi32>
%8 = arith.addi %0, %1 : i32
%9 = arith.addi %8, %2 : i32
%10 = arith.addi %9, %3 : i32
%11 = arith.addi %10, %4 : i32
%12 = arith.addi %11, %5 : i32
%13 = arith.addi %12, %6 : i32
%14 = arith.addi %13, %7 : i32
```

and replaces it with a logarithmic number of `rotate` and `addi` operations:

```mlir
%0 = tensor_ext.rotate %t, 4 : tensor<8xi32>
%1 = arith.addi %t, %0 : tensor<8xi32>
%2 = tensor_ext.rotate %1, 2 : tensor<8xi32>
%3 = arith.addi %1, %2 : tensor<8xi32>
%4 = tensor_ext.rotate %3, 1 : tensor<8xi32>
%5 = arith.addi %3, %4 : tensor<8xi32>
```
}];
let dependentDialects = ["mlir::heir::tensor_ext::TensorExtDialect"];
}


#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_TD_
17 changes: 17 additions & 0 deletions include/Dialect/TensorExt/Transforms/RotateAndReduce.h
@@ -0,0 +1,17 @@
#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_ROTATEANDREDUCE_H_
#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_ROTATEANDREDUCE_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace tensor_ext {

#define GEN_PASS_DECL_ROTATEANDREDUCE
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc"

} // namespace tensor_ext
} // namespace heir
} // namespace mlir

#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_ROTATEANDREDUCE_H_
35 changes: 35 additions & 0 deletions include/Transforms/ElementwiseToAffine/BUILD
@@ -0,0 +1,35 @@
# ElementwiseToAffine tablegen and headers.

load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

exports_files([
"ElementwiseToAffine.h",
])

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=ElementwiseToAffine",
],
"ElementwiseToAffine.h.inc",
),
(
["-gen-pass-doc"],
"ElementwiseToAffinePasses.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ElementwiseToAffine.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
18 changes: 18 additions & 0 deletions include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h
@@ -0,0 +1,18 @@
#ifndef INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_H_
#define INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {

#define GEN_PASS_DECL
#include "include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h.inc"

#define GEN_PASS_REGISTRATION
#include "include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h.inc"

} // namespace heir
} // namespace mlir

#endif // INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_H_
18 changes: 18 additions & 0 deletions include/Transforms/ElementwiseToAffine/ElementwiseToAffine.td
@@ -0,0 +1,18 @@
#ifndef INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_TD_
#define INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_TD_

include "mlir/Pass/PassBase.td"

def ElementwiseToAffine : Pass<"convert-elementwise-to-affine"> {
let summary = "This pass lowers ElementwiseMappable operations to Affine loops.";
let description = [{
This pass lowers ElementwiseMappable operations over tensors
to affine loop nests that instead apply the operation to the underlying scalar values.
}];
let dependentDialects = [
"mlir::affine::AffineDialect",
"mlir::tensor::TensorDialect"
];
}

#endif // INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_TD_

0 comments on commit 7a8759a

Please sign in to comment.