From 2689afbfe086c04694046362dcdac1029758ce92 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Tue, 27 Feb 2024 16:07:48 -0800 Subject: [PATCH] Initial port of HECO auto-SIMD passes to HEIR - InsertRotate: insert rotations and apply target slot selection rules - TensorExtCanonicalization: canonicalization patterns to enable cse to remove unnecessary rotations - CollapseInsertionChains: identify extract/insert chains that can be converted to rotations Additional followup issues identified in https://github.com/google/heir/pull/471 for improvements. --- include/Dialect/TensorExt/IR/BUILD | 20 ++ .../TensorExt/IR/TensorExtCanonicalization.td | 88 +++++++++ .../Dialect/TensorExt/IR/TensorExtDialect.td | 9 +- include/Dialect/TensorExt/IR/TensorExtOps.h | 6 +- include/Dialect/TensorExt/IR/TensorExtOps.td | 1 + include/Dialect/TensorExt/Transforms/BUILD | 55 ++++++ .../Transforms/CollapseInsertionChains.h | 17 ++ .../TensorExt/Transforms/InsertRotate.h | 17 ++ .../TensorExt/Transforms/InsertRotate.td | 84 +++++++++ include/Dialect/TensorExt/Transforms/Passes.h | 19 ++ .../Dialect/TensorExt/Transforms/Passes.td | 63 +++++++ lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 6 +- lib/Dialect/TensorExt/IR/BUILD | 4 + lib/Dialect/TensorExt/IR/TensorExtDialect.cpp | 1 + lib/Dialect/TensorExt/IR/TensorExtOps.cpp | 22 +++ lib/Dialect/TensorExt/Transforms/BUILD | 53 ++++++ .../Transforms/CollapseInsertionChains.cpp | 175 ++++++++++++++++++ .../TensorExt/Transforms/InsertRotate.cpp | 41 ++++ tests/simd/BUILD | 16 ++ tests/simd/box_blur_4x4.mlir | 39 ++++ tests/simd/box_blur_64x64.mlir | 62 +++++++ tests/simd/hamming_distance.mlir | 32 ++++ tests/tensor_ext/canonicalize.mlir | 23 +++ .../tensor_ext/collapse_insertion_chains.mlir | 22 +++ tests/tensor_ext/insert_rotations.mlir | 29 +++ tools/BUILD | 1 + tools/heir-opt.cpp | 2 + 27 files changed, 898 insertions(+), 9 deletions(-) create mode 100644 include/Dialect/TensorExt/IR/TensorExtCanonicalization.td create mode 100644 include/Dialect/TensorExt/Transforms/BUILD create mode 100644 include/Dialect/TensorExt/Transforms/CollapseInsertionChains.h create mode 100644 include/Dialect/TensorExt/Transforms/InsertRotate.h create mode 100644 include/Dialect/TensorExt/Transforms/InsertRotate.td create mode 100644 include/Dialect/TensorExt/Transforms/Passes.h create mode 100644 include/Dialect/TensorExt/Transforms/Passes.td create mode 100644 lib/Dialect/TensorExt/Transforms/BUILD create mode 100644 lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp create mode 100644 lib/Dialect/TensorExt/Transforms/InsertRotate.cpp create mode 100644 tests/simd/BUILD create mode 100644 tests/simd/box_blur_4x4.mlir create mode 100644 tests/simd/box_blur_64x64.mlir create mode 100644 tests/simd/hamming_distance.mlir create mode 100644 tests/tensor_ext/canonicalize.mlir create mode 100644 tests/tensor_ext/collapse_insertion_chains.mlir create mode 100644 tests/tensor_ext/insert_rotations.mlir diff --git a/include/Dialect/TensorExt/IR/BUILD b/include/Dialect/TensorExt/IR/BUILD index 3f546387e..f2ac37599 100644 --- a/include/Dialect/TensorExt/IR/BUILD +++ b/include/Dialect/TensorExt/IR/BUILD @@ -11,12 +11,14 @@ exports_files( [ "TensorExtDialect.h", "TensorExtOps.h", + "TensorExtPatterns.h", ], ) td_library( name = "td_files", srcs = [ + "TensorExtCanonicalization.td", "TensorExtDialect.td", "TensorExtOps.td", ], @@ -77,3 +79,21 @@ gentbl_cc_library( "@heir//include/Dialect/Polynomial/IR:td_files", ], ) + +gentbl_cc_library( + name = "canonicalize_inc_gen", + tbl_outs = [ + ( + ["-gen-rewriters"], + "TensorExtCanonicalization.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "TensorExtCanonicalization.td", + deps = [ + ":ops_inc_gen", + ":td_files", + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:TensorOpsTdFiles", + ], +) diff --git a/include/Dialect/TensorExt/IR/TensorExtCanonicalization.td b/include/Dialect/TensorExt/IR/TensorExtCanonicalization.td new file mode 100644 index 000000000..5b308d13e --- /dev/null +++ b/include/Dialect/TensorExt/IR/TensorExtCanonicalization.td @@ -0,0 +1,88 @@ +#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTCANONICALIZATION_TD_ +#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTCANONICALIZATION_TD_ + +include "TensorExtOps.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "mlir/Dialect/Tensor/IR/TensorOps.td" +include "mlir/IR/PatternBase.td" + +// TODO(#515): refactor these helpers to a common file with InsertRotate.td +defvar DefOverflow = ConstantEnumCase; + +def MakeSingleResultVariadic: NativeCodeCall<"{ $0 }">; + +def CreateIndexCastOp : NativeCodeCall< + "$_builder.create($0.getLoc(), $1.getType(), $0)">; + +def IsZero : + Constraint< + CPred<"llvm::cast($0).getValue().isZero()">>; + +def OutOfBoundsOfTensorDim : + Constraint< + CPred< + "llvm::cast($0).getValue().getSExtValue() < 0 " + "|| llvm::cast($0).getValue().getSExtValue() > " + "llvm::cast($1.getType()).getShape()[0]" + > + >; + +// rotate %t, 0 -> %t +def DropZeroRotation : Pat< + (TensorExt_RotateOp $tensor, (ConstantLikeMatcher APIntAttr:$c0)), + (replaceWithValue $tensor), + [(IsZero $c0)] +>; + +// rotate %t, x -> rotate %t, x mod size +def NormalizeRotationIndex : Pat< + (TensorExt_RotateOp $tensor, (Arith_ConstantOp:$shiftOp APIntAttr:$shiftAmount)), + (TensorExt_RotateOp $tensor, + (Arith_RemUIOp + $shiftOp, + // Only works for 1D tensors: index is taken modulo the tensor length, + // i.e., dim 0 + (CreateIndexCastOp + (Tensor_DimOp $tensor, (Arith_ConstantOp ConstantAttr)), + $shiftOp)) + ), + [(OutOfBoundsOfTensorDim $shiftAmount, $tensor)] +>; + +// %0 = rotate %t, x +// %1 = rotate %0, y +// ---> rotate %t (x+y) +def CombineSequentialRotates : Pat< + (TensorExt_RotateOp + (TensorExt_RotateOp $tensor, (Arith_ConstantOp:$xOp APIntAttr:$x)), + (Arith_ConstantOp:$yOp APIntAttr:$y)), + (TensorExt_RotateOp $tensor, (Arith_AddIOp $xOp, $yOp, DefOverflow)), + [] +>; + +// A rotation followed by extraction can be extracted directly from the +// original tensor. +def RotatePlusExtractToIndexedExtract : Pat< + (Tensor_ExtractOp + (TensorExt_RotateOp $tensor, $shift), + (variadic $index)), + (Tensor_ExtractOp + $tensor, + (MakeSingleResultVariadic (Arith_AddIOp $shift, $index, DefOverflow))) +>; + +// Rotating two tensors by the same amount can be converted to a single +// post-rotation. This can result in eliminating either the rotation (because +// it can be combined with a later rotation) or the arith op itself, if it is +// is identical to an existing arith op applied before the rotation. +foreach ArithOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in { + def FactorParallelRotationsThroughOp_#ArithOp : Pat< + (ArithOp + (TensorExt_RotateOp $t1, $i), + (TensorExt_RotateOp $t2, $i), + $ovf), + (TensorExt_RotateOp (ArithOp $t1, $t2, $ovf), $i) + >; +} + +#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTCANONICALIZATION_TD_ diff --git a/include/Dialect/TensorExt/IR/TensorExtDialect.td b/include/Dialect/TensorExt/IR/TensorExtDialect.td index 218d00d6e..cc8ae4502 100644 --- a/include/Dialect/TensorExt/IR/TensorExtDialect.td +++ b/include/Dialect/TensorExt/IR/TensorExtDialect.td @@ -1,5 +1,5 @@ -#ifndef HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_ -#define HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_ +#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTDIALECT_TD_ +#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTDIALECT_TD_ include "mlir/IR/DialectBase.td" @@ -12,6 +12,9 @@ def TensorExt_Dialect : Dialect { }]; let cppNamespace = "::mlir::heir::tensor_ext"; + let dependentDialects = [ + "tensor::TensorDialect", + ]; } -#endif // HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_ +#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTDIALECT_TD_ diff --git a/include/Dialect/TensorExt/IR/TensorExtOps.h b/include/Dialect/TensorExt/IR/TensorExtOps.h index 026a97b6c..cb959543c 100644 --- a/include/Dialect/TensorExt/IR/TensorExtOps.h +++ b/include/Dialect/TensorExt/IR/TensorExtOps.h @@ -1,5 +1,5 @@ -#ifndef HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_ -#define HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_ +#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_H_ +#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_H_ #include "include/Dialect/TensorExt/IR/TensorExtDialect.h" #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project @@ -8,4 +8,4 @@ #define GET_OP_CLASSES #include "include/Dialect/TensorExt/IR/TensorExtOps.h.inc" -#endif // HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_ +#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_H_ diff --git a/include/Dialect/TensorExt/IR/TensorExtOps.td b/include/Dialect/TensorExt/IR/TensorExtOps.td index eb56fc9aa..d2f633fc9 100644 --- a/include/Dialect/TensorExt/IR/TensorExtOps.td +++ b/include/Dialect/TensorExt/IR/TensorExtOps.td @@ -35,6 +35,7 @@ def TensorExt_RotateOp : TensorExt_Op<"rotate", [Pure, AllTypesMatch<["tensor", let arguments = (ins AnyTensor:$tensor, SignlessIntegerLike:$shift); let results = (outs AnyTensor:$output); let assemblyFormat = "operands attr-dict `:` qualified(type($tensor)) `,` type($shift)"; + let hasCanonicalizer = 1; } #endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_TD_ diff --git a/include/Dialect/TensorExt/Transforms/BUILD b/include/Dialect/TensorExt/Transforms/BUILD new file mode 100644 index 000000000..574ad102f --- /dev/null +++ b/include/Dialect/TensorExt/Transforms/BUILD @@ -0,0 +1,55 @@ +# InsertRotate tablegen and headers. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +gentbl_cc_library( + name = "pass_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=TensorExt", + ], + "Passes.h.inc", + ), + ( + ["-gen-pass-doc"], + "TensorExtPasses.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Passes.td", + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "insert_rotate_inc_gen", + tbl_outs = [ + ( + ["-gen-rewriters"], + "InsertRotate.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "InsertRotate.td", + deps = [ + "@heir//include/Dialect/TensorExt/IR:ops_inc_gen", + "@heir//include/Dialect/TensorExt/IR:td_files", + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:TensorOpsTdFiles", + ], +) + +exports_files([ + "Passes.h", + "CollapseInsertionChains.h", + "InsertRotate.h", +]) diff --git a/include/Dialect/TensorExt/Transforms/CollapseInsertionChains.h b/include/Dialect/TensorExt/Transforms/CollapseInsertionChains.h new file mode 100644 index 000000000..631ca4f1d --- /dev/null +++ b/include/Dialect/TensorExt/Transforms/CollapseInsertionChains.h @@ -0,0 +1,17 @@ +#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_COLLAPSEINSERTIONCHAINS_H_ +#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_COLLAPSEINSERTIONCHAINS_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace tensor_ext { + +#define GEN_PASS_DECL_COLLAPSEINSERTIONCHAINS +#include "include/Dialect/TensorExt/Transforms/Passes.h.inc" + +} // namespace tensor_ext +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_COLLAPSEINSERTIONCHAINS_H_ diff --git a/include/Dialect/TensorExt/Transforms/InsertRotate.h b/include/Dialect/TensorExt/Transforms/InsertRotate.h new file mode 100644 index 000000000..b07d3af46 --- /dev/null +++ b/include/Dialect/TensorExt/Transforms/InsertRotate.h @@ -0,0 +1,17 @@ +#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_ +#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace tensor_ext { + +#define GEN_PASS_DECL_INSERTROTATE +#include "include/Dialect/TensorExt/Transforms/Passes.h.inc" + +} // namespace tensor_ext +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_ diff --git a/include/Dialect/TensorExt/Transforms/InsertRotate.td b/include/Dialect/TensorExt/Transforms/InsertRotate.td new file mode 100644 index 000000000..6ffc2a539 --- /dev/null +++ b/include/Dialect/TensorExt/Transforms/InsertRotate.td @@ -0,0 +1,84 @@ +#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_ +#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_ + +include "include/Dialect/TensorExt/IR/TensorExtOps.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "mlir/Dialect/Tensor/IR/TensorOps.td" +include "mlir/IR/PatternBase.td" + +// TODO(#512): Support target slot selection when the downstream op is an insert. + +// The patterns in this file are intended to align with the automatic-SIMD +// batching heuristics from the HECO project. See section 4.4 of +// https://arxiv.org/abs/2202.01649 and the hir2hir passes in +// https://github.com/MarbleHE/HECO/blob/main/src/Passes/hir2hir/ + +defvar DefOverflow = ConstantEnumCase; + +// To understand why this is needed, see +// https://discourse.llvm.org/t/compilation-failure-with-drr-generated-pattern/77385 +def MakeSingleResultVariadic: NativeCodeCall<"{ $0 }">; + +// Match an arith op that extracts scalar values from two tensors, and replace +// it with rotations to align slots and apply the same op in SIMD. Other +// patterns in this file will find better alignment of adjacent rotations, and +// canonicalization patterns will remove duplicated rotations. +foreach ArithOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in { + def InsertRotations_#ArithOp : Pattern< + (ArithOp + (Tensor_ExtractOp $t1, (variadic $i1)), + (Tensor_ExtractOp $t2, (variadic $i2)), + $overflow), + [ + (TensorExt_RotateOp:$r1 $t1, $i1), + (TensorExt_RotateOp:$r2 $t2, $i2), + (ArithOp:$opResult $r1, $r2, $overflow), + (Tensor_ExtractOp + $opResult, + (MakeSingleResultVariadic (Arith_ConstantOp ConstantAttr))), + ] + >; +} + +// Pre-align the first op's operands to the index that the result is +// used for in a subsequent op. +// TODO(#514): handle OuterOp with two different InnerOps on the LHS and RHS +foreach InnerOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in { + foreach OuterOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in { + // Left associated grouping handles (add (add (rotate t1 i1) (rotate t2 i2)) (rotate t3 i3)) + def AlignRotations_LeftAssociated_Inner_#InnerOp#_Outer_#OuterOp : Pattern< + (OuterOp + (InnerOp (TensorExt_RotateOp $t1, $i1), (TensorExt_RotateOp $t2, $i2), $ovf1), + (TensorExt_RotateOp $t3, $i3), + $ovf2), + [ + (TensorExt_RotateOp:$r1 $t1, (Arith_SubIOp $i1, $i3, DefOverflow)), + (TensorExt_RotateOp:$r2 $t2, (Arith_SubIOp $i2, $i3, DefOverflow)), + (InnerOp:$addResult $r1, $r2, $ovf1), + (OuterOp:$output $addResult, $t3, $ovf2), + // Downstream ops are not updated by this pass, so we need to preserve the original + // rotation and then clean it up in a separate canonicalization pattern. + (TensorExt_RotateOp $output, $i3), + ] + >; + + // Right associated grouping handles (add (rotate t1 i1) (add (rotate t2 i2) (rotate t3 i3))) + def AlignRotations_RightAssociated_Inner_#InnerOp#_Outer_#OuterOp : Pattern< + (OuterOp + (TensorExt_RotateOp $t3, $i3), + (InnerOp (TensorExt_RotateOp $t1, $i1), (TensorExt_RotateOp $t2, $i2), $ovf1), + $ovf2), + [ + (TensorExt_RotateOp:$r1 $t1, (Arith_SubIOp $i1, $i3, DefOverflow)), + (TensorExt_RotateOp:$r2 $t2, (Arith_SubIOp $i2, $i3, DefOverflow)), + (InnerOp:$addResult $r1, $r2, $ovf1), + (OuterOp:$output $addResult, $t3, $ovf2), + // Downstream ops are not updated by this pass, so we need to preserve the original + // rotation and then clean it up in a separate canonicalization pattern. + (TensorExt_RotateOp $output, $i3), + ] + >; + } +} + +#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_ diff --git a/include/Dialect/TensorExt/Transforms/Passes.h b/include/Dialect/TensorExt/Transforms/Passes.h new file mode 100644 index 000000000..01eafff25 --- /dev/null +++ b/include/Dialect/TensorExt/Transforms/Passes.h @@ -0,0 +1,19 @@ +#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_H_ +#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_H_ + +#include "include/Dialect/TensorExt/IR/TensorExtDialect.h" +#include "include/Dialect/TensorExt/Transforms/CollapseInsertionChains.h" +#include "include/Dialect/TensorExt/Transforms/InsertRotate.h" + +namespace mlir { +namespace heir { +namespace tensor_ext { + +#define GEN_PASS_REGISTRATION +#include "include/Dialect/TensorExt/Transforms/Passes.h.inc" + +} // namespace tensor_ext +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_H_ diff --git a/include/Dialect/TensorExt/Transforms/Passes.td b/include/Dialect/TensorExt/Transforms/Passes.td new file mode 100644 index 000000000..88d396611 --- /dev/null +++ b/include/Dialect/TensorExt/Transforms/Passes.td @@ -0,0 +1,63 @@ +#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_TD_ +#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_TD_ + +include "mlir/Pass/PassBase.td" + +def InsertRotate : Pass<"insert-rotate"> { + let summary = "Vectorize arithmetic FHE operations using HECO-style heuristics"; + let description = [{ + This pass implements the SIMD-vectorization passes from the + [HECO paper](https://arxiv.org/abs/2202.01649). + + The pass operates by identifying arithmetic operations that can be suitably + combined into a combination of cyclic rotations and vectorized operations + on tensors. It further identifies a suitable "slot target" for each operation + and heuristically aligns the operations to reduce unnecessary rotations. + + This pass by itself does not eliminate any operations, but instead inserts + well-chosen rotations so that, for well-structured code (like unrolled affine loops), + a subsequent `--cse` and `--canonicalize` pass will dramatically reduce the IR. + As such, the pass is designed to be paired with the canonicalization patterns + in `tensor_ext`, as well as the `collapse-insertion-chains` pass, which + cleans up remaining insertion and extraction ops after the main simplifications + are applied. + + Unlike HECO, this pass operates on plaintext types and tensors, along with + the HEIR-specific `tensor_ext` dialect for its cyclic `rotate` op. It is intended + to be run before lowering to a scheme dialect like `bgv`. + }]; + let dependentDialects = ["mlir::heir::tensor_ext::TensorExtDialect"]; +} + +// TODO(#512): Investigate replacing this pattern with a tensor_ext.combine op +def CollapseInsertionChains : Pass<"collapse-insertion-chains"> { + let summary = "Collapse chains of extract/insert ops into rotate ops when possible"; + let description = [{ + This pass is a cleanup pass for `insert-rotate`. That pass sometimes leaves + behind a chain of insertion operations like this: + + ```mlir + %extracted = tensor.extract %14[%c5] : tensor<16xi16> + %inserted = tensor.insert %extracted into %dest[%c0] : tensor<16xi16> + %extracted_0 = tensor.extract %14[%c6] : tensor<16xi16> + %inserted_1 = tensor.insert %extracted_0 into %inserted[%c1] : tensor<16xi16> + %extracted_2 = tensor.extract %14[%c7] : tensor<16xi16> + %inserted_3 = tensor.insert %extracted_2 into %inserted_1[%c2] : tensor<16xi16> + ... + %extracted_28 = tensor.extract %14[%c4] : tensor<16xi16> + %inserted_29 = tensor.insert %extracted_28 into %inserted_27[%c15] : tensor<16xi16> + yield %inserted_29 : tensor<16xi16> + ``` + + In many cases, this chain will insert into every index of the `dest` tensor, + and the extracted values all come from consistently aligned indices of the same + source tensor. In this case, the chain can be collapsed into a single `rotate`. + + Each index used for insertion or extraction must be constant; this may + require running `--canonicalize` or `--sccp` before this pass to apply + folding rules (use `--sccp` if you need to fold constant through control flow). + }]; + let dependentDialects = ["mlir::heir::tensor_ext::TensorExtDialect"]; +} + +#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_TD_ diff --git a/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/lib/Dialect/Polynomial/IR/PolynomialOps.cpp index e9793ea28..61c984f1b 100644 --- a/lib/Dialect/Polynomial/IR/PolynomialOps.cpp +++ b/lib/Dialect/Polynomial/IR/PolynomialOps.cpp @@ -6,13 +6,13 @@ #include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project -// Required after PatternMatch.h -#include "include/Dialect/Polynomial/IR/PolynomialCanonicalize.cpp.inc" - namespace mlir { namespace heir { namespace polynomial { +// Required after PatternMatch.h +#include "include/Dialect/Polynomial/IR/PolynomialCanonicalize.cpp.inc" + void FromTensorOp::build(OpBuilder &builder, OperationState &result, Value input, RingAttr ring) { TensorType tensorType = dyn_cast(input.getType()); diff --git a/lib/Dialect/TensorExt/IR/BUILD b/lib/Dialect/TensorExt/IR/BUILD index 7681c51d2..8e649e4c8 100644 --- a/lib/Dialect/TensorExt/IR/BUILD +++ b/lib/Dialect/TensorExt/IR/BUILD @@ -20,6 +20,7 @@ cc_library( "@heir//include/Dialect/TensorExt/IR:ops_inc_gen", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:TensorDialect", ], ) @@ -33,9 +34,12 @@ cc_library( "@heir//include/Dialect/TensorExt/IR:TensorExtOps.h", ], deps = [ + "@heir//include/Dialect/TensorExt/IR:canonicalize_inc_gen", "@heir//include/Dialect/TensorExt/IR:dialect_inc_gen", "@heir//include/Dialect/TensorExt/IR:ops_inc_gen", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:TensorDialect", ], ) diff --git a/lib/Dialect/TensorExt/IR/TensorExtDialect.cpp b/lib/Dialect/TensorExt/IR/TensorExtDialect.cpp index 25f9e41c6..fd0dcfdf7 100644 --- a/lib/Dialect/TensorExt/IR/TensorExtDialect.cpp +++ b/lib/Dialect/TensorExt/IR/TensorExtDialect.cpp @@ -1,5 +1,6 @@ #include "include/Dialect/TensorExt/IR/TensorExtDialect.h" +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project // NOLINTNEXTLINE(misc-include-cleaner): Required to define TensorExtOps diff --git a/lib/Dialect/TensorExt/IR/TensorExtOps.cpp b/lib/Dialect/TensorExt/IR/TensorExtOps.cpp index 0d14a2928..394484c4c 100644 --- a/lib/Dialect/TensorExt/IR/TensorExtOps.cpp +++ b/lib/Dialect/TensorExt/IR/TensorExtOps.cpp @@ -1 +1,23 @@ #include "include/Dialect/TensorExt/IR/TensorExtOps.h" + +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace tensor_ext { + +// Kept inside a namespace because it generates a function called +// populateWithGenerated, which can conflict with other generated patterns. +#include "include/Dialect/TensorExt/IR/TensorExtCanonicalization.cpp.inc" + +void RotateOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + populateWithGenerated(results); +} + +} // namespace tensor_ext +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/TensorExt/Transforms/BUILD b/lib/Dialect/TensorExt/Transforms/BUILD new file mode 100644 index 000000000..8853a87d5 --- /dev/null +++ b/lib/Dialect/TensorExt/Transforms/BUILD @@ -0,0 +1,53 @@ +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "Transforms", + hdrs = [ + "@heir//include/Dialect/TensorExt/Transforms:Passes.h", + ], + deps = [ + ":CollapseInsertionChains", + ":InsertRotate", + "@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen", + "@heir//lib/Dialect/TensorExt/IR:Dialect", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "InsertRotate", + srcs = ["InsertRotate.cpp"], + hdrs = [ + "@heir//include/Dialect/TensorExt/Transforms:InsertRotate.h", + ], + deps = [ + "@heir//include/Dialect/TensorExt/IR:canonicalize_inc_gen", + "@heir//include/Dialect/TensorExt/Transforms:insert_rotate_inc_gen", + "@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen", + "@heir//lib/Dialect/TensorExt/IR:Dialect", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "CollapseInsertionChains", + srcs = ["CollapseInsertionChains.cpp"], + hdrs = [ + "@heir//include/Dialect/TensorExt/Transforms:CollapseInsertionChains.h", + ], + deps = [ + "@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen", + "@heir//lib/Dialect/TensorExt/IR:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp b/lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp new file mode 100644 index 000000000..d54324642 --- /dev/null +++ b/lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp @@ -0,0 +1,175 @@ +#include "include/Dialect/TensorExt/Transforms/CollapseInsertionChains.h" + +#include "include/Dialect/TensorExt/IR/TensorExtOps.h" +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project + +#define DEBUG_TYPE "collapse-insertion-chains" + +namespace mlir { +namespace heir { +namespace tensor_ext { + +#define GEN_PASS_DEF_COLLAPSEINSERTIONCHAINS +#include "include/Dialect/TensorExt/Transforms/Passes.h.inc" + +template +FailureOr get1DExtractionIndex(Op op) { + auto insertIndices = op.getIndices(); + if (insertIndices.size() != 1) return failure(); + + // Each index must be constant; this may require running --canonicalize or + // -sccp before this pass to apply folding rules (use -sccp if you need to + // fold constants through control flow). + Value insertIndex = *insertIndices.begin(); + auto insertIndexConstOp = insertIndex.getDefiningOp(); + if (!insertIndexConstOp) return failure(); + + auto insertOffsetAttr = + llvm::dyn_cast(insertIndexConstOp.getValue()); + if (!insertOffsetAttr) return failure(); + + return insertOffsetAttr.getInt(); +} + +/// A pattern that searches for sequences of extract + insert, where the +/// indices extracted and inserted have the same offset, and replaced them with +/// a single rotate operation. +struct ConvertAlignedExtractInsertToRotate + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + // Given an insert and extract op, compute the shift between their two + // access indices. Only works for 1D tensors. + FailureOr calculateShift(tensor::InsertOp insertOp, + tensor::ExtractOp extractOp) const { + auto insertIndexRes = get1DExtractionIndex(insertOp); + auto extractionIndexRes = get1DExtractionIndex(extractOp); + if (failed(insertIndexRes) || failed(extractionIndexRes)) return failure(); + + auto insertTensorType = + insertOp.getDest().getType().cast(); + auto extractTensorType = + extractOp.getTensor().getType().cast(); + if (insertTensorType.getShape() != extractTensorType.getShape()) + return failure(); + + auto shift = (extractionIndexRes.value() - insertIndexRes.value()); + if (shift < 0) shift += insertTensorType.getShape()[0]; + + LLVM_DEBUG({ + llvm::dbgs() << "Insertion index: " << insertIndexRes.value() << "\n"; + llvm::dbgs() << "Extraction index: " << extractionIndexRes.value() + << "\n"; + llvm::dbgs() << "Shift: " << shift << "\n"; + }); + return shift; + } + + LogicalResult matchAndRewrite(tensor::InsertOp insertOp, + PatternRewriter &rewriter) const override { + auto extractOp = insertOp.getScalar().getDefiningOp(); + if (!extractOp) return failure(); + + auto shiftRes = calculateShift(insertOp, extractOp); + if (failed(shiftRes)) return failure(); + + int64_t shift = shiftRes.value(); + DenseSet accessedIndices; + accessedIndices.insert(get1DExtractionIndex(insertOp).value()); + + // Check if there are corresponding insertions into all other indices, + // which are extracted from the same source tensor with the same shift. + // + // The problem is that because tensors have value semantics and not pointer + // semantics, we will see each new insertion use a different SSA value, like + // this: + // + // %extracted_1 = tensor.extract %original_source[%c1] : tensor<4096xi16> + // %inserted_1 = tensor.insert %extracted_1 into %original_dest[%c5] : + // tensor<4096xi16> %extracted_2 = tensor.extract %original_source[%c2] : + // tensor<4096xi16> %inserted_2 = tensor.insert %extracted_2 into + // %inserted_1[%c6] : tensor<4096xi16> + // + // Note hwo inserted_1 replaces original_dest for the subsequent insert, + // and inserted_2 will replace inserted_1 for the next one. + // + // So we need to traverse the insertions in order to follow the chain. Note + // that a more sophisticated pass might be able to support less + // well-structured DAGs of insertions and extractions, but we will improve + // this when that becomes necessary. + + // Also note that the greedy pattern rewriter will start from the first op + // in the block order, so we can assume that if the pattern matches, it + // matches on the first insert op encountered. + auto extractionSource = extractOp.getTensor(); + auto current = insertOp; + while (true) { + bool foundNext = false; + + for (auto *user : current.getResult().getUsers()) { + if (auto nextInsert = dyn_cast(user)) { + auto nextExtract = + nextInsert.getScalar().getDefiningOp(); + if (!nextExtract) continue; + + // We're inserting into this tensor from a different tensor than + // earlier insertions in the chain, so we can't continue. + if (nextExtract.getTensor() != extractionSource) return failure(); + + auto nextShiftRes = calculateShift(nextInsert, nextExtract); + if (failed(nextShiftRes)) return failure(); + + if (nextShiftRes.value() != shift) return failure(); + + accessedIndices.insert(get1DExtractionIndex(nextInsert).value()); + current = nextInsert; + foundNext = true; + } + } + + // This can either be because we reached the end of the chain, or else + // because the chain is incomplete. + if (!foundNext) break; + } + + // We didn't cover the entire tensor, so the downstream user of this tensor + // may be depending on the original data in the untouched indices being in + // tact. + if (accessedIndices.size() != + extractionSource.getType().cast().getShape()[0]) + return failure(); + + // The last insertion must be replaced because its user is the final end + // user + rewriter.replaceOpWithNewOp( + current, extractionSource, + rewriter.create(current.getLoc(), shift, + /*width=*/32)); + + // The rest of the chain of insertions and extractions itself will be + // DCE'd by canonicalization if possible. + return success(); + } +}; + +struct CollapseInsertionChains + : impl::CollapseInsertionChainsBase { + using CollapseInsertionChainsBase::CollapseInsertionChainsBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + + patterns.add(context); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace tensor_ext +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp b/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp new file mode 100644 index 000000000..bacd8407f --- /dev/null +++ b/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp @@ -0,0 +1,41 @@ +#include "include/Dialect/TensorExt/Transforms/InsertRotate.h" + +#include "include/Dialect/TensorExt/IR/TensorExtOps.h" +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace tensor_ext { + +#define GEN_PASS_DEF_INSERTROTATE +#include "include/Dialect/TensorExt/Transforms/Passes.h.inc" + +namespace alignment { +// In an inner namespace to avoid conflicts with canonicalization patterns +#include "include/Dialect/TensorExt/Transforms/InsertRotate.cpp.inc" +} // namespace alignment + +namespace canonicalization { +#include "include/Dialect/TensorExt/IR/TensorExtCanonicalization.cpp.inc" +} // namespace canonicalization + +struct InsertRotate : impl::InsertRotateBase { + using InsertRotateBase::InsertRotateBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + + alignment::populateWithGenerated(patterns); + canonicalization::populateWithGenerated(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace tensor_ext +} // namespace heir +} // namespace mlir diff --git a/tests/simd/BUILD b/tests/simd/BUILD new file mode 100644 index 000000000..c9cee5978 --- /dev/null +++ b/tests/simd/BUILD @@ -0,0 +1,16 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + size_override = { + "box_blur_64x64.mlir": "large", + }, + test_file_exts = ["mlir"], +) diff --git a/tests/simd/box_blur_4x4.mlir b/tests/simd/box_blur_4x4.mlir new file mode 100644 index 000000000..007ba1d35 --- /dev/null +++ b/tests/simd/box_blur_4x4.mlir @@ -0,0 +1,39 @@ +// RUN: heir-opt --secretize=entry-function=box_blur --wrap-generic --canonicalize --cse \ +// RUN: --full-loop-unroll \ +// RUN: --insert-rotate --cse --canonicalize --collapse-insertion-chains --canonicalize --cse \ +// RUN: %s | FileCheck %s + +module { + // CHECK-LABEL: @box_blur + // CHECK-NOT: tensor.extract + // CHECK-COUNT-7: tensor_ext.rotate + func.func @box_blur(%arg0: tensor<16xi16> {secret.secret}) -> tensor<16xi16> { + %c16 = arith.constant 16 : index + %c4 = arith.constant 4 : index + %0 = affine.for %x = 0 to 4 iter_args(%arg0_x = %arg0) -> (tensor<16xi16>) { + %1 = affine.for %y = 0 to 4 iter_args(%arg0_y = %arg0_x) -> (tensor<16xi16>) { + %c0_si16 = arith.constant 0 : i16 + %2 = affine.for %j = -1 to 2 iter_args(%value_j = %c0_si16) -> (i16) { + %6 = affine.for %i = -1 to 2 iter_args(%value_i = %value_j) -> (i16) { + %7 = arith.addi %x, %i : index + %8 = arith.muli %7, %c4 : index + %9 = arith.addi %y, %j : index + %10 = arith.addi %8, %9 : index + %11 = arith.remui %10, %c16 : index + %12 = tensor.extract %arg0[%11] : tensor<16xi16> + %13 = arith.addi %value_i, %12 : i16 + affine.yield %13 : i16 + } + affine.yield %6 : i16 + } + %3 = arith.muli %c4, %x : index + %4 = arith.addi %3, %y : index + %5 = arith.remui %4, %c16 : index + %6 = tensor.insert %2 into %arg0_y[%5] : tensor<16xi16> + affine.yield %6 : tensor<16xi16> + } + affine.yield %1 : tensor<16xi16> + } + return %0 : tensor<16xi16> + } +} diff --git a/tests/simd/box_blur_64x64.mlir b/tests/simd/box_blur_64x64.mlir new file mode 100644 index 000000000..21fb28d95 --- /dev/null +++ b/tests/simd/box_blur_64x64.mlir @@ -0,0 +1,62 @@ +// RUN: heir-opt --secretize=entry-function=box_blur --wrap-generic --canonicalize --cse --full-loop-unroll \ +// RUN: --insert-rotate --cse --canonicalize --collapse-insertion-chains --canonicalize --cse \ +// RUN: %s | FileCheck %s + +module { + // CHECK-LABEL: @box_blur + // CHECK-SAME: %[[arg0:.*]]: !secret.secret>) -> !secret.secret> { + // CHECK-NEXT: %[[c65:.*]] = arith.constant 65 : i32 + // CHECK-NEXT: %[[c127:.*]] = arith.constant 127 : index + // CHECK-NEXT: %[[c4032:.*]] = arith.constant 4032 : index + // CHECK-NEXT: %[[c3968:.*]] = arith.constant 3968 : index + // CHECK-NEXT: %[[c63:.*]] = arith.constant 63 : index + // CHECK-NEXT: %[[v0:.*]] = secret.generic ins(%[[arg0]] : !secret.secret>) { + // CHECK-NEXT: ^bb0(%[[arg1:.*]]: tensor<4096xi16>): + // CHECK-NEXT: %[[v1:.*]] = tensor_ext.rotate %[[arg1]], %[[c3968]] + // CHECK-NEXT: %[[v2:.*]] = tensor_ext.rotate %[[arg1]], %[[c4032]] + // CHECK-NEXT: %[[v3:.*]] = arith.addi %[[v1]], %[[v2]] + // CHECK-NEXT: %[[v4:.*]] = arith.addi %[[v3]], %arg1 + // CHECK-NEXT: %[[v5:.*]] = tensor_ext.rotate %[[v4]], %[[c63]] + // CHECK-NEXT: %[[v6:.*]] = arith.addi %[[v5]], %[[v2]] + // CHECK-NEXT: %[[v7:.*]] = arith.addi %[[v6]], %arg1 + // CHECK-NEXT: %[[v8:.*]] = tensor_ext.rotate %[[v7]], %[[c63]] + // CHECK-NEXT: %[[v9:.*]] = tensor_ext.rotate %arg1, %[[c127]] + // CHECK-NEXT: %[[v10:.*]] = arith.addi %[[v8]], %[[v9]] + // CHECK-NEXT: %[[v11:.*]] = arith.addi %[[v10]], %arg1 + // CHECK-NEXT: %[[v12:.*]] = tensor_ext.rotate %[[v11]], %[[c3968]] + // CHECK-NEXT: %[[v13:.*]] = arith.addi %[[v12]], %[[v2]] + // CHECK-NEXT: %[[v14:.*]] = arith.addi %[[v13]], %arg1 + // CHECK-NEXT: %[[v15:.*]] = tensor_ext.rotate %[[v14]], %[[c65]] + // CHECK-NEXT: secret.yield %[[v15]] + // CHECK-NEXT: } -> !secret.secret> + // CHECK-NEXT: return %[[v0]] + func.func @box_blur(%arg0: tensor<4096xi16> {secret.secret}) -> tensor<4096xi16> { + %c4096 = arith.constant 4096 : index + %c64 = arith.constant 64 : index + %0 = affine.for %x = 0 to 64 iter_args(%arg0_x = %arg0) -> (tensor<4096xi16>) { + %1 = affine.for %y = 0 to 64 iter_args(%arg0_y = %arg0_x) -> (tensor<4096xi16>) { + %c0_si16 = arith.constant 0 : i16 + %2 = affine.for %j = -1 to 2 iter_args(%value_j = %c0_si16) -> (i16) { + %6 = affine.for %i = -1 to 2 iter_args(%value_i = %value_j) -> (i16) { + %7 = arith.addi %x, %i : index + %8 = arith.muli %7, %c64 : index + %9 = arith.addi %y, %j : index + %10 = arith.addi %8, %9 : index + %11 = arith.remui %10, %c4096 : index + %12 = tensor.extract %arg0[%11] : tensor<4096xi16> + %13 = arith.addi %value_i, %12 : i16 + affine.yield %13 : i16 + } + affine.yield %6 : i16 + } + %3 = arith.muli %c64, %x : index + %4 = arith.addi %3, %y : index + %5 = arith.remui %4, %c4096 : index + %6 = tensor.insert %2 into %arg0_y[%5] : tensor<4096xi16> + affine.yield %6 : tensor<4096xi16> + } + affine.yield %1 : tensor<4096xi16> + } + return %0 : tensor<4096xi16> + } +} diff --git a/tests/simd/hamming_distance.mlir b/tests/simd/hamming_distance.mlir new file mode 100644 index 000000000..10a8dfb94 --- /dev/null +++ b/tests/simd/hamming_distance.mlir @@ -0,0 +1,32 @@ +// RUN: heir-opt --secretize=entry-function=hamming --wrap-generic --canonicalize --cse \ +// RUN: --full-loop-unroll --insert-rotate --cse --canonicalize \ +// RUN: %s | FileCheck %s + +// CHECK-LABEL: @hamming +// CHECK: secret.generic +// CHECK: arith.subi +// CHECK-NEXT: arith.muli +// CHECK-NEXT: tensor_ext.rotate +// CHECK-NEXT: arith.addi +// CHECK-NEXT: tensor_ext.rotate +// CHECK-NEXT: arith.addi +// CHECK-NEXT: arith.addi +// CHECK-NEXT: tensor.extract +// CHECK-NEXT: secret.yield + +func.func @hamming(%arg0: tensor<4xi16> {secret.secret}, %arg1: tensor<4xi16> {secret.secret}) -> i16 { + %c0 = arith.constant 0 : index + %c0_si16 = arith.constant 0 : i16 + %0 = affine.for %arg2 = 0 to 4 iter_args(%arg6 = %c0_si16) -> i16 { + %1 = tensor.extract %arg0[%arg2] : tensor<4xi16> + %2 = tensor.extract %arg1[%arg2] : tensor<4xi16> + %3 = arith.subi %1, %2 : i16 + %4 = tensor.extract %arg0[%arg2] : tensor<4xi16> + %5 = tensor.extract %arg1[%arg2] : tensor<4xi16> + %6 = arith.subi %4, %5 : i16 + %7 = arith.muli %3, %6 : i16 + %8 = arith.addi %arg6, %7 : i16 + affine.yield %8 : i16 + } + return %0 : i16 +} diff --git a/tests/tensor_ext/canonicalize.mlir b/tests/tensor_ext/canonicalize.mlir new file mode 100644 index 000000000..3d4bf1e58 --- /dev/null +++ b/tests/tensor_ext/canonicalize.mlir @@ -0,0 +1,23 @@ +// RUN: heir-opt --canonicalize %s | FileCheck %s + +// CHECK-LABEL: @test_sum_rotation_indices +// CHECK: %[[c3:.*]] = arith.constant 3 : i32 +// CHECK: tensor_ext.rotate +// CHECK-SAME: %[[c3]] +func.func @test_sum_rotation_indices(%0: tensor<16xi32>) -> tensor<16xi32> { + %c1 = arith.constant 1 : i32 + %c2 = arith.constant 2 : i32 + %1 = tensor_ext.rotate %0, %c1 : tensor<16xi32>, i32 + %2 = tensor_ext.rotate %1, %c2 : tensor<16xi32>, i32 + return %2 : tensor<16xi32> +} + +// CHECK-LABEL: @test_normalize_negative +// CHECK: %[[c3:.*]] = arith.constant 3 : i32 +// CHECK: tensor_ext.rotate +// CHECK-SAME: %[[c3]] +func.func @test_normalize_negative(%0: tensor<16xi32>) -> tensor<16xi32> { + %c1 = arith.constant -13 : i32 + %1 = tensor_ext.rotate %0, %c1 : tensor<16xi32>, i32 + return %1 : tensor<16xi32> +} diff --git a/tests/tensor_ext/collapse_insertion_chains.mlir b/tests/tensor_ext/collapse_insertion_chains.mlir new file mode 100644 index 000000000..5504be60a --- /dev/null +++ b/tests/tensor_ext/collapse_insertion_chains.mlir @@ -0,0 +1,22 @@ +// RUN: heir-opt --collapse-insertion-chains --canonicalize %s | FileCheck %s + +// CHECK-LABEL: @test_collapse_insertion_chains +// CHECK-SAME: (%[[in:.*]]: tensor<4xi32>, %[[out:.*]]: tensor<4xi32>) -> tensor<4xi32> { +// CHECK: %[[c2:.*]] = arith.constant 2 : i32 +// CHECK: %[[res:.*]] = tensor_ext.rotate %[[in]], %[[c2]] : tensor<4xi32>, i32 +// CHECK: return %[[res]] : tensor<4xi32> +func.func @test_collapse_insertion_chains(%in: tensor<4xi32>, %out: tensor<4xi32>) -> tensor<4xi32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %ex0 = tensor.extract %in[%c3] : tensor<4xi32> + %in0 = tensor.insert %ex0 into %out[%c1] : tensor<4xi32> + %ex1 = tensor.extract %in[%c0] : tensor<4xi32> + %in1 = tensor.insert %ex1 into %in0[%c2] : tensor<4xi32> + %ex2 = tensor.extract %in[%c1] : tensor<4xi32> + %in2 = tensor.insert %ex2 into %in1[%c3] : tensor<4xi32> + %ex3 = tensor.extract %in[%c2] : tensor<4xi32> + %in3 = tensor.insert %ex3 into %in2[%c0] : tensor<4xi32> + return %in3 : tensor<4xi32> +} diff --git a/tests/tensor_ext/insert_rotations.mlir b/tests/tensor_ext/insert_rotations.mlir new file mode 100644 index 000000000..97d625cfe --- /dev/null +++ b/tests/tensor_ext/insert_rotations.mlir @@ -0,0 +1,29 @@ +// RUN: heir-opt --insert-rotate --canonicalize --cse %s | FileCheck %s + +func.func @test_insert_rotation_for_add(%arg1: tensor<16xi32>) -> i32 { + %c4 = arith.constant 4 : index + %c11 = arith.constant 11 : index + %c15 = arith.constant 15 : index + + // These two ops are rotated both to align with each other, and so that the + // result aligns with the %c4 rotation in extracted_1. + %extracted = tensor.extract %arg1[%c11] : tensor<16xi32> + %extracted_0 = tensor.extract %arg1[%c15] : tensor<16xi32> + %1 = arith.addi %extracted, %extracted_0 : i32 + + %extracted_1 = tensor.extract %arg1[%c4] : tensor<16xi32> + %2 = arith.addi %1, %extracted_1 : i32 + return %2 : i32 +} + +// CHECK-LABEL: func @test_insert_rotation_for_add +// CHECK-SAME: (%[[arg0:.*]]: tensor<16xi32>) -> i32 { +// CHECK-NEXT: %[[c11:.*]] = arith.constant 11 +// CHECK-NEXT: %[[c4:.*]] = arith.constant 4 +// CHECK-NEXT: %[[c7:.*]] = arith.constant 7 +// CHECK-NEXT: %[[v0:.*]] = tensor_ext.rotate %[[arg0]], %[[c7]] : tensor<16xi32>, index +// CHECK-NEXT: %[[v1:.*]] = tensor_ext.rotate %[[arg0]], %[[c11]] : tensor<16xi32>, index +// CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v0]], %[[v1]] : tensor<16xi32> +// CHECK-NEXT: %[[v3:.*]] = arith.addi %[[v2]], %[[arg0]] : tensor<16xi32> +// CHECK-NEXT: %[[extracted:.*]] = tensor.extract %[[v3]][%[[c4]]] : tensor<16xi32> +// CHECK-NEXT: return diff --git a/tools/BUILD b/tools/BUILD index ada72f155..a2446bd13 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -54,6 +54,7 @@ cc_binary( "@heir//lib/Dialect/Secret/Transforms", "@heir//lib/Dialect/Secret/Transforms:DistributeGeneric", "@heir//lib/Dialect/TensorExt/IR:Dialect", + "@heir//lib/Dialect/TensorExt/Transforms", "@heir//lib/Dialect/TfheRust/IR:Dialect", "@heir//lib/Dialect/TfheRustBool/IR:Dialect", "@heir//lib/Transforms/ForwardStoreToLoad", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 0810bca22..14dade8cd 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -21,6 +21,7 @@ #include "include/Dialect/Secret/Transforms/DistributeGeneric.h" #include "include/Dialect/Secret/Transforms/Passes.h" #include "include/Dialect/TensorExt/IR/TensorExtDialect.h" +#include "include/Dialect/TensorExt/Transforms/Passes.h" #include "include/Dialect/TfheRust/IR/TfheRustDialect.h" #include "include/Dialect/TfheRustBool/IR/TfheRustBoolDialect.h" #include "include/Transforms/ForwardStoreToLoad/ForwardStoreToLoad.h" @@ -292,6 +293,7 @@ int main(int argc, char **argv) { cggi::registerCGGIPasses(); lwe::registerLWEPasses(); secret::registerSecretPasses(); + tensor_ext::registerTensorExtPasses(); registerSecretizePasses(); registerFullLoopUnrollPasses(); registerForwardStoreToLoadPasses();