Skip to content

Commit

Permalink
Merge pull request #518 from j2kun:drr-refactor
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615040440
  • Loading branch information
Copybara-Service committed Mar 12, 2024
2 parents bd49e73 + 34bd918 commit a5fcef5
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 19 deletions.
16 changes: 16 additions & 0 deletions include/DRR/BUILD
@@ -0,0 +1,16 @@
# Tablegen helpers for declarative rewrite (DRR) patterns
load("@llvm-project//mlir:tblgen.bzl", "td_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

td_library(
name = "DRR",
srcs = [
"Utils.td",
],
# include from the heir-root to enable fully-qualified include-paths
includes = ["../../.."],
)
27 changes: 27 additions & 0 deletions include/DRR/Utils.td
@@ -0,0 +1,27 @@

include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/PatternBase.td"

// When using DRR to define new artih ops, one must include an attribute
// describing the overflow semantics. When the new arith op is built from an
// existing op, one should match the overflow attribute and propagate it, but
// when the op is built from scratch, one needs to pass an overflow choice, and
// if there is no reason to pick a particular overflow type, use this helper as
// a default.
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;

// When constructing a new variadic op, and you want to pass it a single argument
// as its variadic input, use this helper. To understand why this is needed, see
// https://discourse.llvm.org/t/compilation-failure-with-drr-generated-pattern/77385
def MakeSingleResultVariadic: NativeCodeCall<"{ $0 }">;

// Determine if an input integer attr is zero
def IsZeroIntAttr :
Constraint<
CPred<"llvm::cast<mlir::IntegerAttr>($0).getValue().isZero()">>;

// Construct an index.cast op, which is necessary here because the default
// builder is not compatible with DRR as of 2024-03-11.
def CreateIndexCastOp : NativeCodeCall<
"$_builder.create<arith::IndexCastOp>($0.getLoc(), $1.getType(), $0)">;
1 change: 1 addition & 0 deletions include/Dialect/TensorExt/IR/BUILD
Expand Up @@ -25,6 +25,7 @@ td_library(
# include from the heir-root to enable fully-qualified include-paths
includes = ["../../../.."],
deps = [
"@heir//include/DRR",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
Expand Down
15 changes: 2 additions & 13 deletions include/Dialect/TensorExt/IR/TensorExtCanonicalization.td
Expand Up @@ -2,22 +2,11 @@
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTCANONICALIZATION_TD_

include "TensorExtOps.td"
include "include/DRR/Utils.td"
include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"
include "mlir/IR/PatternBase.td"

// TODO(#515): refactor these helpers to a common file with InsertRotate.td
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;

def MakeSingleResultVariadic: NativeCodeCall<"{ $0 }">;

def CreateIndexCastOp : NativeCodeCall<
"$_builder.create<arith::IndexCastOp>($0.getLoc(), $1.getType(), $0)">;

def IsZero :
Constraint<
CPred<"llvm::cast<mlir::IntegerAttr>($0).getValue().isZero()">>;

def OutOfBoundsOfTensorDim :
Constraint<
CPred<
Expand All @@ -31,7 +20,7 @@ def OutOfBoundsOfTensorDim :
def DropZeroRotation : Pat<
(TensorExt_RotateOp $tensor, (ConstantLikeMatcher APIntAttr:$c0)),
(replaceWithValue $tensor),
[(IsZero $c0)]
[(IsZeroIntAttr $c0)]
>;

// rotate %t, x -> rotate %t, x mod size
Expand Down
1 change: 1 addition & 0 deletions include/Dialect/TensorExt/Transforms/BUILD
Expand Up @@ -41,6 +41,7 @@ gentbl_cc_library(
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "InsertRotate.td",
deps = [
"@heir//include/DRR",
"@heir//include/Dialect/TensorExt/IR:ops_inc_gen",
"@heir//include/Dialect/TensorExt/IR:td_files",
"@llvm-project//mlir:ArithOpsTdFiles",
Expand Down
7 changes: 1 addition & 6 deletions include/Dialect/TensorExt/Transforms/InsertRotate.td
@@ -1,6 +1,7 @@
#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_

include "include/DRR/Utils.td"
include "include/Dialect/TensorExt/IR/TensorExtOps.td"
include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"
Expand All @@ -13,12 +14,6 @@ include "mlir/IR/PatternBase.td"
// https://arxiv.org/abs/2202.01649 and the hir2hir passes in
// https://github.com/MarbleHE/HECO/blob/main/src/Passes/hir2hir/

defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;

// To understand why this is needed, see
// https://discourse.llvm.org/t/compilation-failure-with-drr-generated-pattern/77385
def MakeSingleResultVariadic: NativeCodeCall<"{ $0 }">;

// Match an arith op that extracts scalar values from two tensors, and replace
// it with rotations to align slots and apply the same op in SIMD. Other
// patterns in this file will find better alignment of adjacent rotations, and
Expand Down

0 comments on commit a5fcef5

Please sign in to comment.