diff --git a/bazel/import_llvm.bzl b/bazel/import_llvm.bzl index 6fe18055c..9503b1690 100644 --- a/bazel/import_llvm.bzl +++ b/bazel/import_llvm.bzl @@ -7,7 +7,7 @@ load( def import_llvm(name): """Imports LLVM.""" - LLVM_COMMIT = "a83f8e0314fcdda162e54cbba1c9dcf230dff093" + LLVM_COMMIT = "a4ca07f13b560b4f6fa5459eef7159e4f9ee9a6b" new_git_repository( name = name, diff --git a/include/Analysis/TargetSlotAnalysis/BUILD b/include/Analysis/TargetSlotAnalysis/BUILD new file mode 100644 index 000000000..085e170b2 --- /dev/null +++ b/include/Analysis/TargetSlotAnalysis/BUILD @@ -0,0 +1,9 @@ +# TargetSlotAnalysis analysis pass +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files( + ["TargetSlotAnalysis.h"], +) diff --git a/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h b/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h new file mode 100644 index 000000000..984fb338c --- /dev/null +++ b/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h @@ -0,0 +1,134 @@ +#ifndef INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_ +#define INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_ + +#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace target_slot_analysis { + +/// A target slot is an identification of a downstream tensor index at which an +/// SSA value will be used. To make the previous sentence even mildly +/// comprehensible, consider it in the following example. +/// +/// %c3 = arith.constant 3 : index +/// %c4 = arith.constant 4 : index +/// %c11 = arith.constant 11 : index +/// %c15 = arith.constant 15 : index +/// %v11 = tensor.extract %arg1[%c11] : tensor<16xi32> +/// %v15 = tensor.extract %arg1[%c15] : tensor<16xi32> +/// %1 = arith.addi %v11, %v15: i32 +/// %v3 = tensor.extract %arg1[%c3] : tensor<16xi32> +/// %2 = arith.addi %v3, %1 : i32 +/// %inserted = tensor.insert %2 into %output[%c4] : tensor<16xi32> +/// +/// In vectorized FHE schemes like BGV, the computation model does not +/// efficiently support extracting values at particular indices; instead, it +/// supports SIMD additions of entire vectors, and cyclic rotations of vectors +/// by constant shifts. To optimize the above computation, we want to convert +/// the extractions to rotations, and minimize rotations as much as possible. +/// +/// A naive conversion convert tensor.extract %arg1[Z] to arith.rotate %arg1, +/// Z, always placing the needed values in the zero-th slot. However, the last +/// line above indicates that the downstream dependencies of these computations +/// are ultimately needed in slot 4 of the %output tensor. So one could reduce +/// the number of rotations by rotating instead to slot 4, so that the final +/// rotation is not needed. +/// +/// This analysis identifies that downstream insertion index, and propagates it +/// backward through the IR to attach it to each SSA value, enabling later +/// optimization passes to access it easily. +/// +/// As it turns out, if the IR is well-structured, such as an unrolled affine +/// for loop with simple iteration strides, then aligning to target slots in +/// this way leads to many common sub-expressions that can be eliminated. Cf. +/// the insert-rotate pass for more on that. + +class TargetSlot { + public: + TargetSlot() : value(std::nullopt) {} + TargetSlot(int64_t value) : value(value) {} + ~TargetSlot() = default; + + /// Whether the slot target is initialized. It can be uninitialized when the + /// state hasn't been set during the analysis. + bool isInitialized() const { return value.has_value(); } + + /// Get a known slot target. + const int64_t &getValue() const { + assert(isInitialized()); + return *value; + } + + bool operator==(const TargetSlot &rhs) const { return value == rhs.value; } + + /// Join two target slots. + static TargetSlot join(const TargetSlot &lhs, const TargetSlot &rhs) { + if (!lhs.isInitialized()) return rhs; + if (!rhs.isInitialized()) return lhs; + // If they are both initialized, use an arbitrary deterministic rule to + // select one. A more sophisticated analysis could try to determine which + // slot is more likely to lead to beneficial optimizations. + return TargetSlot{lhs.getValue() < rhs.getValue() ? lhs.getValue() + : rhs.getValue()}; + } + + void print(raw_ostream &os) const { os << value; } + + private: + /// The target slot, if known. + std::optional value; + + friend mlir::Diagnostic &operator<<(mlir::Diagnostic &diagnostic, + const TargetSlot &foo) { + if (foo.isInitialized()) { + return diagnostic << foo.getValue(); + } + return diagnostic << "uninitialized"; + } +}; + +inline raw_ostream &operator<<(raw_ostream &os, const TargetSlot &v) { + v.print(os); + return os; +} + +class TargetSlotLattice : public dataflow::Lattice { + public: + using Lattice::Lattice; +}; + +/// An analysis that identifies a target slot for an SSA value in a program. +/// This is used by downstream passes to determine how to align rotations in +/// vectorized FHE schemes. +/// +/// We use a backward dataflow analysis because the target slot propagates +/// backward from its final use to the arithmetic operations at which rotations +/// can be optimized. +class TargetSlotAnalysis + : public dataflow::SparseBackwardDataFlowAnalysis { + public: + explicit TargetSlotAnalysis(DataFlowSolver &solver, + SymbolTableCollection &symbolTable) + : SparseBackwardDataFlowAnalysis(solver, symbolTable) {} + ~TargetSlotAnalysis() override = default; + using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; + + // Given the computed results of the operation, update its operand lattice + // values. + void visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override; + + void visitBranchOperand(OpOperand &operand) override{}; + void visitCallOperand(OpOperand &operand) override{}; + void setToExitState(TargetSlotLattice *lattice) override{}; +}; + +} // namespace target_slot_analysis +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_ diff --git a/include/Dialect/Comb/IR/Combinational.td b/include/Dialect/Comb/IR/Combinational.td index faa8bd2a0..639da10a0 100644 --- a/include/Dialect/Comb/IR/Combinational.td +++ b/include/Dialect/Comb/IR/Combinational.td @@ -80,6 +80,9 @@ def XorOp : UTVariadicOp<"xor", [Commutative]> { bool isBinaryNot(); }]; } +def XNorOp : UTVariadicOp<"xnor">; +def NandOp : UTVariadicOp<"nand">; +def NorOp : UTVariadicOp<"nor">; //===----------------------------------------------------------------------===// // Comparisons @@ -154,6 +157,15 @@ def ICmpOp : CombOp<"icmp", [Pure, SameTypeOperands]> { // Unary Operations //===----------------------------------------------------------------------===// +class UnaryOp traits = []> : + CombOp { + let arguments = (ins HWIntegerType:$input, UnitAttr:$twoState); + let results = (outs HWIntegerType:$result); + + let assemblyFormat = "(`bin` $twoState^)? $input attr-dict `:` qualified(type($input))"; +} +def InvOp : UnaryOp<"inv">; + // Base class for unary reduction operations that produce an i1. class UnaryI1ReductionOp traits = []> : CombOp { diff --git a/include/Dialect/Polynomial/IR/PolynomialTypes.td b/include/Dialect/Polynomial/IR/PolynomialTypes.td index 2637131da..14d282464 100644 --- a/include/Dialect/Polynomial/IR/PolynomialTypes.td +++ b/include/Dialect/Polynomial/IR/PolynomialTypes.td @@ -14,7 +14,7 @@ class Polynomial_Type traits = []> let mnemonic = typeMnemonic; } -def Polynomial : Polynomial_Type<"Polynomial", "polynomial", [MemRefElementTypeInterface]> { +def Polynomial : Polynomial_Type<"Polynomial", "polynomial"> { let summary = "An element of a polynomial quotient ring"; let description = [{ diff --git a/include/Dialect/TensorExt/Transforms/BUILD b/include/Dialect/TensorExt/Transforms/BUILD index 814553ce4..5ea104990 100644 --- a/include/Dialect/TensorExt/Transforms/BUILD +++ b/include/Dialect/TensorExt/Transforms/BUILD @@ -1,4 +1,4 @@ -# InsertRotate tablegen and headers. +# TensorExt pass tablegen and headers. load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") @@ -50,7 +50,8 @@ gentbl_cc_library( ) exports_files([ - "Passes.h", "CollapseInsertionChains.h", "InsertRotate.h", + "Passes.h", + "RotateAndReduce.h", ]) diff --git a/include/Dialect/TensorExt/Transforms/InsertRotate.h b/include/Dialect/TensorExt/Transforms/InsertRotate.h index b07d3af46..38f037a11 100644 --- a/include/Dialect/TensorExt/Transforms/InsertRotate.h +++ b/include/Dialect/TensorExt/Transforms/InsertRotate.h @@ -1,7 +1,10 @@ #ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_ #define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_ -#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project +#include "include/Dialect/TensorExt/IR/TensorExtOps.h" +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project namespace mlir { namespace heir { diff --git a/include/Dialect/TensorExt/Transforms/Passes.h b/include/Dialect/TensorExt/Transforms/Passes.h index 01eafff25..6e101454b 100644 --- a/include/Dialect/TensorExt/Transforms/Passes.h +++ b/include/Dialect/TensorExt/Transforms/Passes.h @@ -4,6 +4,7 @@ #include "include/Dialect/TensorExt/IR/TensorExtDialect.h" #include "include/Dialect/TensorExt/Transforms/CollapseInsertionChains.h" #include "include/Dialect/TensorExt/Transforms/InsertRotate.h" +#include "include/Dialect/TensorExt/Transforms/RotateAndReduce.h" namespace mlir { namespace heir { diff --git a/include/Dialect/TensorExt/Transforms/Passes.td b/include/Dialect/TensorExt/Transforms/Passes.td index 88d396611..1b9f2b487 100644 --- a/include/Dialect/TensorExt/Transforms/Passes.td +++ b/include/Dialect/TensorExt/Transforms/Passes.td @@ -60,4 +60,47 @@ def CollapseInsertionChains : Pass<"collapse-insertion-chains"> { let dependentDialects = ["mlir::heir::tensor_ext::TensorExtDialect"]; } +def RotateAndReduce : Pass<"rotate-and-reduce"> { + let summary = "Use a logarithmic number of rotations to reduce a tensor."; + let description = [{ + This pass identifies when a commutative, associative binary operation is used + to reduce all of the entries of a tensor to a single value, and optimizes the + operations by using a logarithmic number of reduction operations. + + In particular, this pass identifies an unrolled set of operations of the form + (the binary ops may come in any order): + + ```mlir + %0 = tensor.extract %t[0] : tensor<8xi32> + %1 = tensor.extract %t[1] : tensor<8xi32> + %2 = tensor.extract %t[2] : tensor<8xi32> + %3 = tensor.extract %t[3] : tensor<8xi32> + %4 = tensor.extract %t[4] : tensor<8xi32> + %5 = tensor.extract %t[5] : tensor<8xi32> + %6 = tensor.extract %t[6] : tensor<8xi32> + %7 = tensor.extract %t[7] : tensor<8xi32> + %8 = arith.addi %0, %1 : i32 + %9 = arith.addi %8, %2 : i32 + %10 = arith.addi %9, %3 : i32 + %11 = arith.addi %10, %4 : i32 + %12 = arith.addi %11, %5 : i32 + %13 = arith.addi %12, %6 : i32 + %14 = arith.addi %13, %7 : i32 + ``` + + and replaces it with a logarithmic number of `rotate` and `addi` operations: + + ```mlir + %0 = tensor_ext.rotate %t, 4 : tensor<8xi32> + %1 = arith.addi %t, %0 : tensor<8xi32> + %2 = tensor_ext.rotate %1, 2 : tensor<8xi32> + %3 = arith.addi %1, %2 : tensor<8xi32> + %4 = tensor_ext.rotate %3, 1 : tensor<8xi32> + %5 = arith.addi %3, %4 : tensor<8xi32> + ``` + }]; + let dependentDialects = ["mlir::heir::tensor_ext::TensorExtDialect"]; +} + + #endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_TD_ diff --git a/include/Dialect/TensorExt/Transforms/RotateAndReduce.h b/include/Dialect/TensorExt/Transforms/RotateAndReduce.h new file mode 100644 index 000000000..51605ec1e --- /dev/null +++ b/include/Dialect/TensorExt/Transforms/RotateAndReduce.h @@ -0,0 +1,17 @@ +#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_ROTATEANDREDUCE_H_ +#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_ROTATEANDREDUCE_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace tensor_ext { + +#define GEN_PASS_DECL_ROTATEANDREDUCE +#include "include/Dialect/TensorExt/Transforms/Passes.h.inc" + +} // namespace tensor_ext +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_ROTATEANDREDUCE_H_ diff --git a/include/Transforms/ElementwiseToAffine/BUILD b/include/Transforms/ElementwiseToAffine/BUILD new file mode 100644 index 000000000..07fd7c045 --- /dev/null +++ b/include/Transforms/ElementwiseToAffine/BUILD @@ -0,0 +1,35 @@ +# ElementwiseToAffine tablegen and headers. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files([ + "ElementwiseToAffine.h", +]) + +gentbl_cc_library( + name = "pass_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=ElementwiseToAffine", + ], + "ElementwiseToAffine.h.inc", + ), + ( + ["-gen-pass-doc"], + "ElementwiseToAffinePasses.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ElementwiseToAffine.td", + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) diff --git a/include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h b/include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h new file mode 100644 index 000000000..48b4a1b49 --- /dev/null +++ b/include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h @@ -0,0 +1,18 @@ +#ifndef INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_H_ +#define INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#define GEN_PASS_DECL +#include "include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h.inc" + +#define GEN_PASS_REGISTRATION +#include "include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h.inc" + +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_H_ diff --git a/include/Transforms/ElementwiseToAffine/ElementwiseToAffine.td b/include/Transforms/ElementwiseToAffine/ElementwiseToAffine.td new file mode 100644 index 000000000..cd83ee554 --- /dev/null +++ b/include/Transforms/ElementwiseToAffine/ElementwiseToAffine.td @@ -0,0 +1,18 @@ +#ifndef INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_TD_ +#define INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_TD_ + +include "mlir/Pass/PassBase.td" + +def ElementwiseToAffine : Pass<"convert-elementwise-to-affine"> { + let summary = "This pass lowers ElementwiseMappable operations to Affine loops."; + let description = [{ + This pass lowers ElementwiseMappable operations over tensors + to affine loop nests that instead apply the operation to the underlying scalar values. + }]; + let dependentDialects = [ + "mlir::affine::AffineDialect", + "mlir::tensor::TensorDialect" + ]; +} + +#endif // INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_TD_ diff --git a/include/Transforms/YosysOptimizer/YosysOptimizer.h b/include/Transforms/YosysOptimizer/YosysOptimizer.h index 17b9939b2..c7339eea0 100644 --- a/include/Transforms/YosysOptimizer/YosysOptimizer.h +++ b/include/Transforms/YosysOptimizer/YosysOptimizer.h @@ -1,14 +1,19 @@ #ifndef INCLUDE_TRANSFORMS_YOSYSOPTIMIZER_YOSYSOPTIMIZER_H_ #define INCLUDE_TRANSFORMS_YOSYSOPTIMIZER_YOSYSOPTIMIZER_H_ -#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project +#include + +#include "llvm/include/llvm/Support/CommandLine.h" // from @llvm-project +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project namespace mlir { namespace heir { +enum Mode { Boolean, LUT }; + std::unique_ptr createYosysOptimizer( const std::string &yosysFilesPath, const std::string &abcPath, bool abcFast, - int unrollFactor = 0, bool printStats = false); + int unrollFactor = 0, Mode mode = LUT, bool printStats = false); #define GEN_PASS_DECL #include "include/Transforms/YosysOptimizer/YosysOptimizer.h.inc" @@ -25,6 +30,13 @@ struct YosysOptimizerPipelineOptions "value of zero (default) prevents unrolling."), llvm::cl::init(0)}; + PassOptions::Option mode{ + *this, "mode", + llvm::cl::desc("Map gates to boolean gates or lookup table gates."), + llvm::cl::init(LUT), + llvm::cl::values(clEnumVal(Boolean, "use boolean gates"), + clEnumVal(LUT, "use lookup tables"))}; + PassOptions::Option printStats{ *this, "print-stats", llvm::cl::desc("Prints statistics about the optimized circuit"), diff --git a/include/Transforms/YosysOptimizer/YosysOptimizer.td b/include/Transforms/YosysOptimizer/YosysOptimizer.td index a4e52f67d..9bfcbb4de 100644 --- a/include/Transforms/YosysOptimizer/YosysOptimizer.td +++ b/include/Transforms/YosysOptimizer/YosysOptimizer.td @@ -24,6 +24,7 @@ def YosysOptimizer : Pass<"yosys-optimizer"> { - `unroll-factor`: Before optimizing the circuit, unroll loops by a given factor. If unset, this pass will not unroll any loops. - `print-stats`: Prints statistics about the optimized circuits. + - `mode={Boolean,LUT}`: Map gates to boolean gates or lookup table gates. }]; // TODO(#257): add option for the pass to select the unroll factor // automatically. diff --git a/lib/Analysis/TargetSlotAnalysis/BUILD b/lib/Analysis/TargetSlotAnalysis/BUILD new file mode 100644 index 000000000..d04a520b5 --- /dev/null +++ b/lib/Analysis/TargetSlotAnalysis/BUILD @@ -0,0 +1,19 @@ +# TargetSlotAnalysis analysis pass +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "TargetSlotAnalysis", + srcs = ["TargetSlotAnalysis.cpp"], + hdrs = ["@heir//include/Analysis/TargetSlotAnalysis:TargetSlotAnalysis.h"], + deps = [ + "@heir//lib/Dialect:Utils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + ], +) diff --git a/lib/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.cpp b/lib/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.cpp new file mode 100644 index 000000000..3fe126af1 --- /dev/null +++ b/lib/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.cpp @@ -0,0 +1,57 @@ +#include "include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h" + +#include "lib/Dialect/Utils.h" +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project + +#define DEBUG_TYPE "target-slot-analysis" + +namespace mlir { +namespace heir { +namespace target_slot_analysis { + +void TargetSlotAnalysis::visitOperation( + Operation *op, ArrayRef operands, + ArrayRef results) { + llvm::TypeSwitch(*op) + .Case([&](auto insertOp) { + LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << *op << "\n"; }); + auto insertIndexRes = get1DExtractionIndex(insertOp); + // If the target slot can't be statically determined, we can't + // propagate anything through the IR. + if (failed(insertIndexRes)) return; + + // The target slot propagates to the value inserted, which is the first + // positional argument + TargetSlotLattice *lattice = operands[0]; + TargetSlot newSlot = TargetSlot{insertIndexRes.value()}; + LLVM_DEBUG({ + llvm::dbgs() << "Joining " << lattice->getValue() << " and " + << newSlot << " --> " + << TargetSlot::join(lattice->getValue(), newSlot) + << "\n"; + }); + ChangeResult changed = lattice->join(newSlot); + propagateIfChanged(lattice, changed); + }) + .Default([&](Operation &op) { + // By default, an op propagates its result target slots to all its + // operands. + for (const TargetSlotLattice *r : results) { + for (TargetSlotLattice *operand : operands) { + ChangeResult result = operand->join(*r); + propagateIfChanged(operand, result); + } + } + }); +} + +} // namespace target_slot_analysis +} // namespace heir +} // namespace mlir diff --git a/lib/Conversion/BUILD b/lib/Conversion/BUILD index 56f0fcf03..f3ecd6420 100644 --- a/lib/Conversion/BUILD +++ b/lib/Conversion/BUILD @@ -8,9 +8,14 @@ cc_library( srcs = ["Utils.cpp"], hdrs = ["Utils.h"], deps = [ + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncTransforms", + "@llvm-project//mlir:IR", "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], ) diff --git a/lib/Conversion/PolynomialToStandard/BUILD b/lib/Conversion/PolynomialToStandard/BUILD index 41234d13c..ab7d20000 100644 --- a/lib/Conversion/PolynomialToStandard/BUILD +++ b/lib/Conversion/PolynomialToStandard/BUILD @@ -12,6 +12,8 @@ cc_library( deps = [ "@heir//include/Conversion/PolynomialToStandard:pass_inc_gen", "@heir//lib/Conversion:Utils", + "@heir//lib/Dialect/Polynomial/IR:Polynomial", + "@heir//lib/Dialect/Polynomial/IR:PolynomialAttributes", "@heir//lib/Dialect/Polynomial/IR:PolynomialOps", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", @@ -23,6 +25,7 @@ cc_library( "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], diff --git a/lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp b/lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp index d493dbb7c..b6e376463 100644 --- a/lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp +++ b/lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp @@ -1,18 +1,45 @@ #include "include/Conversion/PolynomialToStandard/PolynomialToStandard.h" +#include +#include +#include +#include +#include +#include + +#include "include/Dialect/Polynomial/IR/Polynomial.h" +#include "include/Dialect/Polynomial/IR/PolynomialAttributes.h" +#include "include/Dialect/Polynomial/IR/PolynomialDialect.h" #include "include/Dialect/Polynomial/IR/PolynomialOps.h" #include "include/Dialect/Polynomial/IR/PolynomialTypes.h" #include "lib/Conversion/Utils.h" +#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project #include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h" // from @llvm-project #include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project #include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h" // from @llvm-project -#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/include/mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/IR/Location.h" // from @llvm-project +#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project @@ -474,11 +501,12 @@ struct ConvertMul : public OpConversionPattern { // 2N - 1 sized result tensor -> reduce modulo ideal to get a N sized tensor func::FuncOp divMod = getFuncOpCallback(funcType, polyTy.getRing()); - if (!divMod) + if (!divMod) { return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) { diag << "Missing software implementation for polynomial mod op of type" << funcType << " and for ring " << polyTy.getRing(); }); + } rewriter.replaceOpWithNewOp(op, divMod, polyMul.getResult(0)); return success(); @@ -739,8 +767,8 @@ void PolynomialToStandard::runOnOperation() { ConvertConstant, ConvertMulScalar>(typeConverter, context); patterns.add(typeConverter, patterns.getContext(), getDivmodOp); addStructuralConversionPatterns(typeConverter, patterns, target); + addTensorOfTensorConversionPatterns(typeConverter, patterns, target); - // TODO(#143): Handle tensor of polys. if (failed(applyPartialConversion(module, target, std::move(patterns)))) { signalPassFailure(); } diff --git a/lib/Conversion/Utils.cpp b/lib/Conversion/Utils.cpp index 2bf992392..32cb51e09 100644 --- a/lib/Conversion/Utils.cpp +++ b/lib/Conversion/Utils.cpp @@ -1,8 +1,23 @@ #include "lib/Conversion/Utils.h" +#include +#include +#include + +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project #include "mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/include/mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/Region.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir { @@ -12,6 +27,233 @@ using ::mlir::func::CallOp; using ::mlir::func::FuncOp; using ::mlir::func::ReturnOp; +struct ConvertAny : public ConversionPattern { + ConvertAny(const TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern(typeConverter, RewritePattern::MatchAnyOpTypeTag(), + /*benefit=*/1, context) { + setDebugName("ConvertAny"); + setHasBoundedRewriteRecursion(true); + } + + // generate a new op where all operands have been replaced with their + // materialized/typeconverted versions + LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + SmallVector newOperandTypes; + if (failed(getTypeConverter()->convertTypes(op->getOperandTypes(), + newOperandTypes))) + return failure(); + + SmallVector newResultTypes; + if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), + newResultTypes))) + return failure(); + + SmallVector, 1> regions; + IRMapping mapping; + for (auto &r : op->getRegions()) { + Region *newRegion = new Region(); + rewriter.cloneRegionBefore(r, *newRegion, newRegion->end(), mapping); + if (failed(rewriter.convertRegionTypes(newRegion, *this->typeConverter))) + return failure(); + regions.emplace_back(newRegion); + } + + Operation *newOp = rewriter.create(OperationState( + op->getLoc(), op->getName().getStringRef(), operands, newResultTypes, + op->getAttrs(), op->getSuccessors(), regions)); + + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +struct ConvertExtract : public OpConversionPattern { + ConvertExtract(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + // Convert a tensor.extract that would type-convert to extracting a tensor to + // a tensor.extract_slice operation instead. Specifically, this targets + // extracting SourceType from tensor<...xSourceType> when SourceType would be + // type converted to tensor<...>. + LogicalResult matchAndRewrite( + tensor::ExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // replace tensor.extract %t[%i] from tensor + // with an equivalent tensor.slice from tensor + auto shape = op.getTensor().getType().getShape(); + auto resultType = getTypeConverter() + ->convertType(op.getResult().getType()) + .cast(); + auto resultShape = resultType.getShape(); + + // expand op's list of indices by appending as many zeros as there are + // dimension in resultShape + SmallVector offsets; + offsets.append(op.getIndices().begin(), op.getIndices().end()); + for (size_t i = 0; i < resultShape.size(); ++i) { + offsets.push_back(rewriter.getIndexAttr(0)); + } + + // expand resultShape by prepending as many ones as there are dimensions in + // shape + SmallVector sizes; + for (size_t i = 0; i < shape.size(); ++i) { + sizes.push_back(rewriter.getIndexAttr(1)); + } + for (int64_t i : resultShape) { + sizes.push_back(rewriter.getIndexAttr(i)); + } + + // strides are all 1, and we need as many as there are dimensions in + // both shape and resultShape together + SmallVector strides; + for (size_t i = 0; i < shape.size() + resultShape.size(); ++i) { + strides.push_back(rewriter.getIndexAttr(1)); + } + + rewriter.replaceOpWithNewOp( + op, resultType, adaptor.getTensor(), offsets, sizes, strides); + + return success(); + } +}; + +struct ConvertInsert : public OpConversionPattern { + ConvertInsert(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + // Convert a tensor.insert that would type-convert to inserting a tensor to + // a tensor.insert_slice operation instead. Specifically, this targets + // inserting SourceType into tensor<...xSourceType> when SourceType would be + // type converted to tensor<...>. + LogicalResult matchAndRewrite( + tensor::InsertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // replace tensor.insert %s into %t[%i] with tensor + // with an equivalent tensor.insert_slice with tensor + auto shape = op.getDest().getType().getShape(); + auto resultType = getTypeConverter() + ->convertType(op.getScalar().getType()) + .cast(); + auto resultShape = resultType.getShape(); + + // expand op's list of indices by appending as many zeros as there are + // dimension in resultShape + SmallVector offsets; + offsets.append(op.getIndices().begin(), op.getIndices().end()); + for (size_t i = 0; i < resultShape.size(); ++i) { + offsets.push_back(rewriter.getIndexAttr(0)); + } + + // expand resultShape by prepending as many ones as there are dimensions in + // shape + SmallVector sizes; + for (size_t i = 0; i < shape.size(); ++i) { + sizes.push_back(rewriter.getIndexAttr(1)); + } + for (int64_t i : resultShape) { + sizes.push_back(rewriter.getIndexAttr(i)); + } + + // strides are all 1, and we need as many as there are dimensions in + // both shape and resultShape together + SmallVector strides; + for (size_t i = 0; i < shape.size() + resultShape.size(); ++i) { + strides.push_back(rewriter.getIndexAttr(1)); + } + + rewriter.replaceOpWithNewOp( + op, adaptor.getScalar(), adaptor.getDest(), offsets, sizes, strides); + + return success(); + } +}; + +struct ConvertFromElements + : public OpConversionPattern { + ConvertFromElements(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + // Converts a tensor.from_elements %s0, %s1, ... : tensor<...xSourceType> + // where SourceType would be type-converted to tensor<...> to + // a concatenation of the converted operands (with appropriate reshape) + LogicalResult matchAndRewrite( + tensor::FromElementsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Expand each of the (converted) operands: + SmallVector newOperands; + for (auto o : adaptor.getElements()) { + // extend tensor<...xT> to tensor<1x...xT> + if (auto tensorType = o.getType().dyn_cast()) { + auto shape = tensorType.getShape(); + SmallVector newShape(1, 1); + newShape.append(shape.begin(), shape.end()); + + // Create a dense constant for targetShape + auto shapeOp = rewriter.create( + op.getLoc(), + RankedTensorType::get(newShape.size(), rewriter.getIndexType()), + rewriter.getIndexTensorAttr(newShape)); + + auto reshapeOp = rewriter.create( + op.getLoc(), + RankedTensorType::get(newShape, tensorType.getElementType()), o, + shapeOp); + newOperands.push_back(reshapeOp); + } else { + newOperands.push_back(o); + } + } + // Create the final tensor.concat operation + rewriter.replaceOpWithNewOp(op, 0, newOperands); + + return success(); + } +}; + +void addTensorOfTensorConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target) { + target.addDynamicallyLegalDialect( + [&](Operation *op) { return typeConverter.isLegal(op); }); + + typeConverter.addConversion([&](TensorType type) -> Type { + if (!typeConverter.isLegal(type.getElementType())) { + typeConverter.convertType(type.getElementType()).dump(); + if (auto convertedType = + typeConverter.convertType(type.getElementType())) { + if (auto castConvertedType = + convertedType.dyn_cast()) { + // Create the combined shape + auto polyShape = castConvertedType.getShape(); + auto tensorShape = type.getShape(); + SmallVector combinedShape(tensorShape.begin(), + tensorShape.end()); + combinedShape.append(polyShape.begin(), polyShape.end()); + auto combinedType = RankedTensorType::get( + combinedShape, castConvertedType.getElementType()); + return combinedType; + } + } + } + return type; + }); + + target.addDynamicallyLegalDialect( + [&](Operation *op) { return typeConverter.isLegal(op); }); + + patterns.add( + typeConverter, patterns.getContext()); +} + void addStructuralConversionPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { diff --git a/lib/Conversion/Utils.h b/lib/Conversion/Utils.h index e2f953dc6..48054f295 100644 --- a/lib/Conversion/Utils.h +++ b/lib/Conversion/Utils.h @@ -1,11 +1,18 @@ #ifndef LIB_CONVERSION_UTILS_H_ #define LIB_CONVERSION_UTILS_H_ +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir { namespace heir { +// Adds conversion patterns that deal with tensor<..xsource_type> +// when source_type will be type converted to tensor<...>, too +void addTensorOfTensorConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target); + // Adds the standard set of conversion patterns for // converting types involved in func, cf, etc., which // don't depend on the logic of the dialect beyond the diff --git a/lib/Dialect/BUILD b/lib/Dialect/BUILD index fb5356f75..4667ec052 100644 --- a/lib/Dialect/BUILD +++ b/lib/Dialect/BUILD @@ -18,3 +18,16 @@ cc_library( "@llvm-project//mlir:IR", ], ) + +cc_library( + name = "Utils", + hdrs = [ + "Utils.h", + ], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) diff --git a/lib/Dialect/Comb/IR/CombOps.cpp b/lib/Dialect/Comb/IR/CombOps.cpp index d0024c25f..0f77c8485 100644 --- a/lib/Dialect/Comb/IR/CombOps.cpp +++ b/lib/Dialect/Comb/IR/CombOps.cpp @@ -160,6 +160,12 @@ LogicalResult OrOp::verify() { return verifyUTBinOp(*this); } LogicalResult XorOp::verify() { return verifyUTBinOp(*this); } +LogicalResult XNorOp::verify() { return verifyUTBinOp(*this); } + +LogicalResult NandOp::verify() { return verifyUTBinOp(*this); } + +LogicalResult NorOp::verify() { return verifyUTBinOp(*this); } + //===----------------------------------------------------------------------===// // ConcatOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TensorExt/Transforms/BUILD b/lib/Dialect/TensorExt/Transforms/BUILD index f06b7becd..9e534b610 100644 --- a/lib/Dialect/TensorExt/Transforms/BUILD +++ b/lib/Dialect/TensorExt/Transforms/BUILD @@ -11,6 +11,7 @@ cc_library( deps = [ ":CollapseInsertionChains", ":InsertRotate", + ":RotateAndReduce", "@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen", "@heir//lib/Dialect/TensorExt/IR:Dialect", "@llvm-project//mlir:IR", @@ -27,10 +28,14 @@ cc_library( "@heir//include/Dialect/TensorExt/IR:canonicalize_inc_gen", "@heir//include/Dialect/TensorExt/Transforms:insert_rotate_inc_gen", "@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen", + "@heir//lib/Analysis/TargetSlotAnalysis", "@heir//lib/Dialect/TensorExt/IR:Dialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], @@ -44,6 +49,7 @@ cc_library( ], deps = [ "@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen", + "@heir//lib/Dialect:Utils", "@heir//lib/Dialect/TensorExt/IR:Dialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", @@ -54,3 +60,22 @@ cc_library( "@llvm-project//mlir:Transforms", ], ) + +cc_library( + name = "RotateAndReduce", + srcs = ["RotateAndReduce.cpp"], + hdrs = [ + "@heir//include/Dialect/TensorExt/Transforms:RotateAndReduce.h", + ], + deps = [ + "@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen", + "@heir//lib/Dialect/TensorExt/IR:Dialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + ], +) diff --git a/lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp b/lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp index 63c2a9983..1ba85db49 100644 --- a/lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp +++ b/lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp @@ -4,14 +4,13 @@ #include #include "include/Dialect/TensorExt/IR/TensorExtOps.h" -#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project +#include "lib/Dialect/Utils.h" #include "llvm/include/llvm/Support/Debug.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/include/mlir/IR/Matchers.h" // from @llvm-project #include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project @@ -27,25 +26,6 @@ namespace tensor_ext { #define GEN_PASS_DEF_COLLAPSEINSERTIONCHAINS #include "include/Dialect/TensorExt/Transforms/Passes.h.inc" -template -FailureOr get1DExtractionIndex(Op op) { - auto insertIndices = op.getIndices(); - if (insertIndices.size() != 1) return failure(); - - // Each index must be constant; this may require running --canonicalize or - // -sccp before this pass to apply folding rules (use -sccp if you need to - // fold constants through control flow). - Value insertIndex = *insertIndices.begin(); - auto insertIndexConstOp = insertIndex.getDefiningOp(); - if (!insertIndexConstOp) return failure(); - - auto insertOffsetAttr = - llvm::dyn_cast(insertIndexConstOp.getValue()); - if (!insertOffsetAttr) return failure(); - - return insertOffsetAttr.getInt(); -} - /// A pattern that searches for sequences of extract + insert, where the /// indices extracted and inserted have the same offset, and replaced them with /// a single rotate operation. diff --git a/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp b/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp index c829ba34d..373980db7 100644 --- a/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp +++ b/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp @@ -2,14 +2,21 @@ #include -#include "include/Dialect/TensorExt/IR/TensorExtOps.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project -#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/include/mlir/IR/Matchers.h" // from @llvm-project -#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h" +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project +#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/include/mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#define DEBUG_TYPE "insert-rotate" + namespace mlir { namespace heir { namespace tensor_ext { @@ -33,6 +40,31 @@ struct InsertRotate : impl::InsertRotateBase { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); + SymbolTableCollection symbolTable; + DataFlowSolver solver; + // These two upstream analyses are required dependencies for any sparse + // dataflow analysis, or else the analysis will be a no-op. Cf. + // https://github.com/llvm/llvm-project/issues/58922 + solver.load(); + solver.load(); + solver.load(symbolTable); + if (failed(solver.initializeAndRun(getOperation()))) { + getOperation()->emitOpError() << "Failed to run the analysis.\n"; + signalPassFailure(); + return; + } + + LLVM_DEBUG({ + getOperation()->walk([&](Operation *op) { + if (op->getNumResults() == 0) return; + auto *targetSlotLattice = + solver.lookupState( + op->getResult(0)); + llvm::dbgs() << "Target slot for op " << *op << ": " + << targetSlotLattice->getValue() << "\n"; + }); + }); + alignment::populateWithGenerated(patterns); canonicalization::populateWithGenerated(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); diff --git a/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp b/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp new file mode 100644 index 000000000..fe5808271 --- /dev/null +++ b/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp @@ -0,0 +1,202 @@ +#include "include/Dialect/TensorExt/Transforms/RotateAndReduce.h" + +#include + +#include "include/Dialect/TensorExt/IR/TensorExtOps.h" +#include "llvm/include/llvm/ADT/DenseSet.h" // from @llvm-project +#include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/SliceAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/IR/Iterators.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project + +#define DEBUG_TYPE "rotate-and-reduce" + +namespace mlir { +namespace heir { +namespace tensor_ext { + +#define GEN_PASS_DEF_ROTATEANDREDUCE +#include "include/Dialect/TensorExt/Transforms/Passes.h.inc" + +/// A pass that searches for a length N sequence of binary operations that +/// reduces a length N vector to a single scalar, and replaces it with a +/// logarithmic number of rotations and binary operations. +struct RotateAndReduce : impl::RotateAndReduceBase { + using RotateAndReduceBase::RotateAndReduceBase; + + template + void tryReplace(ArithOp op, DenseSet &visited) { + LLVM_DEBUG(llvm::dbgs() << "Trying to replace " << *op << "\n"); + SetVector backwardSlice; + BackwardSliceOptions options; + // asserts that the parent op has a single region with a single block. + options.omitBlockArguments = false; + + DenseSet inputTensors; + DenseSet visitedReductionOps; + DenseSet accessIndices; + DenseMap opCounts; + opCounts[op->getName().getStringRef()]++; + + // TODO(#523): replace backward slice with a dataflow analysis + getBackwardSlice(op.getOperation(), &backwardSlice, options); + for (Operation *upstreamOpPtr : backwardSlice) { + auto result = + llvm::TypeSwitch(upstreamOpPtr) + .Case( + [&](auto upstreamOp) { return success(); }) + .template Case( + [&](auto upstreamOp) { + opCounts[upstreamOp->getName().getStringRef()]++; + // More than one reduction op is mixed in the reduction. + if (opCounts.size() > 1) { + LLVM_DEBUG(llvm::dbgs() + << "Not replacing op because reduction " + "contains multiple incompatible ops " + << op->getName() << " and " + << upstreamOp->getName() << "\n"); + return failure(); + } + + // TODO(#522): support these non-tensor-extract operands by + // saving the values, and applying them again to the final + // result. + for (Value operand : upstreamOp->getOperands()) { + if (operand.getDefiningOp()) { + LLVM_DEBUG(llvm::dbgs() + << "Not replacing op because reduction " + "includes non-tensor value operands " + << operand << "\n"); + return failure(); + } + } + visitedReductionOps.insert(upstreamOp); + return success(); + }) + .template Case([&](auto tensorOp) { + inputTensors.insert(tensorOp.getTensor()); + if (inputTensors.size() > 1) { + LLVM_DEBUG( + llvm::dbgs() + << "Not replacing op due to multiple input tensors\n"); + return failure(); + } + + // If the tensor is not 1D, we can't replace it with a rotate. + if (tensorOp.getIndices().size() != 1) { + LLVM_DEBUG(llvm::dbgs() + << "Not replacing op due to >1D input tensor\n"); + return failure(); + } + + // If the access index is not constant, we can't tell if we are + // reducing the entire vector (each index occurs exactly once in + // the redution). + arith::ConstantOp indexConstant = + tensorOp.getIndices() + .front() + .template getDefiningOp(); + if (!indexConstant) { + LLVM_DEBUG( + llvm::dbgs() + << "Not replacing op due to non constant index access;" + << " (do you need to run --canonicalize or --sccp?)\n"); + return failure(); + } + int64_t accessIndex = + indexConstant.getValue().cast().getInt(); + + // If the access index was already seen, then fail because some + // tensor element contributes more than once to the reduction. + if (accessIndices.count(accessIndex)) { + LLVM_DEBUG( + llvm::dbgs() + << "Not replacing op because input tensor was accessed " + "multiple times in at same index\n"); + return failure(); + } + LLVM_DEBUG(llvm::dbgs() + << "Adding valid index " << accessIndex << "\n"); + accessIndices.insert(accessIndex); + return success(); + }) + .Default([&](Operation *op) { return failure(); }); + + if (failed(result)) { + return; + } + } + + // The test for a match is now: does the number of accessed indices exactly + // match the size of the tensor? I.e., does each tensor element show up + // exactly once in the reduction? + auto tensorShape = + inputTensors.begin()->getType().cast().getShape(); + if (tensorShape.size() != 1 || tensorShape[0] != accessIndices.size()) { + LLVM_DEBUG(llvm::dbgs() + << "Not replacing op because tensor shape (" + << inputTensors.begin()->getType() + << ") is not fully reduced. Only " << accessIndices.size() + << " of " << tensorShape[0] << " indices were accessed\n"); + return; + } + + // From here we know we will succeed. + auto b = ImplicitLocOpBuilder(op->getLoc(), op); + Value inputTensor = *inputTensors.begin(); + Operation *finalOp; + for (int64_t shiftSize = tensorShape[0] / 2; shiftSize > 0; + shiftSize /= 2) { + auto rotatedTensor = b.create( + inputTensor, b.create(b.getIndexAttr(shiftSize))); + auto addOp = b.create(inputTensor, rotatedTensor); + finalOp = addOp; + inputTensor = addOp->getResult(0); + } + + [[maybe_unused]] auto *parentOp = op->getParentOp(); + // We can extract at any index; every index contains the same reduced value. + auto extractOp = b.create( + finalOp->getResult(0), b.create(0).getResult()); + op->replaceAllUsesWith(extractOp); + LLVM_DEBUG(llvm::dbgs() << "Post-replacement: " << *parentOp << "\n"); + + // Mark all ops in the reduction as visited so we don't try to replace them + // twice. + for (Operation *visitedOp : visitedReductionOps) { + visited.insert(visitedOp); + } + } + + void runOnOperation() override { + DenseSet visited; + // Traverse the IR in reverse order so that we can eagerly compute backward + // slices for each operation. + getOperation()->walk( + [&](Operation *op) { + if (visited.count(op)) { + return; + } + llvm::TypeSwitch(*op) + .Case([&](auto arithOp) { + tryReplace(arithOp, visited); + }) + .Case([&](auto arithOp) { + tryReplace(arithOp, visited); + }); + }); + } +}; + +} // namespace tensor_ext +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/Utils.h b/lib/Dialect/Utils.h new file mode 100644 index 000000000..ac0aef50a --- /dev/null +++ b/lib/Dialect/Utils.h @@ -0,0 +1,41 @@ +#ifndef INCLUDE_DIALECT_UTILS_H_ +#define INCLUDE_DIALECT_UTILS_H_ + +#include + +#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace heir { + +/// Given a tensor::InsertOp or tensor::ExtractOp, and assuming the shape +/// of the input tensor is 1-dimensional and the input index is constant, +/// return the constant index value. If any of these conditions are not +/// met, return a failure. +template +FailureOr get1DExtractionIndex(Op op) { + auto insertIndices = op.getIndices(); + if (insertIndices.size() != 1) return failure(); + + // Each index must be constant; this may require running --canonicalize or + // -sccp before this pass to apply folding rules (use -sccp if you need to + // fold constants through control flow). + Value insertIndex = *insertIndices.begin(); + auto insertIndexConstOp = insertIndex.getDefiningOp(); + if (!insertIndexConstOp) return failure(); + + auto insertOffsetAttr = + llvm::dyn_cast(insertIndexConstOp.getValue()); + if (!insertOffsetAttr) return failure(); + + return insertOffsetAttr.getInt(); +} + +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_DIALECT_UTILS_H_ diff --git a/lib/Transforms/ElementwiseToAffine/BUILD b/lib/Transforms/ElementwiseToAffine/BUILD new file mode 100644 index 000000000..b1b274632 --- /dev/null +++ b/lib/Transforms/ElementwiseToAffine/BUILD @@ -0,0 +1,23 @@ +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "ElementwiseToAffine", + srcs = ["ElementwiseToAffine.cpp"], + hdrs = [ + "@heir//include/Transforms/ElementwiseToAffine:ElementwiseToAffine.h", + ], + deps = [ + "@heir//include/Transforms/ElementwiseToAffine:pass_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/lib/Transforms/ElementwiseToAffine/ElementwiseToAffine.cpp b/lib/Transforms/ElementwiseToAffine/ElementwiseToAffine.cpp new file mode 100644 index 000000000..8878b88b5 --- /dev/null +++ b/lib/Transforms/ElementwiseToAffine/ElementwiseToAffine.cpp @@ -0,0 +1,146 @@ +#include "include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h" + +#include +#include + +#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project +#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#define GEN_PASS_DEF_ELEMENTWISETOAFFINE +#include "include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h.inc" + +// All of this is based on the ElementwiseToLinalg Pass in +// mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp + +static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { + if (!OpTrait::hasElementwiseMappableTraits(op)) return false; + + // TODO(#534): Test ElementwiseToAffine with `any_of` constraints + // as the pass should (in theory) support scalar operands, too + return llvm::all_of(op->getOperandTypes(), + [](Type type) { return isa(type); }); +} + +namespace { + +struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { + ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { + if (!isElementwiseMappableOpOnRankedTensors(op)) + return rewriter.notifyMatchFailure( + op, "requires elementwise op on ranked tensors"); + + auto resultType = cast(op->getResult(0).getType()); + auto elementType = resultType.getElementType(); + auto shape = resultType.getShape(); + auto rank = resultType.getRank(); + + // Save insertion point prior to entering loop nest + auto ip = rewriter.saveInsertionPoint(); + + // Create an empty tensor as initial value of the iter_args + Value target = + rewriter.create(op->getLoc(), shape, elementType); + + llvm::SmallVector indices; + + // Create a an affine.for loop nest of depth rank + for (size_t i = 0; i < rank; ++i) { + auto loop = + rewriter.create(op->getLoc(), /* lowerBound*/ 0, + /* upperBound*/ shape[i], + /* step*/ 1, + /* iterArgs*/ target); + + // Update target & indices + target = loop.getRegionIterArgs().front(); + indices.push_back(loop.getInductionVar()); + + // If first loop: replace scalar op + if (i == 0) { + rewriter.replaceOp(op, loop); + } else { // yield the result of this loop + rewriter.create(op->getLoc(), + loop->getResults()); + } + rewriter.setInsertionPointToStart(loop.getBody()); + } + + // Create the innermost body + auto resultTypes = + llvm::to_vector<6>(llvm::map_range(op->getResultTypes(), [](Type type) { + return cast(type).getElementType(); + })); + + // Generate a `tensor.extract` for each tensor operand + SmallVector newOperands; + for (auto operand : op->getOperands()) { + if (operand.getType().isa()) { + // We don't need to check the shape, as ElementwiseMappable + // requires all tensor operands to have compatible shapes + auto extractOp = rewriter.create(operand.getLoc(), + operand, indices); + newOperands.push_back(extractOp); + } else { + // scalar (technically, "non-tensor") operands can be reused as-is + newOperands.push_back(operand); + } + } + + // "lowered" operation is the same operation, but over non-tensor + // operands + auto *scalarOp = + rewriter.create(op->getLoc(), op->getName().getIdentifier(), + newOperands, resultTypes, op->getAttrs()); + + // insert scalarOp into the tensor at right index + Value inserted = rewriter.create( + op->getLoc(), scalarOp->getResult(0), target, indices); + + // replace lingalg.yield scalarOp with affine.yield insertedOp + rewriter.create(op->getLoc(), inserted); + + // reset insertion point + rewriter.restoreInsertionPoint(ip); + + return success(); + } +}; +} // namespace + +struct ElementwiseToAffine + : impl::ElementwiseToAffineBase { + using ElementwiseToAffineBase::ElementwiseToAffineBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + RewritePatternSet patterns(context); + + patterns.add(context); + target.markUnknownOpDynamicallyLegal([](Operation *op) { + return !isElementwiseMappableOpOnRankedTensors(op); + }); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace heir +} // namespace mlir diff --git a/lib/Transforms/YosysOptimizer/BUILD b/lib/Transforms/YosysOptimizer/BUILD index dc46f338d..7b9926124 100644 --- a/lib/Transforms/YosysOptimizer/BUILD +++ b/lib/Transforms/YosysOptimizer/BUILD @@ -62,6 +62,22 @@ cc_test( ], ) +cc_library( + name = "BooleanGateImporter", + srcs = ["BooleanGateImporter.cpp"], + hdrs = ["BooleanGateImporter.h"], + deps = [ + ":RTLILImporter", + "@at_clifford_yosys//:kernel", + "@heir//lib/Dialect/Comb/IR:Dialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "YosysOptimizer", srcs = ["YosysOptimizer.cpp"], @@ -73,6 +89,7 @@ cc_library( "@heir//lib/Transforms/YosysOptimizer/yosys:share_files", ], deps = [ + ":BooleanGateImporter", ":LUTImporter", ":RTLILImporter", "@at_clifford_yosys//:kernel", diff --git a/lib/Transforms/YosysOptimizer/BooleanGateImporter.cpp b/lib/Transforms/YosysOptimizer/BooleanGateImporter.cpp new file mode 100644 index 000000000..6d2613645 --- /dev/null +++ b/lib/Transforms/YosysOptimizer/BooleanGateImporter.cpp @@ -0,0 +1,58 @@ +#include "lib/Transforms/YosysOptimizer/BooleanGateImporter.h" + +#include + +#include "include/Dialect/Comb/IR/CombOps.h" +#include "llvm/include/llvm/Support/ErrorHandling.h" // from @llvm-project +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project + +// Block clang-format from reordering +// clang-format off +#include "kernel/rtlil.h" // from @at_clifford_yosys +// clang-format on + +namespace mlir { +namespace heir { + +mlir::Operation *BooleanGateImporter::createOp(Yosys::RTLIL::Cell *cell, + SmallVector &inputs, + ImplicitLocOpBuilder &b) const { + auto op = llvm::StringSwitch(cell->type.substr(1)) + .Case("inv", b.create(inputs[0], false)) + .Case("xnor2", b.create(inputs, false)) + .Case("and2", b.create(inputs, false)) + .Case("xor2", b.create(inputs, false)) + .Case("nand2", b.create(inputs, false)) + .Case("nor2", b.create(inputs, false)) + .Case("or2", b.create(inputs, false)) + .Default(nullptr); + if (op == nullptr) { + llvm_unreachable("unexpected cell type"); + } + return op; +} + +SmallVector BooleanGateImporter::getInputs( + Yosys::RTLIL::Cell *cell) const { + // Return all non-Y named attributes. + SmallVector inputs; + for (auto &conn : cell->connections()) { + if (conn.first.contains("Y")) { + continue; + } + inputs.push_back(conn.second); + } + + return inputs; +} + +Yosys::RTLIL::SigSpec BooleanGateImporter::getOutput( + Yosys::RTLIL::Cell *cell) const { + return cell->getPort(Yosys::RTLIL::IdString("\\Y")); +} + +} // namespace heir +} // namespace mlir diff --git a/lib/Transforms/YosysOptimizer/BooleanGateImporter.h b/lib/Transforms/YosysOptimizer/BooleanGateImporter.h new file mode 100644 index 000000000..b527a397b --- /dev/null +++ b/lib/Transforms/YosysOptimizer/BooleanGateImporter.h @@ -0,0 +1,37 @@ +#ifndef THIRD_PARTY_HEIR_LIB_TRANSFORMS_YOSYSOPTIMIZER_BOOLEANGATEIMPORTER_H_ +#define THIRD_PARTY_HEIR_LIB_TRANSFORMS_YOSYSOPTIMIZER_BOOLEANGATEIMPORTER_H_ + +#include "lib/Transforms/YosysOptimizer/RTLILImporter.h" +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project + +// Block clang-format from reordering +// clang-format off +#include "kernel/rtlil.h" // from @at_clifford_yosys +// clang-format on +namespace mlir { +namespace heir { + +// BooleanGateImporter implements the RTLILConfig for importing RTLIL that uses +// boolean gates. +class BooleanGateImporter : public RTLILImporter { + public: + BooleanGateImporter(MLIRContext *context) : RTLILImporter(context) {} + + protected: + Operation *createOp(Yosys::RTLIL::Cell *cell, SmallVector &inputs, + ImplicitLocOpBuilder &b) const override; + + SmallVector getInputs( + Yosys::RTLIL::Cell *cell) const override; + + Yosys::RTLIL::SigSpec getOutput(Yosys::RTLIL::Cell *cell) const override; +}; + +} // namespace heir +} // namespace mlir + +#endif // THIRD_PARTY_HEIR_LIB_TRANSFORMS_YOSYSOPTIMIZER_BOOLEANGATEIMPORTER_H_ diff --git a/lib/Transforms/YosysOptimizer/LUTImporter.cpp b/lib/Transforms/YosysOptimizer/LUTImporter.cpp index b4e9fb572..511e457b2 100644 --- a/lib/Transforms/YosysOptimizer/LUTImporter.cpp +++ b/lib/Transforms/YosysOptimizer/LUTImporter.cpp @@ -3,19 +3,23 @@ #include #include "include/Dialect/Comb/IR/CombOps.h" -#include "kernel/rtlil.h" // from @at_clifford_yosys -#include "llvm/include/llvm/ADT/ArrayRef.h" // from @llvm-project +#include "llvm/include/llvm/ADT/ArrayRef.h" // from @llvm-project #include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/include/mlir/IR/Operation.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +// Block clang-format from reordering +// clang-format off +#include "kernel/rtlil.h" // from @at_clifford_yosys +// clang-format on + namespace mlir { namespace heir { mlir::Operation *LUTImporter::createOp(Yosys::RTLIL::Cell *cell, - SmallVector &inputs, + SmallVector &inputs, ImplicitLocOpBuilder &b) const { assert(cell->type.begins_with("\\lut")); @@ -36,7 +40,7 @@ mlir::Operation *LUTImporter::createOp(Yosys::RTLIL::Cell *cell, return b.create(inputs, lookupTable); } -SmallVector LUTImporter::getInputs( +SmallVector LUTImporter::getInputs( Yosys::RTLIL::Cell *cell) const { assert(cell->type.begins_with("\\lut") && "expected lut cells"); diff --git a/lib/Transforms/YosysOptimizer/LUTImporter.h b/lib/Transforms/YosysOptimizer/LUTImporter.h index bf5eef253..b829bf8fa 100644 --- a/lib/Transforms/YosysOptimizer/LUTImporter.h +++ b/lib/Transforms/YosysOptimizer/LUTImporter.h @@ -1,7 +1,6 @@ #ifndef HEIR_LIB_TRANSFORMS_YOSYSOPTIMIZER_LUTIMPORTER_H_ #define HEIR_LIB_TRANSFORMS_YOSYSOPTIMIZER_LUTIMPORTER_H_ -#include "kernel/rtlil.h" // from @at_clifford_yosys #include "lib/Transforms/YosysOptimizer/RTLILImporter.h" #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project @@ -9,6 +8,11 @@ #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +// Block clang-format from reordering +// clang-format off +#include "kernel/rtlil.h" // from @at_clifford_yosys +// clang-format on + namespace mlir { namespace heir { @@ -18,10 +22,10 @@ class LUTImporter : public RTLILImporter { LUTImporter(MLIRContext *context) : RTLILImporter(context) {} protected: - Operation *createOp(Yosys::RTLIL::Cell *cell, SmallVector &inputs, + Operation *createOp(Yosys::RTLIL::Cell *cell, SmallVector &inputs, ImplicitLocOpBuilder &b) const override; - SmallVector getInputs( + SmallVector getInputs( Yosys::RTLIL::Cell *cell) const override; Yosys::RTLIL::SigSpec getOutput(Yosys::RTLIL::Cell *cell) const override; diff --git a/lib/Transforms/YosysOptimizer/RTLILImporter.cpp b/lib/Transforms/YosysOptimizer/RTLILImporter.cpp index 57dd1a0c8..9705a4662 100644 --- a/lib/Transforms/YosysOptimizer/RTLILImporter.cpp +++ b/lib/Transforms/YosysOptimizer/RTLILImporter.cpp @@ -185,7 +185,7 @@ func::FuncOp RTLILImporter::importModule( "expected cell in RTLIL design"); auto *cell = module->cells_["\\" + cellName]; - SmallVector inputValues; + SmallVector inputValues; for (const auto &conn : getInputs(cell)) { inputValues.push_back(getBit(conn, b, retBitValues)); } diff --git a/lib/Transforms/YosysOptimizer/RTLILImporter.h b/lib/Transforms/YosysOptimizer/RTLILImporter.h index 12a1a9c85..349439642 100644 --- a/lib/Transforms/YosysOptimizer/RTLILImporter.h +++ b/lib/Transforms/YosysOptimizer/RTLILImporter.h @@ -1,9 +1,8 @@ #ifndef HEIR_LIB_TRANSFORMS_YOSYSOPTIMIZER_RTLILIMPORTER_H_ #define HEIR_LIB_TRANSFORMS_YOSYSOPTIMIZER_RTLILIMPORTER_H_ -#include "kernel/rtlil.h" // from @at_clifford_yosys -#include "llvm/include/llvm/ADT/MapVector.h" // from @llvm-project -#include "llvm/include/llvm/ADT/StringMap.h" // from @llvm-project +#include "llvm/include/llvm/ADT/MapVector.h" // from @llvm-project +#include "llvm/include/llvm/ADT/StringMap.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project @@ -11,6 +10,11 @@ #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +// Block clang-format from reordering +// clang-format off +#include "kernel/rtlil.h" // from @at_clifford_yosys +// clang-format on + namespace mlir { namespace heir { @@ -41,11 +45,11 @@ class RTLILImporter { protected: // cellToOp converts an RTLIL cell to an MLIR operation. virtual Operation *createOp(Yosys::RTLIL::Cell *cell, - SmallVector &inputs, + SmallVector &inputs, ImplicitLocOpBuilder &b) const = 0; // Returns a list of RTLIL cell inputs. - virtual SmallVector getInputs( + virtual SmallVector getInputs( Yosys::RTLIL::Cell *cell) const = 0; // Returns an RTLIL cell output. diff --git a/lib/Transforms/YosysOptimizer/YosysOptimizer.cpp b/lib/Transforms/YosysOptimizer/YosysOptimizer.cpp index 1addfcead..0a92b50ae 100644 --- a/lib/Transforms/YosysOptimizer/YosysOptimizer.cpp +++ b/lib/Transforms/YosysOptimizer/YosysOptimizer.cpp @@ -16,6 +16,7 @@ #include "include/Dialect/Secret/IR/SecretPatterns.h" #include "include/Dialect/Secret/IR/SecretTypes.h" #include "include/Target/Verilog/VerilogEmitter.h" +#include "lib/Transforms/YosysOptimizer/BooleanGateImporter.h" #include "lib/Transforms/YosysOptimizer/LUTImporter.h" #include "lib/Transforms/YosysOptimizer/RTLILImporter.h" #include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project @@ -66,7 +67,7 @@ using std::string; // $2: yosys runfiles // $3: abc path // $4: abc fast option -fast -constexpr std::string_view kYosysTemplate = R"( +constexpr std::string_view kYosysLutTemplate = R"( read_verilog {0}; hierarchy -check -top \{1}; proc; memory; stat; @@ -81,6 +82,24 @@ clean; stat; )"; +// $0: verilog filename +// $1: function name +// $2: abc path +// $3: yosys runfiles path +// $4: abc fast option -fast +constexpr std::string_view kYosysBooleanTemplate = R"( +read_verilog {0}; +hierarchy -check -top \{1}; +proc; memory; stat; +techmap -map {3}/techmap.v; opt; stat; +abc -exe {2} -liberty {3}/tfhe-rs_cells.liberty {4}; stat; +opt_clean -purge; stat; +rename -hide */c:*; rename -enumerate */c:*; +hierarchy -generate * o:Y i:*; opt; opt_clean -purge; +clean; +stat; +)"; + struct RelativeOptimizationStatistics { std::string originalOp; int64_t numArithOps; @@ -91,17 +110,13 @@ struct YosysOptimizer : public impl::YosysOptimizerBase { using YosysOptimizerBase::YosysOptimizerBase; YosysOptimizer(std::string yosysFilesPath, std::string abcPath, bool abcFast, - int unrollFactor, bool printStats) + int unrollFactor, Mode mode, bool printStats) : yosysFilesPath(std::move(yosysFilesPath)), abcPath(std::move(abcPath)), abcFast(abcFast), printStats(printStats), - unrollFactor(unrollFactor) {} - - void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); - } + unrollFactor(unrollFactor), + mode(mode) {} void runOnOperation() override; @@ -116,6 +131,7 @@ struct YosysOptimizer : public impl::YosysOptimizerBase { bool abcFast; bool printStats; int unrollFactor; + Mode mode; llvm::SmallVector optStatistics; }; @@ -183,7 +199,6 @@ LogicalResult convertOpOperands(secret::GenericOp op, func::FuncOp func, /// Convert a secret.generic's results from secret.secret> /// to secret.secret. -// genericOp has the original, func op has the memref's yosys optimized LogicalResult convertOpResults(secret::GenericOp op, SmallVector originalResultTy, DenseSet &castOps, @@ -195,7 +210,6 @@ LogicalResult convertOpResults(secret::GenericOp op, opResult.getType().cast(); IntegerType elementType; - int numElements = 1; if (MemRefType convertedType = dyn_cast(secretType.getValueType())) { if (!convertedType.getElementType().isa() || @@ -206,7 +220,6 @@ LogicalResult convertOpResults(secret::GenericOp op, return failure(); } elementType = convertedType.getElementType().cast(); - numElements = convertedType.getNumElements(); } else { elementType = secretType.getValueType().cast(); } @@ -388,9 +401,22 @@ LogicalResult YosysOptimizer::runOnGenericOp(secret::GenericOp op) { // Invoke Yosys to translate to a combinational circuit and optimize. Yosys::log_error_stderr = true; LLVM_DEBUG(Yosys::log_streams.push_back(&std::cout)); - Yosys::run_pass(llvm::formatv(kYosysTemplate.data(), filename, moduleName, - yosysFilesPath, abcPath, - abcFast ? "-fast" : "")); + + LLVM_DEBUG( + llvm::dbgs() << "Using " + << (mode == Mode::LUT ? "LUT cells" : "boolean gates")); + auto yosysTemplate = + llvm::formatv(kYosysLutTemplate.data(), filename, moduleName, + yosysFilesPath, abcPath, abcFast ? "-fast" : "") + .str(); + if (mode == Mode::Boolean) { + std::cout << yosysFilesPath << std::endl; + yosysTemplate = + llvm::formatv(kYosysBooleanTemplate.data(), filename, moduleName, + abcPath, yosysFilesPath, abcFast ? "-fast" : "") + .str(); + } + Yosys::run_pass(yosysTemplate); // Translate Yosys result back to MLIR and insert into the func LLVM_DEBUG(Yosys::run_pass("dump;")); @@ -399,7 +425,6 @@ LogicalResult YosysOptimizer::runOnGenericOp(secret::GenericOp op) { Yosys::run_pass("torder -stop * P*;"); Yosys::log_streams.clear(); auto topologicalOrder = getTopologicalOrder(cellOrder); - LUTImporter lutImporter = LUTImporter(context); Yosys::RTLIL::Design *design = Yosys::yosys_get_design(); auto numCells = design->top_module()->cells().size(); totalCircuitSize += numCells; @@ -408,9 +433,14 @@ LogicalResult YosysOptimizer::runOnGenericOp(secret::GenericOp op) { } LLVM_DEBUG(llvm::dbgs() << "Importing RTLIL module\n"); - + std::unique_ptr importer; + if (mode == Mode::LUT) { + importer = std::make_unique(context); + } else { + importer = std::make_unique(context); + } func::FuncOp func = - lutImporter.importModule(design->top_module(), topologicalOrder); + importer->importModule(design->top_module(), topologicalOrder); Yosys::run_pass("delete;"); LLVM_DEBUG(llvm::dbgs() << "Done importing RTLIL, now type-coverting ops\n"); @@ -554,9 +584,9 @@ void YosysOptimizer::runOnOperation() { std::unique_ptr createYosysOptimizer( const std::string &yosysFilesPath, const std::string &abcPath, bool abcFast, - int unrollFactor, bool printStats) { + int unrollFactor, Mode mode, bool printStats) { return std::make_unique(yosysFilesPath, abcPath, abcFast, - unrollFactor, printStats); + unrollFactor, mode, printStats); } void registerYosysOptimizerPipeline(const std::string &yosysFilesPath, @@ -567,7 +597,7 @@ void registerYosysOptimizerPipeline(const std::string &yosysFilesPath, const YosysOptimizerPipelineOptions &options) { pm.addPass(createYosysOptimizer(yosysFilesPath, abcPath, options.abcFast, options.unrollFactor, - options.printStats)); + options.mode, options.printStats)); pm.addPass(mlir::createCSEPass()); }); } diff --git a/lib/Transforms/YosysOptimizer/yosys/BUILD b/lib/Transforms/YosysOptimizer/yosys/BUILD index e89c5930c..f12d811d9 100644 --- a/lib/Transforms/YosysOptimizer/yosys/BUILD +++ b/lib/Transforms/YosysOptimizer/yosys/BUILD @@ -9,5 +9,6 @@ filegroup( name = "share_files", srcs = glob([ "*.v", + "*.liberty", ]), ) diff --git a/lib/Transforms/YosysOptimizer/yosys/tfhe-rs_cells.liberty b/lib/Transforms/YosysOptimizer/yosys/tfhe-rs_cells.liberty new file mode 100644 index 000000000..4fa43fca3 --- /dev/null +++ b/lib/Transforms/YosysOptimizer/yosys/tfhe-rs_cells.liberty @@ -0,0 +1,260 @@ +/********************************************/ +/* */ +/* Supergate cell library for Bench marking */ +/* */ +/* Symbiotic EDA GmbH / Moseley Instruments */ +/* Niels A. Moseley */ +/* */ +/* Process: none */ +/* */ +/* Date : 02-11-2018 */ +/* Version: 1.0 */ +/* */ +/********************************************/ + +library(supergate) { + delay_model : table_lookup; + time_unit : "1ns"; + + /* Inverter */ + cell(inv) { + area : 0; + pin(A) { + direction : input; + } + + pin(Y) { + direction : output; + function : "A'"; + timing() { + related_pin : "A"; + timing_sense : negative_unate; + cell_rise(scalar) { + values("0.0"); + } + cell_fall(scalar) { + values("0.0"); + } + rise_transition(scalar) { + values("0.0"); + } + fall_transition(scalar) { + values("0.0"); + } + } + } + } + + cell(buffer) { + area : 0; + pin(A) { + direction : input; + } + pin(Y) { + direction : output; + function : "A"; + timing() { + related_pin : "A"; + timing_sense : positive_unate; + cell_rise(scalar) { + values("0.0"); + } + cell_fall(scalar) { + values("0.0"); + } + rise_transition(scalar) { + values("0.0"); + } + fall_transition(scalar) { + values("0.0"); + } + } + } + } + + /* 2-input AND gate */ + cell(and2) { + area : 100; + pin(A) { + direction : input; + } + pin(B) { + direction : input; + } + pin(Y) { + direction: output; + function : "(A * B)"; + timing() { + related_pin : "A B"; + timing_sense : positive_unate; + cell_rise(scalar) { + values("1000.0"); + } + cell_fall(scalar) { + values("1000.0"); + } + rise_transition(scalar) { + values("1000.0"); + } + fall_transition(scalar) { + values("1000.0"); + } + } + } + } + + /* 2-input NAND gate */ + cell(nand2) { + area : 100; + pin(A) { + direction : input; + } + pin(B) { + direction : input; + } + pin(Y) { + direction: output; + function : "(A * B)'"; + timing() { + related_pin : "A B"; + timing_sense : negative_unate; + cell_rise(scalar) { + values("1000.0"); + } + cell_fall(scalar) { + values("1000.0"); + } + rise_transition(scalar) { + values("1000.0"); + } + fall_transition(scalar) { + values("1000.0"); + } + } + } + } + + /* 2-input OR gate */ + cell(or2) { + area : 100; + pin(A) { + direction : input; + } + pin(B) { + direction : input; + } + pin(Y) { + direction: output; + function : "(A + B)"; + timing() { + related_pin : "A B"; + timing_sense : positive_unate; + cell_rise(scalar) { + values("1000.0"); + } + cell_fall(scalar) { + values("1000.0"); + } + rise_transition(scalar) { + values("1000.0"); + } + fall_transition(scalar) { + values("1000.0"); + } + } + } + } + + /* 2-input NOR gate */ + cell(nor2) { + area : 100; + pin(A) { + direction : input; + } + pin(B) { + direction : input; + } + pin(Y) { + direction: output; + function : "(A + B)'"; + timing() { + related_pin : "A B"; + timing_sense : negative_unate; + cell_rise(scalar) { + values("1000.0"); + } + cell_fall(scalar) { + values("1000.0"); + } + rise_transition(scalar) { + values("1000.0"); + } + fall_transition(scalar) { + values("1000.0"); + } + } + } + } + + /* 2-input XOR */ + cell(xor2) { + area : 100; + pin(A) { + direction : input; + } + pin(B) { + direction : input; + } + pin(Y) { + direction: output; + function : "(A * (B')) + ((A') * B)"; + timing() { + related_pin : "A B"; + timing_sense : non_unate; + cell_rise(scalar) { + values("1000.0"); + } + cell_fall(scalar) { + values("1000.0"); + } + rise_transition(scalar) { + values("1000.0"); + } + fall_transition(scalar) { + values("1000.0"); + } + } + } + } + + /* 2-input XNOR */ + cell(xnor2) { + area : 100; + pin(A) { + direction : input; + } + pin(B) { + direction : input; + } + pin(Y) { + direction: output; + function : "((A * (B')) + ((A') * B))'"; + timing() { + related_pin : "A B"; + timing_sense : non_unate; + cell_rise(scalar) { + values("1000.0"); + } + cell_fall(scalar) { + values("1000.0"); + } + rise_transition(scalar) { + values("1000.0"); + } + fall_transition(scalar) { + values("1000.0"); + } + } + } + } +} /* end */ diff --git a/tests/cggi/straight_line_vectorizer.mlir b/tests/cggi/straight_line_vectorizer.mlir index 490d6dd94..f466b64c0 100644 --- a/tests/cggi/straight_line_vectorizer.mlir +++ b/tests/cggi/straight_line_vectorizer.mlir @@ -1,4 +1,5 @@ -// RUN: heir-opt --straight-line-vectorize %s | FileCheck %s +// TODO(#519): disable FileChecks until nondeterminism issues are resolved +// RUN: heir-opt --straight-line-vectorize %s #encoding = #lwe.unspecified_bit_field_encoding !ct_ty = !lwe.lwe_ciphertext diff --git a/tests/polynomial/elementwise.mlir b/tests/polynomial/elementwise.mlir deleted file mode 100644 index 0720672b8..000000000 --- a/tests/polynomial/elementwise.mlir +++ /dev/null @@ -1,10 +0,0 @@ -// RUN: heir-opt --convert-elementwise-to-linalg --one-shot-bufferize --convert-linalg-to-affine-loops %s | FileCheck %s - -!poly = !polynomial.polynomial<>> - -// CHECK-LABEL: @test_bin_ops -// CHECK: affine.for -func.func @test_bin_ops(%arg0: tensor<2x!poly>, %arg1: tensor<2x!poly>) -> tensor<2x!poly> { - %0 = polynomial.add(%arg0, %arg1) : tensor<2x!poly> - return %0 : tensor<2x!poly> -} diff --git a/tests/polynomial/elementwise_to_affine.mlir b/tests/polynomial/elementwise_to_affine.mlir new file mode 100644 index 000000000..27780257d --- /dev/null +++ b/tests/polynomial/elementwise_to_affine.mlir @@ -0,0 +1,90 @@ +// RUN: heir-opt --convert-elementwise-to-affine %s | FileCheck --enable-var-scope %s + +!poly = !polynomial.polynomial<>> + +// CHECK-LABEL: @test_elementwise +// CHECK: {{.*}} -> [[T:tensor<2x!polynomial.*33538049.*]] { +func.func @test_elementwise(%arg0: tensor<2x!poly>, %arg1: tensor<2x!poly>) -> tensor<2x!poly> { + // CHECK: [[EMPTY:%.+]] = tensor.empty() : [[T:tensor<2x!polynomial.*33538049.*]] + // CHECK: [[LOOP:%.+]] = affine.for [[I:%.+]] = 0 to 2 iter_args([[T0:%.+]] = [[EMPTY]]) -> ([[T]]) { + // CHECK: [[A:%.+]] = tensor.extract %arg0[[[I]]] : [[T]] + // CHECK: [[B:%.+]] = tensor.extract %arg1[[[I]]] : [[T]] + // CHECK: [[S:%.+]] = polynomial.add([[A]], [[B]]) : [[P:!polynomial.*33538049.*]] + // CHECK-NOT: polynomial.add{{.*}} : [[T]] + // CHECK: [[R:%.+]] = tensor.insert [[S]] into [[T0]][[[I]]] : [[T]] + // CHECK: affine.yield [[R]] : [[T]] + %0 = polynomial.add(%arg0, %arg1) : tensor<2x!poly> + // CHECK: return [[LOOP]] : [[T]] + return %0 : tensor<2x!poly> +} + +// This is just here to make sure the FileCheck commands above work as expected +// CHECK-LABEL: @lowered_elementwise +// CHECK: {{.*}} -> [[T:tensor<2x!polynomial.*33538049.*]] { +func.func @lowered_elementwise(%arg0: tensor<2x!poly>, %arg1: tensor<2x!poly>) -> tensor<2x!poly> { + // CHECK: [[EMPTY:%.+]] = tensor.empty() : [[T:tensor<2x!polynomial.*33538049.*]] + %empty = tensor.empty() : tensor<2x!poly> + // CHECK: [[LOOP:%.+]] = affine.for [[I:%.+]] = 0 to 2 iter_args([[T0:%.+]] = [[EMPTY]]) -> ([[T]]) { + %0 = affine.for %i = 0 to 2 iter_args(%t0 = %empty) -> (tensor<2x!poly>) { + // CHECK: [[A:%.+]] = tensor.extract %arg0[[[I]]] : [[T]] + %a = tensor.extract %arg0[%i] : tensor<2x!poly> + // CHECK: [[B:%.+]] = tensor.extract %arg1[[[I]]] : [[T]] + %b = tensor.extract %arg1[%i] : tensor<2x!poly> + // CHECK: [[S:%.+]] = polynomial.add([[A]], [[B]]) : [[P:!polynomial.*33538049.*]] + // CHECK-NOT: polynomial.add{{.*}} : [[T]] + %s = polynomial.add(%a, %b) : !poly + // CHECK: [[R:%.+]] = tensor.insert [[S]] into [[T0]][[[I]]] : [[T]] + %r = tensor.insert %s into %t0[%i] : tensor<2x!poly> + // CHECK: affine.yield [[R]] : [[T]] + affine.yield %r : tensor<2x!poly> + } + // CHECK: return [[LOOP]] : [[T]] + return %0 : tensor<2x!poly> +} + +// CHECK-LABEL: @test_elementwise_multidim +// CHECK: {{.*}} -> [[T:tensor<2x3x!polynomial.*33538049.*]] { +func.func @test_elementwise_multidim(%arg0: tensor<2x3x!poly>, %arg1: tensor<2x3x!poly>) -> tensor<2x3x!poly> { + %0 = polynomial.add(%arg0, %arg1) : tensor<2x3x!poly> + return %0 : tensor<2x3x!poly> + // CHECK: [[EMPTY:%.+]] = tensor.empty() : [[T:tensor<2x3x!polynomial.*33538049.*]] + // CHECK: [[LOOP:%.+]] = affine.for [[I:%.+]] = 0 to 2 iter_args([[T0:%.+]] = [[EMPTY]]) -> ([[T]]) { + // CHECK: [[INNERLOOP:%.+]] = affine.for [[J:%.+]] = 0 to 3 iter_args([[T1:%.+]] = [[T0]]) -> ([[T]]) { + // CHECK: [[A:%.+]] = tensor.extract %arg0[[[I]], [[J]]] : [[T]] + // CHECK: [[B:%.+]] = tensor.extract %arg1[[[I]], [[J]]] : [[T]] + // CHECK: [[S:%.+]] = polynomial.add([[A]], [[B]]) : [[P:!polynomial.*33538049.*]] + // CHECK-NOT: polynomial.add{{.*}} : [[T]] + // CHECK: [[R:%.+]] = tensor.insert [[S]] into [[T1]][[[I]], [[J]]] : [[T]] + // CHECK: affine.yield [[R]] : [[T]] + // CHECK: affine.yield [[INNERLOOP]] : [[T]] + // CHECK: return [[LOOP]] : [[T]] +} + +// This is just here to make sure the FileCheck commands above work as expected +// CHECK-LABEL: @lowered_elementwise_multidim +// CHECK: {{.*}} -> [[T:tensor<2x3x!polynomial.*33538049.*]] { +func.func @lowered_elementwise_multidim(%arg0: tensor<2x3x!poly>, %arg1: tensor<2x3x!poly>) -> tensor<2x3x!poly> { + // CHECK: [[EMPTY:%.+]] = tensor.empty() : [[T:tensor<2x3x!polynomial.*33538049.*]] + %empty = tensor.empty() : tensor<2x3x!poly> + // CHECK: [[LOOP:%.+]] = affine.for [[I:%.+]] = 0 to 2 iter_args([[T0:%.+]] = [[EMPTY]]) -> ([[T]]) { + %0 = affine.for %i = 0 to 2 iter_args(%t0 = %empty) -> (tensor<2x3x!poly>) { + // CHECK: [[INNERLOOP:%.+]] = affine.for [[J:%.+]] = 0 to 3 iter_args([[T1:%.+]] = [[T0]]) -> ([[T]]) { + %1 = affine.for %j = 0 to 3 iter_args(%t1 = %t0) -> (tensor<2x3x!poly>) { + // CHECK: [[A:%.+]] = tensor.extract %arg0[[[I]], [[J]]] : [[T]] + %a = tensor.extract %arg0[%i, %j] : tensor<2x3x!poly> + // CHECK: [[B:%.+]] = tensor.extract %arg1[[[I]], [[J]]] : [[T]] + %b = tensor.extract %arg1[%i, %j] : tensor<2x3x!poly> + // CHECK: [[S:%.+]] = polynomial.add([[A]], [[B]]) : [[P:!polynomial.*33538049.*]] + // CHECK-NOT: polynomial.add{{.*}} : [[T]] + %s = polynomial.add(%a, %b) : !poly + // CHECK: [[R:%.+]] = tensor.insert [[S]] into [[T1]][[[I]], [[J]]] : [[T]] + %r = tensor.insert %s into %t1[%i, %j] : tensor<2x3x!poly> + // CHECK: affine.yield [[R]] : [[T]] + affine.yield %r : tensor<2x3x!poly> + } + // CHECK: affine.yield [[INNERLOOP]] : [[T]] + affine.yield %1 : tensor<2x3x!poly> + } + // CHECK: return [[LOOP]] : [[T]] + return %0 : tensor<2x3x!poly> +} diff --git a/tests/polynomial/lower_add.mlir b/tests/polynomial/lower_add.mlir index 0d9eb74d9..2495ba14b 100644 --- a/tests/polynomial/lower_add.mlir +++ b/tests/polynomial/lower_add.mlir @@ -4,6 +4,8 @@ #ring = #polynomial.ring #ring_prime = #polynomial.ring + +// CHECK-LABEL: @test_lower_add_power_of_two_cmod func.func @test_lower_add_power_of_two_cmod() -> !polynomial.polynomial<#ring> { // 2 + 2x + 2x^2 + ... + 2x^{1023} // CHECK: [[X:%.+]] = arith.constant dense<2> : [[T:tensor<1024xi32>]] @@ -19,6 +21,7 @@ func.func @test_lower_add_power_of_two_cmod() -> !polynomial.polynomial<#ring> { return %poly2 : !polynomial.polynomial<#ring> } +// CHECK-LABEL: @test_lower_add_prime_cmod func.func @test_lower_add_prime_cmod() -> !polynomial.polynomial<#ring_prime> { // CHECK: [[X:%.+]] = arith.constant dense<2> : [[TCOEFF:tensor<1024xi31>]] %coeffs1 = arith.constant dense<2> : tensor<1024xi31> @@ -41,3 +44,52 @@ func.func @test_lower_add_prime_cmod() -> !polynomial.polynomial<#ring_prime> { // CHECK: return [[TRUNC_RESULT]] : [[T]] return %poly2 : !polynomial.polynomial<#ring_prime> } + +// CHECK-LABEL: @test_lower_add_tensor +func.func @test_lower_add_tensor() -> tensor<2x!polynomial.polynomial<#ring>> { + // 2 + 2x + 2x^2 + ... + 2x^{1023} + // CHECK-DAG: [[A:%.+]] = arith.constant dense<2> : [[T:tensor<1024xi32>]] + %coeffsA = arith.constant dense<2> : tensor<1024xi32> + // CHECK-DAG: [[B:%.+]] = arith.constant dense<3> : [[T]] + %coeffsB = arith.constant dense<3> : tensor<1024xi32> + // CHECK-DAG: [[C:%.+]] = arith.constant dense<4> : [[T]] + %coeffsC = arith.constant dense<4> : tensor<1024xi32> + // CHECK-DAG: [[D:%.+]] = arith.constant dense<5> : [[T]] + %coeffsD = arith.constant dense<5> : tensor<1024xi32> + %polyA = polynomial.from_tensor %coeffsA : tensor<1024xi32> -> !polynomial.polynomial<#ring> + %polyB = polynomial.from_tensor %coeffsB : tensor<1024xi32> -> !polynomial.polynomial<#ring> + %polyC = polynomial.from_tensor %coeffsC : tensor<1024xi32> -> !polynomial.polynomial<#ring> + %polyD = polynomial.from_tensor %coeffsD : tensor<1024xi32> -> !polynomial.polynomial<#ring> + %tensor1 = tensor.from_elements %polyA, %polyB : tensor<2x!polynomial.polynomial<#ring>> + %tensor2 = tensor.from_elements %polyC, %polyD : tensor<2x!polynomial.polynomial<#ring>> + // CHECK: [[S1:%.+]] = arith.constant dense<[1, 1024]> : [[TI:tensor<2xindex>]] + // CHECK: [[T1:%.+]] = tensor.reshape [[A]]([[S1]]) : ([[T]], [[TI]]) -> [[TEX:tensor<1x1024xi32>]] + // CHECK: [[S2:%.+]] = arith.constant dense<[1, 1024]> : [[TI]] + // CHECK: [[T2:%.+]] = tensor.reshape [[B]]([[S2]]) : ([[T]], [[TI]]) -> [[TEX]] + // CHECK: [[C1:%.+]] = tensor.concat dim(0) [[T1]], [[T2]] : ([[TEX]], [[TEX]]) -> [[TT:tensor<2x1024xi32>]] + // CHECK: [[S3:%.+]] = arith.constant dense<[1, 1024]> : [[TI]] + // CHECK: [[T3:%.+]] = tensor.reshape [[C]]([[S3]]) : ([[T]], [[TI]]) -> [[TEX]] + // CHECK: [[S4:%.+]] = arith.constant dense<[1, 1024]> : [[TI]] + // CHECK: [[T4:%.+]] = tensor.reshape [[D]]([[S4]]) : ([[T]], [[TI]]) -> [[TEX]] + // CHECK: [[C2:%.+]] = tensor.concat dim(0) [[T3]], [[T4]] : ([[TEX]], [[TEX]]) -> [[TT:tensor<2x1024xi32>]] + // CHECK-NOT: polynomial.from_tensor + // CHECK-NOT: tensor.from_elements + %tensor3 = affine.for %i = 0 to 2 iter_args(%t0 = %tensor1) -> tensor<2x!polynomial.polynomial<#ring>> { + // CHECK: [[FOR:%.]] = affine.for [[I:%.+]] = 0 to 2 iter_args([[T0:%.+]] = [[C1]]) -> ([[TT]]) { + %a = tensor.extract %tensor1[%i] : tensor<2x!polynomial.polynomial<#ring>> + %b = tensor.extract %tensor2[%i] : tensor<2x!polynomial.polynomial<#ring>> + // CHECK: [[AA:%.+]] = tensor.extract_slice [[C1]][[[I]], 0] [1, 1024] [1, 1] : [[TT]] + // CHECK: [[BB:%.+]] = tensor.extract_slice [[C2]][[[I]], 0] [1, 1024] [1, 1] : [[TT]] + // CHECK-NOT: tensor.extract % + %s = polynomial.add(%a, %b) : !polynomial.polynomial<#ring> + // CHECK: [[SUM:%.+]] = arith.addi [[AA]], [[BB]] : [[T]] + // CHECK-NOT: polynomial.add + %t = tensor.insert %s into %t0[%i] : tensor<2x!polynomial.polynomial<#ring>> + // CHECK: [[INS:%.+]] = tensor.insert_slice [[SUM]] into [[T0]][[[I]], 0] [1, 1024] [1, 1] : [[T]] into [[TT]] + // CHECK-NOT: tensor.insert % + affine.yield %t : tensor<2x!polynomial.polynomial<#ring>> + // CHECK: affine.yield [[INS]] : [[TT]] + } + return %tensor3 : tensor<2x!polynomial.polynomial<#ring>> + // CHECK: return [[FOR]] : [[TT]] +} diff --git a/tests/secretize/BUILD b/tests/secretize/BUILD index 6c9032391..c571e6fc6 100644 --- a/tests/secretize/BUILD +++ b/tests/secretize/BUILD @@ -1,9 +1,6 @@ load("//bazel:lit.bzl", "glob_lit_tests") -package( - default_applicable_licenses = ["@heir//:license"], - default_visibility = ["//visibility:public"], -) +package(default_applicable_licenses = ["@heir//:license"]) glob_lit_tests( name = "all_tests", diff --git a/tests/tensor_ext/rotate_and_reduce.mlir b/tests/tensor_ext/rotate_and_reduce.mlir new file mode 100644 index 000000000..126f70577 --- /dev/null +++ b/tests/tensor_ext/rotate_and_reduce.mlir @@ -0,0 +1,307 @@ +// RUN: heir-opt --rotate-and-reduce --canonicalize %s | FileCheck %s + + +// Sum all entries of a tensor into a single scalar +// CHECK-LABEL: @simple_sum +// CHECK-SAME: (%[[arg0:.*]]: tensor<8xi32> +// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 +// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 +// CHECK-NEXT: %[[c2:.*]] = arith.constant 2 +// CHECK-NEXT: %[[c4:.*]] = arith.constant 4 +// CHECK-NEXT: %[[v0:.*]] = tensor_ext.rotate %[[arg0]], %[[c4]] +// CHECK-NEXT: %[[v1:.*]] = arith.addi %[[arg0]], %[[v0]] +// CHECK-NEXT: %[[v2:.*]] = tensor_ext.rotate %[[v1]], %[[c2]] +// CHECK-NEXT: %[[v3:.*]] = arith.addi %[[v1]], %[[v2]] +// CHECK-NEXT: %[[v4:.*]] = tensor_ext.rotate %[[v3]], %[[c1]] +// CHECK-NEXT: %[[v5:.*]] = arith.addi %[[v3]], %[[v4]] +// CHECK-NEXT: %[[v6:.*]] = tensor.extract %[[v5]][%[[c0]]] +// CHECK-NEXT: return %[[v6]] +func.func @simple_sum(%arg0: tensor<8xi32>) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %0 = tensor.extract %arg0[%c0] : tensor<8xi32> + %1 = tensor.extract %arg0[%c1] : tensor<8xi32> + %2 = tensor.extract %arg0[%c2] : tensor<8xi32> + %3 = tensor.extract %arg0[%c3] : tensor<8xi32> + %4 = tensor.extract %arg0[%c4] : tensor<8xi32> + %5 = tensor.extract %arg0[%c5] : tensor<8xi32> + %6 = tensor.extract %arg0[%c6] : tensor<8xi32> + %7 = tensor.extract %arg0[%c7] : tensor<8xi32> + %8 = arith.addi %0, %1 : i32 + %9 = arith.addi %8, %2 : i32 + %10 = arith.addi %9, %3 : i32 + %11 = arith.addi %10, %4 : i32 + %12 = arith.addi %11, %5 : i32 + %13 = arith.addi %12, %6 : i32 + %14 = arith.addi %13, %7 : i32 + return %14 : i32 +} + +// CHECK-LABEL: @not_supported_mixed_ops +// CHECK-NOT: tensor_ext.rotate +func.func @not_supported_mixed_ops(%arg0: tensor<8xi32>) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %0 = tensor.extract %arg0[%c0] : tensor<8xi32> + %1 = tensor.extract %arg0[%c1] : tensor<8xi32> + %2 = tensor.extract %arg0[%c2] : tensor<8xi32> + %3 = tensor.extract %arg0[%c3] : tensor<8xi32> + %4 = tensor.extract %arg0[%c4] : tensor<8xi32> + %5 = tensor.extract %arg0[%c5] : tensor<8xi32> + %6 = tensor.extract %arg0[%c6] : tensor<8xi32> + %7 = tensor.extract %arg0[%c7] : tensor<8xi32> + %8 = arith.addi %0, %1 : i32 + %9 = arith.addi %8, %2 : i32 + %10 = arith.addi %9, %3 : i32 + %11 = arith.muli %10, %4 : i32 + %12 = arith.addi %11, %5 : i32 + %13 = arith.addi %12, %6 : i32 + %14 = arith.addi %13, %7 : i32 + return %14 : i32 +} + +// CHECK-LABEL: @not_supported_missing_indices +// CHECK-NOT: tensor_ext.rotate +func.func @not_supported_missing_indices(%arg0: tensor<16xi32>) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %c8 = arith.constant 8 : index + %c9 = arith.constant 9 : index + %c10 = arith.constant 10 : index + %c11 = arith.constant 11 : index + %c12 = arith.constant 12 : index + %c13 = arith.constant 13 : index + %c14 = arith.constant 14 : index + %0 = tensor.extract %arg0[%c0] : tensor<16xi32> + %1 = tensor.extract %arg0[%c1] : tensor<16xi32> + %2 = tensor.extract %arg0[%c2] : tensor<16xi32> + %3 = tensor.extract %arg0[%c3] : tensor<16xi32> + %4 = tensor.extract %arg0[%c4] : tensor<16xi32> + %5 = tensor.extract %arg0[%c5] : tensor<16xi32> + %6 = tensor.extract %arg0[%c6] : tensor<16xi32> + %7 = tensor.extract %arg0[%c7] : tensor<16xi32> + %8 = tensor.extract %arg0[%c8] : tensor<16xi32> + %9 = tensor.extract %arg0[%c9] : tensor<16xi32> + %10 = tensor.extract %arg0[%c10] : tensor<16xi32> + %11 = tensor.extract %arg0[%c11] : tensor<16xi32> + %12 = tensor.extract %arg0[%c12] : tensor<16xi32> + %13 = tensor.extract %arg0[%c13] : tensor<16xi32> + %14 = tensor.extract %arg0[%c14] : tensor<16xi32> + // missing element 15 + %v1 = arith.addi %0, %1 : i32 + %v2 = arith.addi %v1, %2 : i32 + %v3 = arith.addi %v2, %3 : i32 + %v4 = arith.addi %v3, %4 : i32 + %v5 = arith.addi %v4, %5 : i32 + %v6 = arith.addi %v5, %6 : i32 + %v7 = arith.addi %v6, %7 : i32 + %v8 = arith.addi %v7, %8 : i32 + %v9 = arith.addi %v8, %9 : i32 + %v10 = arith.addi %v9, %10 : i32 + %v11 = arith.addi %v10, %11 : i32 + %v12 = arith.addi %v11, %12 : i32 + %v13 = arith.addi %v12, %13 : i32 + %v14 = arith.addi %v13, %14 : i32 + return %v14 : i32 +} + +// CHECK-LABEL: @not_supported_repeated_indices +// CHECK-NOT: tensor_ext.rotate +func.func @not_supported_repeated_indices(%arg0: tensor<8xi32>) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %0 = tensor.extract %arg0[%c0] : tensor<8xi32> + %1 = tensor.extract %arg0[%c1] : tensor<8xi32> + %2 = tensor.extract %arg0[%c2] : tensor<8xi32> + %3 = tensor.extract %arg0[%c3] : tensor<8xi32> + // repeats element 3 + %4 = tensor.extract %arg0[%c3] : tensor<8xi32> + %5 = tensor.extract %arg0[%c5] : tensor<8xi32> + %6 = tensor.extract %arg0[%c6] : tensor<8xi32> + %7 = tensor.extract %arg0[%c7] : tensor<8xi32> + %8 = arith.addi %0, %1 : i32 + %9 = arith.addi %8, %2 : i32 + %10 = arith.addi %9, %3 : i32 + %11 = arith.addi %10, %4 : i32 + %12 = arith.addi %11, %5 : i32 + %13 = arith.addi %12, %6 : i32 + %14 = arith.addi %13, %7 : i32 + return %14 : i32 +} + +// CHECK-LABEL: @not_supported_unsupported_op +// CHECK-NOT: tensor_ext.rotate +func.func @not_supported_unsupported_op(%arg0: tensor<8xi32>) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %0 = tensor.extract %arg0[%c0] : tensor<8xi32> + %1 = tensor.extract %arg0[%c1] : tensor<8xi32> + %2 = tensor.extract %arg0[%c2] : tensor<8xi32> + %3 = tensor.extract %arg0[%c3] : tensor<8xi32> + %4 = tensor.extract %arg0[%c4] : tensor<8xi32> + %5 = tensor.extract %arg0[%c5] : tensor<8xi32> + %6 = tensor.extract %arg0[%c6] : tensor<8xi32> + %7 = tensor.extract %arg0[%c7] : tensor<8xi32> + %8 = arith.subi %0, %1 : i32 + %9 = arith.subi %8, %2 : i32 + %10 = arith.subi %9, %3 : i32 + %11 = arith.subi %10, %4 : i32 + %12 = arith.subi %11, %5 : i32 + %13 = arith.subi %12, %6 : i32 + %14 = arith.subi %13, %7 : i32 + return %14 : i32 +} + +// 2D tensor not supported +// CHECK-LABEL: @not_supported_bad_tensor_shape +// CHECK-NOT: tensor_ext.rotate +func.func @not_supported_bad_tensor_shape(%arg0: tensor<1x8xi32>) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %0 = tensor.extract %arg0[%c1, %c0] : tensor<1x8xi32> + %1 = tensor.extract %arg0[%c1, %c1] : tensor<1x8xi32> + %2 = tensor.extract %arg0[%c1, %c2] : tensor<1x8xi32> + %3 = tensor.extract %arg0[%c1, %c3] : tensor<1x8xi32> + %4 = tensor.extract %arg0[%c1, %c4] : tensor<1x8xi32> + %5 = tensor.extract %arg0[%c1, %c5] : tensor<1x8xi32> + %6 = tensor.extract %arg0[%c1, %c6] : tensor<1x8xi32> + %7 = tensor.extract %arg0[%c1, %c7] : tensor<1x8xi32> + %8 = arith.addi %0, %1 : i32 + %9 = arith.addi %8, %2 : i32 + %10 = arith.addi %9, %3 : i32 + %11 = arith.addi %10, %4 : i32 + %12 = arith.addi %11, %5 : i32 + %13 = arith.addi %12, %6 : i32 + %14 = arith.addi %13, %7 : i32 + return %14 : i32 +} + +// reducing from multiple input tensors +// CHECK-LABEL: @not_supported_multiple_tensors +// CHECK-NOT: tensor_ext.rotate +func.func @not_supported_multiple_tensors(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %0 = tensor.extract %arg0[%c0] : tensor<8xi32> + %1 = tensor.extract %arg0[%c1] : tensor<8xi32> + // uses %arg1 + %2 = tensor.extract %arg1[%c2] : tensor<8xi32> + %3 = tensor.extract %arg0[%c3] : tensor<8xi32> + %4 = tensor.extract %arg0[%c4] : tensor<8xi32> + %5 = tensor.extract %arg0[%c5] : tensor<8xi32> + %6 = tensor.extract %arg0[%c6] : tensor<8xi32> + %7 = tensor.extract %arg0[%c7] : tensor<8xi32> + %8 = arith.addi %0, %1 : i32 + %9 = arith.addi %8, %2 : i32 + %10 = arith.addi %9, %3 : i32 + %11 = arith.addi %10, %4 : i32 + %12 = arith.addi %11, %5 : i32 + %13 = arith.addi %12, %6 : i32 + %14 = arith.addi %13, %7 : i32 + return %14 : i32 +} + +// CHECK-LABEL: @not_supported_non_constant_index_access +// CHECK-NOT: tensor_ext.rotate +func.func @not_supported_non_constant_index_access(%arg0: tensor<8xi32>, %arg1: index) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %0 = tensor.extract %arg0[%c0] : tensor<8xi32> + %1 = tensor.extract %arg0[%c1] : tensor<8xi32> + // uses non-constant index + %2 = tensor.extract %arg0[%arg1] : tensor<8xi32> + %3 = tensor.extract %arg0[%c3] : tensor<8xi32> + %4 = tensor.extract %arg0[%c4] : tensor<8xi32> + %5 = tensor.extract %arg0[%c5] : tensor<8xi32> + %6 = tensor.extract %arg0[%c6] : tensor<8xi32> + %7 = tensor.extract %arg0[%c7] : tensor<8xi32> + %8 = arith.addi %0, %1 : i32 + %9 = arith.addi %8, %2 : i32 + %10 = arith.addi %9, %3 : i32 + %11 = arith.addi %10, %4 : i32 + %12 = arith.addi %11, %5 : i32 + %13 = arith.addi %12, %6 : i32 + %14 = arith.addi %13, %7 : i32 + return %14 : i32 +} + +// CHECK-LABEL: @not_supported_non_tensor_operands +// CHECK-NOT: tensor_ext.rotate +// TODO(#522): support this +func.func @not_supported_non_tensor_operands(%arg0: tensor<8xi32>) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %c2_i32 = arith.constant 2 : i32 + %0 = tensor.extract %arg0[%c0] : tensor<8xi32> + %1 = tensor.extract %arg0[%c1] : tensor<8xi32> + %2 = tensor.extract %arg0[%c2] : tensor<8xi32> + %3 = tensor.extract %arg0[%c3] : tensor<8xi32> + %4 = tensor.extract %arg0[%c4] : tensor<8xi32> + %5 = tensor.extract %arg0[%c5] : tensor<8xi32> + %6 = tensor.extract %arg0[%c6] : tensor<8xi32> + %7 = tensor.extract %arg0[%c7] : tensor<8xi32> + %8 = arith.addi %0, %1 : i32 + // next op uses non-tensor operand + %9 = arith.addi %8, %c2_i32 : i32 + %10 = arith.addi %9, %3 : i32 + %11 = arith.addi %10, %4 : i32 + %12 = arith.addi %11, %5 : i32 + %13 = arith.addi %12, %6 : i32 + %14 = arith.addi %13, %7 : i32 + %15 = arith.addi %14, %2 : i32 + return %15 : i32 +} diff --git a/tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs b/tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs index 02f61b2ef..a0ccb8f64 100644 --- a/tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs +++ b/tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs @@ -32,7 +32,6 @@ pub fn decrypt(ciphertexts: &Vec, client_key: &ClientKey) -> u8 { accum |= (bit as u8) << i; } accum.reverse_bits() - } fn main() { diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/README.md b/tests/tfhe_rust_bool/end_to_end_fpga/README.md index 4eb3d1103..42be8b572 100644 --- a/tests/tfhe_rust_bool/end_to_end_fpga/README.md +++ b/tests/tfhe_rust_bool/end_to_end_fpga/README.md @@ -3,7 +3,10 @@ These tests exercise Rust codegen for the [tfhe-rs](https://github.com/zama-ai/tfhe-rs) backend library, including compiling the generated Rust source and running the resulting binary. This sets -tests are specifically of the boolean plaintexts and the accompanying library. +tests are specifically of the boolean plaintexts, accompanying COSIC-KU Leuven version of the library, +and the [FPT-accelerator](https://eprint.iacr.org/2022/1635). + +> :warning: Not possible to run these tests without the COSIC extension of TFHE-rs and FPT-accelerator To avoid introducing these large dependencies into the entire project, these tests are manual, and require the system they're running on to have diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/src/fn_under_test.rs b/tests/tfhe_rust_bool/end_to_end_fpga/src/fn_under_test.rs new file mode 100644 index 000000000..2109a435e --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/src/fn_under_test.rs @@ -0,0 +1,84 @@ +use tfhe::boolean::prelude::*; + + +pub fn fn_under_test( + v0: &ServerKey, + v1: &Vec, + v2: &Vec, +) -> Vec { + let v1 = v1.iter().collect(); + let v2 = v2.iter().collect(); + let v3 = v0.xor_packed(&v1, &v2); + v3 +} + + +// pub fn fn_under_test( +// v0: &ServerKey, +// v1: &Vec<&Ciphertext>, +// v2: &Vec<&Ciphertext>, +// ) -> Vec { +// let v3 = 7; +// let v4 = 6; +// let v5 = 5; +// let v6 = 4; +// let v7 = 3; +// let v8 = 2; +// let v9 = 1; +// let v10 = 0; +// let v11 = vec![v1[v10]]; +// let v12 = vec![v1[v9]]; +// let v13 = vec![v1[v8]]; +// let v14 = vec![v1[v7]]; +// let v15 = vec![v1[v6]]; +// let v16 = vec![v1[v5]]; +// let v17 = vec![v1[v4]]; +// let v18 = vec![v1[v3]]; +// let v19 = vec![v2[v10]]; +// let v20 = vec![v2[v9]]; +// let v21 = vec![v2[v8]]; +// let v22 = vec![v2[v7]]; +// let v23 = vec![v2[v6]]; +// let v24 = vec![v2[v5]]; +// let v25 = vec![v2[v4]]; +// let v26 = vec![v2[v3]]; +// let v27 = v0.xor_packed(&v11, &v19); +// let v27 = v27.iter().collect(); +// let v28 = v0.and_packed(&v11, &v19); +// let v29 = v0.xor_packed(&v12, &v20); +// let v30 = v0.and_packed(&v12, &v20); +// let v28 = v28.iter().collect(); +// let v29 = v29.iter().collect(); +// let v31 = v0.and_packed(&v29, &v28); +// let v32 = v0.xor_packed(&v29, &v28); +// let v33 = v0.xor_packed(&v30, &v31); +// let v34 = v0.xor_packed(&v13, &v21); +// let v35 = v0.and_packed(&v13, &v21); +// let v36 = v0.and_packed(&v34, &v33); +// let v37 = v0.xor_packed(&v34, &v33); +// let v38 = v0.xor_packed(&v35, &v36); +// let v39 = v0.xor_packed(&v14, &v22); +// let v40 = v0.and_packed(&v14, &v22); +// let v41 = v0.and_packed(&v39, &v38); +// let v42 = v0.xor_packed(&v39, &v38); +// let v43 = v0.xor_packed(&v40, &v41); +// let v44 = v0.xor_packed(&v15, &v23); +// let v45 = v0.and_packed(&v15, &v23); +// let v46 = v0.and_packed(&v44, &v43); +// let v47 = v0.xor_packed(&v44, &v43); +// let v48 = v0.xor_packed(&v45, &v46); +// let v49 = v0.xor_packed(&v16, &v24); +// let v50 = v0.and_packed(&v16, &v24); +// let v51 = v0.and_packed(&v49, &v48); +// let v52 = v0.xor_packed(&v49, &v48); +// let v53 = v0.xor_packed(&v50, &v51); +// let v54 = v0.xor_packed(&v17, &v25); +// let v55 = v0.and_packed(&v17, &v25); +// let v56 = v0.and_packed(&v54, &v53); +// let v57 = v0.xor_packed(&v54, &v53); +// let v58 = v0.xor_packed(&v55, &v56); +// let v59 = v0.xor_packed(&v18, &v26); +// let v60 = v0.xor_packed(&v59, &v58); +// let v61 = vec![v60[0], v57[0], v52[0], v47[0], v42[0], v37[0], v32[0], v27[0]]; +// v61 +// } diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs b/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs index 47392eaa9..1a10fadf0 100644 --- a/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs +++ b/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs @@ -72,8 +72,8 @@ fn main() { let ct_1 = encrypt(flags.input1.into(), &client_key); let ct_2 = encrypt(flags.input2.into(), &client_key); - let ct_1= ct_1.iter().collect(); - let ct_2= ct_2.iter().collect(); + // let ct_1= ct_1.into_iter().collect(); + // let ct_2= ct_2.into_iter().collect(); let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2); diff --git a/tests/yosys_optimizer/add_one.mlir b/tests/yosys_optimizer/add_one.mlir index 73bb8a60b..cd0cf3522 100644 --- a/tests/yosys_optimizer/add_one.mlir +++ b/tests/yosys_optimizer/add_one.mlir @@ -1,4 +1,5 @@ -// RUN: heir-opt --yosys-optimizer %s | FileCheck %s +// RUN: heir-opt --yosys-optimizer --canonicalize --cse %s | FileCheck %s +// RUN: heir-opt --yosys-optimizer="mode=Boolean" --canonicalize --cse %s | FileCheck --check-prefix=CHECK --check-prefix=BOOL %s module { // CHECK-LABEL: @add_one @@ -13,6 +14,7 @@ module { ins(%in, %one: !secret.secret, i8) { ^bb0(%IN: i8, %ONE: i8) : // CHECK-NOT: arith.addi + // BOOL-COUNT-7: comb.inv %2 = arith.addi %IN, %ONE : i8 secret.yield %2 : i8 } -> (!secret.secret) diff --git a/tools/BUILD b/tools/BUILD index a2446bd13..04c5c5068 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -57,6 +57,7 @@ cc_binary( "@heir//lib/Dialect/TensorExt/Transforms", "@heir//lib/Dialect/TfheRust/IR:Dialect", "@heir//lib/Dialect/TfheRustBool/IR:Dialect", + "@heir//lib/Transforms/ElementwiseToAffine", "@heir//lib/Transforms/ForwardStoreToLoad", "@heir//lib/Transforms/FullLoopUnroll", "@heir//lib/Transforms/Secretize", @@ -65,12 +66,14 @@ cc_binary( "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", "@llvm-project//mlir:AffineTransforms", + "@llvm-project//mlir:AllExtensions", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithToLLVM", "@llvm-project//mlir:ArithTransforms", "@llvm-project//mlir:BufferizationTransforms", "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:ConvertToLLVM", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncToLLVM", "@llvm-project//mlir:FuncTransforms", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 14dade8cd..bd6fa84c2 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -24,6 +24,7 @@ #include "include/Dialect/TensorExt/Transforms/Passes.h" #include "include/Dialect/TfheRust/IR/TfheRustDialect.h" #include "include/Dialect/TfheRustBool/IR/TfheRustBoolDialect.h" +#include "include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h" #include "include/Transforms/ForwardStoreToLoad/ForwardStoreToLoad.h" #include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h" #include "include/Transforms/Secretize/Passes.h" @@ -32,12 +33,9 @@ #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project -#include "mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project -#include "mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" // from @llvm-project +#include "mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" // from @llvm-project #include "mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project #include "mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" // from @llvm-project -#include "mlir/include/mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project -#include "mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project #include "mlir/include/mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" // from @llvm-project #include "mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h" // from @llvm-project #include "mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h" // from @llvm-project @@ -57,6 +55,7 @@ #include "mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/include/mlir/InitAllDialects.h" // from @llvm-project +#include "mlir/include/mlir/InitAllExtensions.h" // from @llvm-project #include "mlir/include/mlir/InitAllPasses.h" // from @llvm-project #include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/include/mlir/Pass/PassOptions.h" // from @llvm-project @@ -132,6 +131,7 @@ void tosaPipelineBuilder(OpPassManager &manager) { void polynomialToLLVMPipelineBuilder(OpPassManager &manager) { // Poly + manager.addPass(createElementwiseToAffine()); manager.addPass(polynomial::createPolynomialToStandard()); manager.addPass(createCanonicalizerPass()); @@ -140,6 +140,7 @@ void polynomialToLLVMPipelineBuilder(OpPassManager &manager) { // Needed to lower affine.map and affine.apply manager.addNestedPass(affine::createAffineExpandIndexOpsPass()); manager.addNestedPass(affine::createSimplifyAffineStructuresPass()); + manager.addPass(createLowerAffinePass()); manager.addNestedPass(memref::createExpandOpsPass()); manager.addNestedPass(memref::createExpandStridedMetadataPass()); @@ -160,14 +161,9 @@ void polynomialToLLVMPipelineBuilder(OpPassManager &manager) { // ToLLVM manager.addPass(arith::createArithExpandOpsPass()); manager.addPass(createConvertSCFToCFPass()); - manager.addPass(createConvertControlFlowToLLVMPass()); - manager.addPass(createConvertIndexToLLVMPass()); manager.addNestedPass(memref::createExpandStridedMetadataPass()); - manager.addPass(createCanonicalizerPass()); - manager.addPass(createConvertFuncToLLVMPass()); - manager.addPass(createArithToLLVMConversionPass()); - manager.addPass(createFinalizeMemRefToLLVMConversionPass()); - manager.addPass(createReconcileUnrealizedCastsPass()); + manager.addPass(createConvertToLLVMPass()); + // Cleanup manager.addPass(createCanonicalizerPass()); manager.addPass(createSCCPPass()); @@ -285,6 +281,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registerAllDialects(registry); + registerAllExtensions(registry); // Register MLIR core passes to build pipeline. registerAllPasses(); @@ -294,6 +291,7 @@ int main(int argc, char **argv) { lwe::registerLWEPasses(); secret::registerSecretPasses(); tensor_ext::registerTensorExtPasses(); + registerElementwiseToAffinePasses(); registerSecretizePasses(); registerFullLoopUnrollPasses(); registerForwardStoreToLoadPasses();