Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial port of HECO auto-SIMD passes for BGV #471

Merged
merged 1 commit into from Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 20 additions & 0 deletions include/Dialect/TensorExt/IR/BUILD
Expand Up @@ -11,12 +11,14 @@ exports_files(
[
"TensorExtDialect.h",
"TensorExtOps.h",
"TensorExtPatterns.h",
],
)

td_library(
name = "td_files",
srcs = [
"TensorExtCanonicalization.td",
"TensorExtDialect.td",
"TensorExtOps.td",
],
Expand Down Expand Up @@ -77,3 +79,21 @@ gentbl_cc_library(
"@heir//include/Dialect/Polynomial/IR:td_files",
],
)

gentbl_cc_library(
name = "canonicalize_inc_gen",
tbl_outs = [
(
["-gen-rewriters"],
"TensorExtCanonicalization.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "TensorExtCanonicalization.td",
deps = [
":ops_inc_gen",
":td_files",
"@llvm-project//mlir:ArithOpsTdFiles",
"@llvm-project//mlir:TensorOpsTdFiles",
],
)
88 changes: 88 additions & 0 deletions include/Dialect/TensorExt/IR/TensorExtCanonicalization.td
@@ -0,0 +1,88 @@
#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTCANONICALIZATION_TD_
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTCANONICALIZATION_TD_

include "TensorExtOps.td"
include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"
include "mlir/IR/PatternBase.td"

// TODO(#515): refactor these helpers to a common file with InsertRotate.td
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;

def MakeSingleResultVariadic: NativeCodeCall<"{ $0 }">;

def CreateIndexCastOp : NativeCodeCall<
"$_builder.create<arith::IndexCastOp>($0.getLoc(), $1.getType(), $0)">;

def IsZero :
Constraint<
CPred<"llvm::cast<mlir::IntegerAttr>($0).getValue().isZero()">>;

def OutOfBoundsOfTensorDim :
Constraint<
CPred<
"llvm::cast<mlir::IntegerAttr>($0).getValue().getSExtValue() < 0 "
"|| llvm::cast<mlir::IntegerAttr>($0).getValue().getSExtValue() > "
"llvm::cast<mlir::ShapedType>($1.getType()).getShape()[0]"
>
>;

// rotate %t, 0 -> %t
def DropZeroRotation : Pat<
(TensorExt_RotateOp $tensor, (ConstantLikeMatcher APIntAttr:$c0)),
(replaceWithValue $tensor),
[(IsZero $c0)]
>;

// rotate %t, x -> rotate %t, x mod size
def NormalizeRotationIndex : Pat<
(TensorExt_RotateOp $tensor, (Arith_ConstantOp:$shiftOp APIntAttr:$shiftAmount)),
(TensorExt_RotateOp $tensor,
(Arith_RemUIOp
$shiftOp,
// Only works for 1D tensors: index is taken modulo the tensor length,
j2kun marked this conversation as resolved.
Show resolved Hide resolved
// i.e., dim 0
(CreateIndexCastOp
(Tensor_DimOp $tensor, (Arith_ConstantOp ConstantAttr<IndexAttr, "0">)),
$shiftOp))
),
[(OutOfBoundsOfTensorDim $shiftAmount, $tensor)]
>;

// %0 = rotate %t, x
// %1 = rotate %0, y
// ---> rotate %t (x+y)
def CombineSequentialRotates : Pat<
(TensorExt_RotateOp
(TensorExt_RotateOp $tensor, (Arith_ConstantOp:$xOp APIntAttr:$x)),
(Arith_ConstantOp:$yOp APIntAttr:$y)),
(TensorExt_RotateOp $tensor, (Arith_AddIOp $xOp, $yOp, DefOverflow)),
[]
>;

// A rotation followed by extraction can be extracted directly from the
// original tensor.
def RotatePlusExtractToIndexedExtract : Pat<
(Tensor_ExtractOp
(TensorExt_RotateOp $tensor, $shift),
(variadic $index)),
(Tensor_ExtractOp
$tensor,
(MakeSingleResultVariadic (Arith_AddIOp $shift, $index, DefOverflow)))
>;

// Rotating two tensors by the same amount can be converted to a single
// post-rotation. This can result in eliminating either the rotation (because
// it can be combined with a later rotation) or the arith op itself, if it is
// is identical to an existing arith op applied before the rotation.
foreach ArithOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in {
def FactorParallelRotationsThroughOp_#ArithOp : Pat<
(ArithOp
(TensorExt_RotateOp $t1, $i),
(TensorExt_RotateOp $t2, $i),
$ovf),
(TensorExt_RotateOp (ArithOp $t1, $t2, $ovf), $i)
>;
}

#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTCANONICALIZATION_TD_
9 changes: 6 additions & 3 deletions include/Dialect/TensorExt/IR/TensorExtDialect.td
@@ -1,5 +1,5 @@
#ifndef HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_
#define HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_
#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTDIALECT_TD_
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTDIALECT_TD_

include "mlir/IR/DialectBase.td"

Expand All @@ -12,6 +12,9 @@ def TensorExt_Dialect : Dialect {
}];

let cppNamespace = "::mlir::heir::tensor_ext";
let dependentDialects = [
"tensor::TensorDialect",
];
}

#endif // HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_
#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTDIALECT_TD_
6 changes: 3 additions & 3 deletions include/Dialect/TensorExt/IR/TensorExtOps.h
@@ -1,5 +1,5 @@
#ifndef HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_
#define HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_
#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_H_
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_H_

#include "include/Dialect/TensorExt/IR/TensorExtDialect.h"
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
Expand All @@ -8,4 +8,4 @@
#define GET_OP_CLASSES
#include "include/Dialect/TensorExt/IR/TensorExtOps.h.inc"

#endif // HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtOPS_H_
#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_H_
1 change: 1 addition & 0 deletions include/Dialect/TensorExt/IR/TensorExtOps.td
Expand Up @@ -35,6 +35,7 @@ def TensorExt_RotateOp : TensorExt_Op<"rotate", [Pure, AllTypesMatch<["tensor",
let arguments = (ins AnyTensor:$tensor, SignlessIntegerLike:$shift);
let results = (outs AnyTensor:$output);
let assemblyFormat = "operands attr-dict `:` qualified(type($tensor)) `,` type($shift)";
let hasCanonicalizer = 1;
}

#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_TD_
55 changes: 55 additions & 0 deletions include/Dialect/TensorExt/Transforms/BUILD
@@ -0,0 +1,55 @@
# InsertRotate tablegen and headers.

load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=TensorExt",
],
"Passes.h.inc",
),
(
["-gen-pass-doc"],
"TensorExtPasses.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Passes.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)

gentbl_cc_library(
name = "insert_rotate_inc_gen",
tbl_outs = [
(
["-gen-rewriters"],
"InsertRotate.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "InsertRotate.td",
deps = [
"@heir//include/Dialect/TensorExt/IR:ops_inc_gen",
"@heir//include/Dialect/TensorExt/IR:td_files",
"@llvm-project//mlir:ArithOpsTdFiles",
"@llvm-project//mlir:TensorOpsTdFiles",
],
)

exports_files([
"Passes.h",
"CollapseInsertionChains.h",
"InsertRotate.h",
])
17 changes: 17 additions & 0 deletions include/Dialect/TensorExt/Transforms/CollapseInsertionChains.h
@@ -0,0 +1,17 @@
#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_COLLAPSEINSERTIONCHAINS_H_
#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_COLLAPSEINSERTIONCHAINS_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace tensor_ext {

#define GEN_PASS_DECL_COLLAPSEINSERTIONCHAINS
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc"

} // namespace tensor_ext
} // namespace heir
} // namespace mlir

#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_COLLAPSEINSERTIONCHAINS_H_
17 changes: 17 additions & 0 deletions include/Dialect/TensorExt/Transforms/InsertRotate.h
@@ -0,0 +1,17 @@
#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_
#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace tensor_ext {

#define GEN_PASS_DECL_INSERTROTATE
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc"

} // namespace tensor_ext
} // namespace heir
} // namespace mlir

#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_
84 changes: 84 additions & 0 deletions include/Dialect/TensorExt/Transforms/InsertRotate.td
@@ -0,0 +1,84 @@
#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_

include "include/Dialect/TensorExt/IR/TensorExtOps.td"
include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"
include "mlir/IR/PatternBase.td"

// TODO(#512): Support target slot selection when the downstream op is an insert.

// The patterns in this file are intended to align with the automatic-SIMD
// batching heuristics from the HECO project. See section 4.4 of
// https://arxiv.org/abs/2202.01649 and the hir2hir passes in
// https://github.com/MarbleHE/HECO/blob/main/src/Passes/hir2hir/

defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;

// To understand why this is needed, see
// https://discourse.llvm.org/t/compilation-failure-with-drr-generated-pattern/77385
def MakeSingleResultVariadic: NativeCodeCall<"{ $0 }">;

// Match an arith op that extracts scalar values from two tensors, and replace
// it with rotations to align slots and apply the same op in SIMD. Other
// patterns in this file will find better alignment of adjacent rotations, and
// canonicalization patterns will remove duplicated rotations.
foreach ArithOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in {
def InsertRotations_#ArithOp : Pattern<
(ArithOp
(Tensor_ExtractOp $t1, (variadic $i1)),
(Tensor_ExtractOp $t2, (variadic $i2)),
$overflow),
[
(TensorExt_RotateOp:$r1 $t1, $i1),
(TensorExt_RotateOp:$r2 $t2, $i2),
(ArithOp:$opResult $r1, $r2, $overflow),
(Tensor_ExtractOp
$opResult,
(MakeSingleResultVariadic (Arith_ConstantOp ConstantAttr<IndexAttr, "0">))),
]
>;
}

// Pre-align the first op's operands to the index that the result is
// used for in a subsequent op.
// TODO(#514): handle OuterOp with two different InnerOps on the LHS and RHS
foreach InnerOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in {
foreach OuterOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in {
// Left associated grouping handles (add (add (rotate t1 i1) (rotate t2 i2)) (rotate t3 i3))
def AlignRotations_LeftAssociated_Inner_#InnerOp#_Outer_#OuterOp : Pattern<
(OuterOp
(InnerOp (TensorExt_RotateOp $t1, $i1), (TensorExt_RotateOp $t2, $i2), $ovf1),
(TensorExt_RotateOp $t3, $i3),
$ovf2),
[
(TensorExt_RotateOp:$r1 $t1, (Arith_SubIOp $i1, $i3, DefOverflow)),
(TensorExt_RotateOp:$r2 $t2, (Arith_SubIOp $i2, $i3, DefOverflow)),
(InnerOp:$addResult $r1, $r2, $ovf1),
(OuterOp:$output $addResult, $t3, $ovf2),
// Downstream ops are not updated by this pass, so we need to preserve the original
// rotation and then clean it up in a separate canonicalization pattern.
(TensorExt_RotateOp $output, $i3),
]
>;

// Right associated grouping handles (add (rotate t1 i1) (add (rotate t2 i2) (rotate t3 i3)))
def AlignRotations_RightAssociated_Inner_#InnerOp#_Outer_#OuterOp : Pattern<
(OuterOp
(TensorExt_RotateOp $t3, $i3),
(InnerOp (TensorExt_RotateOp $t1, $i1), (TensorExt_RotateOp $t2, $i2), $ovf1),
$ovf2),
[
(TensorExt_RotateOp:$r1 $t1, (Arith_SubIOp $i1, $i3, DefOverflow)),
(TensorExt_RotateOp:$r2 $t2, (Arith_SubIOp $i2, $i3, DefOverflow)),
(InnerOp:$addResult $r1, $r2, $ovf1),
(OuterOp:$output $addResult, $t3, $ovf2),
// Downstream ops are not updated by this pass, so we need to preserve the original
// rotation and then clean it up in a separate canonicalization pattern.
(TensorExt_RotateOp $output, $i3),
]
>;
}
}

#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_
19 changes: 19 additions & 0 deletions include/Dialect/TensorExt/Transforms/Passes.h
@@ -0,0 +1,19 @@
#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_H_
#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_H_

#include "include/Dialect/TensorExt/IR/TensorExtDialect.h"
#include "include/Dialect/TensorExt/Transforms/CollapseInsertionChains.h"
#include "include/Dialect/TensorExt/Transforms/InsertRotate.h"

namespace mlir {
namespace heir {
namespace tensor_ext {

#define GEN_PASS_REGISTRATION
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc"

} // namespace tensor_ext
} // namespace heir
} // namespace mlir

#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_H_