diff --git a/include/Dialect/TensorExt/IR/TensorExtCanonicalization.td b/include/Dialect/TensorExt/IR/TensorExtCanonicalization.td index 20563d1be..5226b9c9c 100644 --- a/include/Dialect/TensorExt/IR/TensorExtCanonicalization.td +++ b/include/Dialect/TensorExt/IR/TensorExtCanonicalization.td @@ -11,6 +11,16 @@ defvar DefOverflow = ConstantEnumCase; def MakeSingleResultVariadic: NativeCodeCall<"{ $0 }">; +def IndexType : NativeCodeCall<"$_builder.getIndexType()">; + +def MakeMatchingTypedZero + : NativeCodeCall< + "$_builder.getIntegerAttr(" + "$_builder.getIntegerType(" + "$0.getType().getIntOrFloatBitWidth()" + "), 0)">; + + def IsZero : Constraint< CPred<"llvm::cast($0).getValue().isZero()">>; @@ -20,6 +30,15 @@ def AreOpposites : CPred<"llvm::cast($0).getValue() " "== llvm::cast($1).getValue()">>; +def OutOfBoundsOfTensorDim : + Constraint< + CPred< + "llvm::cast($0).getValue().getSExtValue() < 0 " + "|| llvm::cast($0).getValue().getSExtValue() > " + "llvm::cast($1.getType()).getShape()[0]" + > + >; + // rotate %t, 0 -> %t def DropZeroRotation : Pat< (TensorExt_RotateOp $tensor, (ConstantLikeMatcher APIntAttr:$c0)), @@ -27,6 +46,20 @@ def DropZeroRotation : Pat< [(IsZero $c0)] >; +// rotate %t, x -> rotate %t, x mod size +def NormalizeRotationIndex : Pat< + (TensorExt_RotateOp $tensor, (Arith_ConstantOp:$shiftOp APIntAttr:$shiftAmount)), + (TensorExt_RotateOp $tensor, + (Arith_RemUIOp + $shiftOp, + // Only works for 1D tensors: index is taken modulo the tensor length, + // i.e., dim 0 + (Arith_IndexCastOp + (Tensor_DimOp $tensor, (Arith_ConstantOp ConstantAttr))) + )), + [(OutOfBoundsOfTensorDim $shiftAmount, $tensor)] +>; + // rotate %t, x + rotate %t y -> rotate %t (x+y) def SumOfConstantRotations : Pat< (TensorExt_RotateOp diff --git a/include/Dialect/TensorExt/IR/TensorExtDialect.td b/include/Dialect/TensorExt/IR/TensorExtDialect.td index 218d00d6e..cc8ae4502 100644 --- a/include/Dialect/TensorExt/IR/TensorExtDialect.td +++ b/include/Dialect/TensorExt/IR/TensorExtDialect.td @@ -1,5 +1,5 @@ -#ifndef HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_ -#define HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_ +#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTDIALECT_TD_ +#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTDIALECT_TD_ include "mlir/IR/DialectBase.td" @@ -12,6 +12,9 @@ def TensorExt_Dialect : Dialect { }]; let cppNamespace = "::mlir::heir::tensor_ext"; + let dependentDialects = [ + "tensor::TensorDialect", + ]; } -#endif // HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_ +#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTDIALECT_TD_ diff --git a/lib/Dialect/TensorExt/IR/TensorExtDialect.cpp b/lib/Dialect/TensorExt/IR/TensorExtDialect.cpp index 25f9e41c6..fd0dcfdf7 100644 --- a/lib/Dialect/TensorExt/IR/TensorExtDialect.cpp +++ b/lib/Dialect/TensorExt/IR/TensorExtDialect.cpp @@ -1,5 +1,6 @@ #include "include/Dialect/TensorExt/IR/TensorExtDialect.h" +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project // NOLINTNEXTLINE(misc-include-cleaner): Required to define TensorExtOps diff --git a/tests/tensor_ext/canonicalize.mlir b/tests/tensor_ext/canonicalize.mlir index 6c0263b7f..06f775484 100644 --- a/tests/tensor_ext/canonicalize.mlir +++ b/tests/tensor_ext/canonicalize.mlir @@ -11,3 +11,13 @@ func.func @test_sum_rotation_indices(%0: tensor<16xi32>) -> tensor<16xi32> { %2 = tensor_ext.rotate %1, %c2 : tensor<16xi32>, i32 return %2 : tensor<16xi32> } + +// CHECK-LABEL: @test_sum_rotation_indices +// CHECK: %[[c3:.*]] = arith.constant 3 : i32 +// CHECK: tensor_ext.rotate +// CHECK-SAME: %[[c3]] +func.func @test_normalize_negative(%0: tensor<16xi32>) -> tensor<16xi32> { + %c1 = arith.constant -13 : i32 + %1 = tensor_ext.rotate %0, %c1 : tensor<16xi32>, i32 + return %1 : tensor<16xi32> +}