Skip to content

Commit

Permalink
Add pattern to normalize rotation indices
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Mar 7, 2024
1 parent ccb7fc4 commit 45ca376
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 3 deletions.
35 changes: 35 additions & 0 deletions include/Dialect/TensorExt/IR/TensorExtCanonicalization.td
Expand Up @@ -11,6 +11,17 @@ defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;

def MakeSingleResultVariadic: NativeCodeCall<"{ $0 }">;

def IndexType : NativeCodeCall<"$_builder.getIndexType()">;

def MakeMatchingTypedZero : NativeCodeCall<
"$_builder.getIntegerAttr("
"$_builder.getIntegerType("
"$0.getType().getIntOrFloatBitWidth()"
"), 0)">;

def CreateIndexCastOp : NativeCodeCall<
"$_builder.create<arith::IndexCastOp>($0.getLoc(), $1.getType(), $0)">;

def IsZero :
Constraint<
CPred<"llvm::cast<mlir::IntegerAttr>($0).getValue().isZero()">>;
Expand All @@ -20,13 +31,37 @@ def AreOpposites :
CPred<"llvm::cast<mlir::IntegerAttr>($0).getValue() "
"== llvm::cast<mlir::IntegerAttr>($1).getValue()">>;

def OutOfBoundsOfTensorDim :
Constraint<
CPred<
"llvm::cast<mlir::IntegerAttr>($0).getValue().getSExtValue() < 0 "
"|| llvm::cast<mlir::IntegerAttr>($0).getValue().getSExtValue() > "
"llvm::cast<mlir::ShapedType>($1.getType()).getShape()[0]"
>
>;

// rotate %t, 0 -> %t
def DropZeroRotation : Pat<
(TensorExt_RotateOp $tensor, (ConstantLikeMatcher APIntAttr:$c0)),
(replaceWithValue $tensor),
[(IsZero $c0)]
>;

// rotate %t, x -> rotate %t, x mod size
def NormalizeRotationIndex : Pat<
(TensorExt_RotateOp $tensor, (Arith_ConstantOp:$shiftOp APIntAttr:$shiftAmount)),
(TensorExt_RotateOp $tensor,
(Arith_RemUIOp
$shiftOp,
// Only works for 1D tensors: index is taken modulo the tensor length,
// i.e., dim 0
(CreateIndexCastOp
(Tensor_DimOp $tensor, (Arith_ConstantOp ConstantAttr<IndexAttr, "0">)),
$shiftOp))
),
[(OutOfBoundsOfTensorDim $shiftAmount, $tensor)]
>;

// rotate %t, x + rotate %t y -> rotate %t (x+y)
def SumOfConstantRotations : Pat<
(TensorExt_RotateOp
Expand Down
9 changes: 6 additions & 3 deletions include/Dialect/TensorExt/IR/TensorExtDialect.td
@@ -1,5 +1,5 @@
#ifndef HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_
#define HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_
#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTDIALECT_TD_
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTDIALECT_TD_

include "mlir/IR/DialectBase.td"

Expand All @@ -12,6 +12,9 @@ def TensorExt_Dialect : Dialect {
}];

let cppNamespace = "::mlir::heir::tensor_ext";
let dependentDialects = [
"tensor::TensorDialect",
];
}

#endif // HEIR_INCLUDE_DIALECT_TensorExt_IR_TensorExtDIALECT_TD_
#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTDIALECT_TD_
1 change: 1 addition & 0 deletions lib/Dialect/TensorExt/IR/BUILD
Expand Up @@ -20,6 +20,7 @@ cc_library(
"@heir//include/Dialect/TensorExt/IR:ops_inc_gen",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:TensorDialect",
],
)

Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TensorExt/IR/TensorExtDialect.cpp
@@ -1,5 +1,6 @@
#include "include/Dialect/TensorExt/IR/TensorExtDialect.h"

#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project

// NOLINTNEXTLINE(misc-include-cleaner): Required to define TensorExtOps
Expand Down
10 changes: 10 additions & 0 deletions tests/tensor_ext/canonicalize.mlir
Expand Up @@ -11,3 +11,13 @@ func.func @test_sum_rotation_indices(%0: tensor<16xi32>) -> tensor<16xi32> {
%2 = tensor_ext.rotate %1, %c2 : tensor<16xi32>, i32
return %2 : tensor<16xi32>
}

// CHECK-LABEL: @test_normalize_negative
// CHECK: %[[c3:.*]] = arith.constant 3 : i32
// CHECK: tensor_ext.rotate
// CHECK-SAME: %[[c3]]
func.func @test_normalize_negative(%0: tensor<16xi32>) -> tensor<16xi32> {
%c1 = arith.constant -13 : i32
%1 = tensor_ext.rotate %0, %c1 : tensor<16xi32>, i32
return %1 : tensor<16xi32>
}

0 comments on commit 45ca376

Please sign in to comment.