Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor DRR helpers into a Utils.td file #518

Merged
merged 1 commit into from Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 16 additions & 0 deletions include/DRR/BUILD
@@ -0,0 +1,16 @@
# Tablegen helpers for declarative rewrite (DRR) patterns
load("@llvm-project//mlir:tblgen.bzl", "td_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

td_library(
name = "DRR",
srcs = [
"Utils.td",
],
# include from the heir-root to enable fully-qualified include-paths
includes = ["../../.."],
)
27 changes: 27 additions & 0 deletions include/DRR/Utils.td
@@ -0,0 +1,27 @@

include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/PatternBase.td"

// When using DRR to define new artih ops, one must include an attribute
// describing the overflow semantics. When the new arith op is built from an
// existing op, one should match the overflow attribute and propagate it, but
// when the op is built from scratch, one needs to pass an overflow choice, and
// if there is no reason to pick a particular overflow type, use this helper as
// a default.
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;

// When constructing a new variadic op, and you want to pass it a single argument
// as its variadic input, use this helper. To understand why this is needed, see
// https://discourse.llvm.org/t/compilation-failure-with-drr-generated-pattern/77385
def MakeSingleResultVariadic: NativeCodeCall<"{ $0 }">;

// Determine if an input integer attr is zero
def IsZeroIntAttr :
Constraint<
CPred<"llvm::cast<mlir::IntegerAttr>($0).getValue().isZero()">>;

// Construct an index.cast op, which is necessary here because the default
// builder is not compatible with DRR as of 2024-03-11.
def CreateIndexCastOp : NativeCodeCall<
"$_builder.create<arith::IndexCastOp>($0.getLoc(), $1.getType(), $0)">;
1 change: 1 addition & 0 deletions include/Dialect/TensorExt/IR/BUILD
Expand Up @@ -25,6 +25,7 @@ td_library(
# include from the heir-root to enable fully-qualified include-paths
includes = ["../../../.."],
deps = [
"@heir//include/DRR",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
Expand Down
15 changes: 2 additions & 13 deletions include/Dialect/TensorExt/IR/TensorExtCanonicalization.td
Expand Up @@ -2,22 +2,11 @@
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTCANONICALIZATION_TD_

include "TensorExtOps.td"
include "include/DRR/Utils.td"
include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"
include "mlir/IR/PatternBase.td"

// TODO(#515): refactor these helpers to a common file with InsertRotate.td
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;

def MakeSingleResultVariadic: NativeCodeCall<"{ $0 }">;

def CreateIndexCastOp : NativeCodeCall<
"$_builder.create<arith::IndexCastOp>($0.getLoc(), $1.getType(), $0)">;

def IsZero :
Constraint<
CPred<"llvm::cast<mlir::IntegerAttr>($0).getValue().isZero()">>;

def OutOfBoundsOfTensorDim :
Constraint<
CPred<
Expand All @@ -31,7 +20,7 @@ def OutOfBoundsOfTensorDim :
def DropZeroRotation : Pat<
(TensorExt_RotateOp $tensor, (ConstantLikeMatcher APIntAttr:$c0)),
(replaceWithValue $tensor),
[(IsZero $c0)]
[(IsZeroIntAttr $c0)]
>;

// rotate %t, x -> rotate %t, x mod size
Expand Down
1 change: 1 addition & 0 deletions include/Dialect/TensorExt/Transforms/BUILD
Expand Up @@ -41,6 +41,7 @@ gentbl_cc_library(
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "InsertRotate.td",
deps = [
"@heir//include/DRR",
"@heir//include/Dialect/TensorExt/IR:ops_inc_gen",
"@heir//include/Dialect/TensorExt/IR:td_files",
"@llvm-project//mlir:ArithOpsTdFiles",
Expand Down
7 changes: 1 addition & 6 deletions include/Dialect/TensorExt/Transforms/InsertRotate.td
@@ -1,6 +1,7 @@
#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_

include "include/DRR/Utils.td"
include "include/Dialect/TensorExt/IR/TensorExtOps.td"
include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"
Expand All @@ -13,12 +14,6 @@ include "mlir/IR/PatternBase.td"
// https://arxiv.org/abs/2202.01649 and the hir2hir passes in
// https://github.com/MarbleHE/HECO/blob/main/src/Passes/hir2hir/

defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;

// To understand why this is needed, see
// https://discourse.llvm.org/t/compilation-failure-with-drr-generated-pattern/77385
def MakeSingleResultVariadic: NativeCodeCall<"{ $0 }">;

// Match an arith op that extracts scalar values from two tensors, and replace
// it with rotations to align slots and apply the same op in SIMD. Other
// patterns in this file will find better alignment of adjacent rotations, and
Expand Down