From 34bd9180be82f7ae254f0673e07022b32be14adb Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Mon, 11 Mar 2024 12:21:16 -0700 Subject: [PATCH] Centralize DRR helpers --- include/DRR/BUILD | 16 +++++++++++ include/DRR/Utils.td | 27 +++++++++++++++++++ include/Dialect/TensorExt/IR/BUILD | 1 + .../TensorExt/IR/TensorExtCanonicalization.td | 15 ++--------- include/Dialect/TensorExt/Transforms/BUILD | 1 + .../TensorExt/Transforms/InsertRotate.td | 7 +---- 6 files changed, 48 insertions(+), 19 deletions(-) create mode 100644 include/DRR/BUILD create mode 100644 include/DRR/Utils.td diff --git a/include/DRR/BUILD b/include/DRR/BUILD new file mode 100644 index 000000000..42dc67ee6 --- /dev/null +++ b/include/DRR/BUILD @@ -0,0 +1,16 @@ +# Tablegen helpers for declarative rewrite (DRR) patterns +load("@llvm-project//mlir:tblgen.bzl", "td_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +td_library( + name = "DRR", + srcs = [ + "Utils.td", + ], + # include from the heir-root to enable fully-qualified include-paths + includes = ["../../.."], +) diff --git a/include/DRR/Utils.td b/include/DRR/Utils.td new file mode 100644 index 000000000..8d7ba558e --- /dev/null +++ b/include/DRR/Utils.td @@ -0,0 +1,27 @@ + +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/PatternBase.td" + +// When using DRR to define new artih ops, one must include an attribute +// describing the overflow semantics. When the new arith op is built from an +// existing op, one should match the overflow attribute and propagate it, but +// when the op is built from scratch, one needs to pass an overflow choice, and +// if there is no reason to pick a particular overflow type, use this helper as +// a default. +defvar DefOverflow = ConstantEnumCase; + +// When constructing a new variadic op, and you want to pass it a single argument +// as its variadic input, use this helper. To understand why this is needed, see +// https://discourse.llvm.org/t/compilation-failure-with-drr-generated-pattern/77385 +def MakeSingleResultVariadic: NativeCodeCall<"{ $0 }">; + +// Determine if an input integer attr is zero +def IsZeroIntAttr : + Constraint< + CPred<"llvm::cast($0).getValue().isZero()">>; + +// Construct an index.cast op, which is necessary here because the default +// builder is not compatible with DRR as of 2024-03-11. +def CreateIndexCastOp : NativeCodeCall< + "$_builder.create($0.getLoc(), $1.getType(), $0)">; diff --git a/include/Dialect/TensorExt/IR/BUILD b/include/Dialect/TensorExt/IR/BUILD index f2ac37599..9b617eea2 100644 --- a/include/Dialect/TensorExt/IR/BUILD +++ b/include/Dialect/TensorExt/IR/BUILD @@ -25,6 +25,7 @@ td_library( # include from the heir-root to enable fully-qualified include-paths includes = ["../../../.."], deps = [ + "@heir//include/DRR", "@llvm-project//mlir:BuiltinDialectTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", "@llvm-project//mlir:OpBaseTdFiles", diff --git a/include/Dialect/TensorExt/IR/TensorExtCanonicalization.td b/include/Dialect/TensorExt/IR/TensorExtCanonicalization.td index 5b308d13e..5948931c9 100644 --- a/include/Dialect/TensorExt/IR/TensorExtCanonicalization.td +++ b/include/Dialect/TensorExt/IR/TensorExtCanonicalization.td @@ -2,22 +2,11 @@ #define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTCANONICALIZATION_TD_ include "TensorExtOps.td" +include "include/DRR/Utils.td" include "mlir/Dialect/Arith/IR/ArithOps.td" include "mlir/Dialect/Tensor/IR/TensorOps.td" include "mlir/IR/PatternBase.td" -// TODO(#515): refactor these helpers to a common file with InsertRotate.td -defvar DefOverflow = ConstantEnumCase; - -def MakeSingleResultVariadic: NativeCodeCall<"{ $0 }">; - -def CreateIndexCastOp : NativeCodeCall< - "$_builder.create($0.getLoc(), $1.getType(), $0)">; - -def IsZero : - Constraint< - CPred<"llvm::cast($0).getValue().isZero()">>; - def OutOfBoundsOfTensorDim : Constraint< CPred< @@ -31,7 +20,7 @@ def OutOfBoundsOfTensorDim : def DropZeroRotation : Pat< (TensorExt_RotateOp $tensor, (ConstantLikeMatcher APIntAttr:$c0)), (replaceWithValue $tensor), - [(IsZero $c0)] + [(IsZeroIntAttr $c0)] >; // rotate %t, x -> rotate %t, x mod size diff --git a/include/Dialect/TensorExt/Transforms/BUILD b/include/Dialect/TensorExt/Transforms/BUILD index 574ad102f..814553ce4 100644 --- a/include/Dialect/TensorExt/Transforms/BUILD +++ b/include/Dialect/TensorExt/Transforms/BUILD @@ -41,6 +41,7 @@ gentbl_cc_library( tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "InsertRotate.td", deps = [ + "@heir//include/DRR", "@heir//include/Dialect/TensorExt/IR:ops_inc_gen", "@heir//include/Dialect/TensorExt/IR:td_files", "@llvm-project//mlir:ArithOpsTdFiles", diff --git a/include/Dialect/TensorExt/Transforms/InsertRotate.td b/include/Dialect/TensorExt/Transforms/InsertRotate.td index 6ffc2a539..7e977f1fb 100644 --- a/include/Dialect/TensorExt/Transforms/InsertRotate.td +++ b/include/Dialect/TensorExt/Transforms/InsertRotate.td @@ -1,6 +1,7 @@ #ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_ #define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_ +include "include/DRR/Utils.td" include "include/Dialect/TensorExt/IR/TensorExtOps.td" include "mlir/Dialect/Arith/IR/ArithOps.td" include "mlir/Dialect/Tensor/IR/TensorOps.td" @@ -13,12 +14,6 @@ include "mlir/IR/PatternBase.td" // https://arxiv.org/abs/2202.01649 and the hir2hir passes in // https://github.com/MarbleHE/HECO/blob/main/src/Passes/hir2hir/ -defvar DefOverflow = ConstantEnumCase; - -// To understand why this is needed, see -// https://discourse.llvm.org/t/compilation-failure-with-drr-generated-pattern/77385 -def MakeSingleResultVariadic: NativeCodeCall<"{ $0 }">; - // Match an arith op that extracts scalar values from two tensors, and replace // it with rotations to align slots and apply the same op in SIMD. Other // patterns in this file will find better alignment of adjacent rotations, and