Skip to content

Commit

Permalink
Merge pull request #471 from j2kun:simd-packing
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 614777518
  • Loading branch information
Copybara-Service committed Mar 11, 2024
2 parents a00de65 + 2689afb commit 3ffb592
Show file tree
Hide file tree
Showing 27 changed files with 915 additions and 9 deletions.
20 changes: 20 additions & 0 deletions include/Dialect/TensorExt/IR/BUILD
Expand Up @@ -11,12 +11,14 @@ exports_files(
[
"TensorExtDialect.h",
"TensorExtOps.h",
"TensorExtPatterns.h",
],
)

td_library(
name = "td_files",
srcs = [
"TensorExtCanonicalization.td",
"TensorExtDialect.td",
"TensorExtOps.td",
],
Expand Down Expand Up @@ -77,3 +79,21 @@ gentbl_cc_library(
"@heir//include/Dialect/Polynomial/IR:td_files",
],
)

gentbl_cc_library(
name = "canonicalize_inc_gen",
tbl_outs = [
(
["-gen-rewriters"],
"TensorExtCanonicalization.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "TensorExtCanonicalization.td",
deps = [
":ops_inc_gen",
":td_files",
"@llvm-project//mlir:ArithOpsTdFiles",
"@llvm-project//mlir:TensorOpsTdFiles",
],
)
88 changes: 88 additions & 0 deletions include/Dialect/TensorExt/IR/TensorExtCanonicalization.td
@@ -0,0 +1,88 @@
#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTCANONICALIZATION_TD_
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTCANONICALIZATION_TD_

include "TensorExtOps.td"
include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"
include "mlir/IR/PatternBase.td"

// TODO(#515): refactor these helpers to a common file with InsertRotate.td
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;

def MakeSingleResultVariadic: NativeCodeCall<"{ $0 }">;

def CreateIndexCastOp : NativeCodeCall<
"$_builder.create<arith::IndexCastOp>($0.getLoc(), $1.getType(), $0)">;

def IsZero :
Constraint<
CPred<"llvm::cast<mlir::IntegerAttr>($0).getValue().isZero()">>;

def OutOfBoundsOfTensorDim :
Constraint<
CPred<
"llvm::cast<mlir::IntegerAttr>($0).getValue().getSExtValue() < 0 "
"|| llvm::cast<mlir::IntegerAttr>($0).getValue().getSExtValue() > "
"llvm::cast<mlir::ShapedType>($1.getType()).getShape()[0]"
>
>;

// rotate %t, 0 -> %t
def DropZeroRotation : Pat<
(TensorExt_RotateOp $tensor, (ConstantLikeMatcher APIntAttr:$c0)),
(replaceWithValue $tensor),
[(IsZero $c0)]
>;

// rotate %t, x -> rotate %t, x mod size
def NormalizeRotationIndex : Pat<
(TensorExt_RotateOp $tensor, (Arith_ConstantOp:$shiftOp APIntAttr:$shiftAmount)),
(TensorExt_RotateOp $tensor,
(Arith_RemUIOp
$shiftOp,
// Only works for 1D tensors: index is taken modulo the tensor length,
// i.e., dim 0
(CreateIndexCastOp
(Tensor_DimOp $tensor, (Arith_ConstantOp ConstantAttr<IndexAttr, "0">)),
$shiftOp))
),
[(OutOfBoundsOfTensorDim $shiftAmount, $tensor)]
>;

// %0 = rotate %t, x
// %1 = rotate %0, y
// ---> rotate %t (x+y)
def CombineSequentialRotates : Pat<
(TensorExt_RotateOp
(TensorExt_RotateOp $tensor, (Arith_ConstantOp:$xOp APIntAttr:$x)),
(Arith_ConstantOp:$yOp APIntAttr:$y)),
(TensorExt_RotateOp $tensor, (Arith_AddIOp $xOp, $yOp, DefOverflow)),
[]
>;

// A rotation followed by extraction can be extracted directly from the
// original tensor.
def RotatePlusExtractToIndexedExtract : Pat<
(Tensor_ExtractOp
(TensorExt_RotateOp $tensor, $shift),
(variadic $index)),
(Tensor_ExtractOp
$tensor,
(MakeSingleResultVariadic (Arith_AddIOp $shift, $index, DefOverflow)))
>;

// Rotating two tensors by the same amount can be converted to a single
// post-rotation. This can result in eliminating either the rotation (because
// it can be combined with a later rotation) or the arith op itself, if it is
// is identical to an existing arith op applied before the rotation.
foreach ArithOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in {
def FactorParallelRotationsThroughOp_#ArithOp : Pat<
(ArithOp
(TensorExt_RotateOp $t1, $i),
(TensorExt_RotateOp $t2, $i),
$ovf),
(TensorExt_RotateOp (ArithOp $t1, $t2, $ovf), $i)
>;
}

#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTCANONICALIZATION_TD_
9 changes: 6 additions & 3 deletions include/Dialect/TensorExt/IR/TensorExtDialect.td
@@ -1,5 +1,5 @@
#ifndef HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_
#define HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_
#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTDIALECT_TD_
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTDIALECT_TD_

include "mlir/IR/DialectBase.td"

Expand All @@ -12,6 +12,9 @@ def TensorExt_Dialect : Dialect {
}];

let cppNamespace = "::mlir::heir::tensor_ext";
let dependentDialects = [
"tensor::TensorDialect",
];
}

#endif // HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_
#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTDIALECT_TD_
6 changes: 3 additions & 3 deletions include/Dialect/TensorExt/IR/TensorExtOps.h
@@ -1,5 +1,5 @@
#ifndef HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_
#define HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_
#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_H_
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_H_

#include "include/Dialect/TensorExt/IR/TensorExtDialect.h"
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
Expand All @@ -8,4 +8,4 @@
#define GET_OP_CLASSES
#include "include/Dialect/TensorExt/IR/TensorExtOps.h.inc"

#endif // HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_
#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_H_
1 change: 1 addition & 0 deletions include/Dialect/TensorExt/IR/TensorExtOps.td
Expand Up @@ -35,6 +35,7 @@ def TensorExt_RotateOp : TensorExt_Op<"rotate", [Pure, AllTypesMatch<["tensor",
let arguments = (ins AnyTensor:$tensor, SignlessIntegerLike:$shift);
let results = (outs AnyTensor:$output);
let assemblyFormat = "operands attr-dict `:` qualified(type($tensor)) `,` type($shift)";
let hasCanonicalizer = 1;
}

#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_TD_
55 changes: 55 additions & 0 deletions include/Dialect/TensorExt/Transforms/BUILD
@@ -0,0 +1,55 @@
# InsertRotate tablegen and headers.

load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=TensorExt",
],
"Passes.h.inc",
),
(
["-gen-pass-doc"],
"TensorExtPasses.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Passes.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)

gentbl_cc_library(
name = "insert_rotate_inc_gen",
tbl_outs = [
(
["-gen-rewriters"],
"InsertRotate.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "InsertRotate.td",
deps = [
"@heir//include/Dialect/TensorExt/IR:ops_inc_gen",
"@heir//include/Dialect/TensorExt/IR:td_files",
"@llvm-project//mlir:ArithOpsTdFiles",
"@llvm-project//mlir:TensorOpsTdFiles",
],
)

exports_files([
"Passes.h",
"CollapseInsertionChains.h",
"InsertRotate.h",
])
17 changes: 17 additions & 0 deletions include/Dialect/TensorExt/Transforms/CollapseInsertionChains.h
@@ -0,0 +1,17 @@
#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_COLLAPSEINSERTIONCHAINS_H_
#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_COLLAPSEINSERTIONCHAINS_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace tensor_ext {

#define GEN_PASS_DECL_COLLAPSEINSERTIONCHAINS
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc"

} // namespace tensor_ext
} // namespace heir
} // namespace mlir

#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_COLLAPSEINSERTIONCHAINS_H_
17 changes: 17 additions & 0 deletions include/Dialect/TensorExt/Transforms/InsertRotate.h
@@ -0,0 +1,17 @@
#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_
#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace tensor_ext {

#define GEN_PASS_DECL_INSERTROTATE
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc"

} // namespace tensor_ext
} // namespace heir
} // namespace mlir

#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_
84 changes: 84 additions & 0 deletions include/Dialect/TensorExt/Transforms/InsertRotate.td
@@ -0,0 +1,84 @@
#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_

include "include/Dialect/TensorExt/IR/TensorExtOps.td"
include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"
include "mlir/IR/PatternBase.td"

// TODO(#512): Support target slot selection when the downstream op is an insert.

// The patterns in this file are intended to align with the automatic-SIMD
// batching heuristics from the HECO project. See section 4.4 of
// https://arxiv.org/abs/2202.01649 and the hir2hir passes in
// https://github.com/MarbleHE/HECO/blob/main/src/Passes/hir2hir/

defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;

// To understand why this is needed, see
// https://discourse.llvm.org/t/compilation-failure-with-drr-generated-pattern/77385
def MakeSingleResultVariadic: NativeCodeCall<"{ $0 }">;

// Match an arith op that extracts scalar values from two tensors, and replace
// it with rotations to align slots and apply the same op in SIMD. Other
// patterns in this file will find better alignment of adjacent rotations, and
// canonicalization patterns will remove duplicated rotations.
foreach ArithOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in {
def InsertRotations_#ArithOp : Pattern<
(ArithOp
(Tensor_ExtractOp $t1, (variadic $i1)),
(Tensor_ExtractOp $t2, (variadic $i2)),
$overflow),
[
(TensorExt_RotateOp:$r1 $t1, $i1),
(TensorExt_RotateOp:$r2 $t2, $i2),
(ArithOp:$opResult $r1, $r2, $overflow),
(Tensor_ExtractOp
$opResult,
(MakeSingleResultVariadic (Arith_ConstantOp ConstantAttr<IndexAttr, "0">))),
]
>;
}

// Pre-align the first op's operands to the index that the result is
// used for in a subsequent op.
// TODO(#514): handle OuterOp with two different InnerOps on the LHS and RHS
foreach InnerOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in {
foreach OuterOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in {
// Left associated grouping handles (add (add (rotate t1 i1) (rotate t2 i2)) (rotate t3 i3))
def AlignRotations_LeftAssociated_Inner_#InnerOp#_Outer_#OuterOp : Pattern<
(OuterOp
(InnerOp (TensorExt_RotateOp $t1, $i1), (TensorExt_RotateOp $t2, $i2), $ovf1),
(TensorExt_RotateOp $t3, $i3),
$ovf2),
[
(TensorExt_RotateOp:$r1 $t1, (Arith_SubIOp $i1, $i3, DefOverflow)),
(TensorExt_RotateOp:$r2 $t2, (Arith_SubIOp $i2, $i3, DefOverflow)),
(InnerOp:$addResult $r1, $r2, $ovf1),
(OuterOp:$output $addResult, $t3, $ovf2),
// Downstream ops are not updated by this pass, so we need to preserve the original
// rotation and then clean it up in a separate canonicalization pattern.
(TensorExt_RotateOp $output, $i3),
]
>;

// Right associated grouping handles (add (rotate t1 i1) (add (rotate t2 i2) (rotate t3 i3)))
def AlignRotations_RightAssociated_Inner_#InnerOp#_Outer_#OuterOp : Pattern<
(OuterOp
(TensorExt_RotateOp $t3, $i3),
(InnerOp (TensorExt_RotateOp $t1, $i1), (TensorExt_RotateOp $t2, $i2), $ovf1),
$ovf2),
[
(TensorExt_RotateOp:$r1 $t1, (Arith_SubIOp $i1, $i3, DefOverflow)),
(TensorExt_RotateOp:$r2 $t2, (Arith_SubIOp $i2, $i3, DefOverflow)),
(InnerOp:$addResult $r1, $r2, $ovf1),
(OuterOp:$output $addResult, $t3, $ovf2),
// Downstream ops are not updated by this pass, so we need to preserve the original
// rotation and then clean it up in a separate canonicalization pattern.
(TensorExt_RotateOp $output, $i3),
]
>;
}
}

#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_
19 changes: 19 additions & 0 deletions include/Dialect/TensorExt/Transforms/Passes.h
@@ -0,0 +1,19 @@
#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_H_
#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_H_

#include "include/Dialect/TensorExt/IR/TensorExtDialect.h"
#include "include/Dialect/TensorExt/Transforms/CollapseInsertionChains.h"
#include "include/Dialect/TensorExt/Transforms/InsertRotate.h"

namespace mlir {
namespace heir {
namespace tensor_ext {

#define GEN_PASS_REGISTRATION
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc"

} // namespace tensor_ext
} // namespace heir
} // namespace mlir

#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_H_

0 comments on commit 3ffb592

Please sign in to comment.