Skip to content

Commit

Permalink
refactor get1DExtractionIndex into util file
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Mar 13, 2024
1 parent a96a312 commit cf89039
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 19 deletions.
16 changes: 16 additions & 0 deletions include/Dialect/BUILD
Expand Up @@ -10,6 +10,7 @@ package(
exports_files(
[
"HEIRInterfaces.h",
"Utils.h",
],
)

Expand Down Expand Up @@ -43,3 +44,18 @@ gentbl_cc_library(
":td_files",
],
)

cc_library(
name = "Utils",
srcs = [
"Utils.h",
],
hdrs = [
"Utils.h",
],
deps = [
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
12 changes: 12 additions & 0 deletions lib/Dialect/BUILD
Expand Up @@ -18,3 +18,15 @@ cc_library(
"@llvm-project//mlir:IR",
],
)

cc_library(
name = "Utils",
srcs = [
"Utils.h",
],
deps = [
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
1 change: 1 addition & 0 deletions lib/Dialect/TensorExt/Transforms/BUILD
Expand Up @@ -45,6 +45,7 @@ cc_library(
],
deps = [
"@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen",
"@heir//lib/Dialect:Utils",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
Expand Down
20 changes: 1 addition & 19 deletions lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp
Expand Up @@ -4,6 +4,7 @@
#include <utility>

#include "include/Dialect/TensorExt/IR/TensorExtOps.h"
#include "lib/Dialect/Utils.h"
#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
Expand All @@ -27,25 +28,6 @@ namespace tensor_ext {
#define GEN_PASS_DEF_COLLAPSEINSERTIONCHAINS
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc"

template <typename Op>
FailureOr<int64_t> get1DExtractionIndex(Op op) {
auto insertIndices = op.getIndices();
if (insertIndices.size() != 1) return failure();

// Each index must be constant; this may require running --canonicalize or
// -sccp before this pass to apply folding rules (use -sccp if you need to
// fold constants through control flow).
Value insertIndex = *insertIndices.begin();
auto insertIndexConstOp = insertIndex.getDefiningOp<arith::ConstantIndexOp>();
if (!insertIndexConstOp) return failure();

auto insertOffsetAttr =
llvm::dyn_cast<IntegerAttr>(insertIndexConstOp.getValue());
if (!insertOffsetAttr) return failure();

return insertOffsetAttr.getInt();
}

/// A pattern that searches for sequences of extract + insert, where the
/// indices extracted and inserted have the same offset, and replaced them with
/// a single rotate operation.
Expand Down
37 changes: 37 additions & 0 deletions lib/Dialect/Utils.h
@@ -0,0 +1,37 @@
#ifndef INCLUDE_DIALECT_UTILS_H_
#define INCLUDE_DIALECT_UTILS_H_

#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project

namespace mlir {
namespace heir {

/// Given a tensor::InsertOp or tensor::ExtractOp, and assuming the shape
/// of the input tensor is 1-dimensional and the input index is constant,
/// return the constant index value. If any of these conditions are not
/// met, return a failure.
template <typename Op>
FailureOr<int64_t> get1DExtractionIndex(Op op) {
auto insertIndices = op.getIndices();
if (insertIndices.size() != 1) return failure();

// Each index must be constant; this may require running --canonicalize or
// -sccp before this pass to apply folding rules (use -sccp if you need to
// fold constants through control flow).
Value insertIndex = *insertIndices.begin();
auto insertIndexConstOp = insertIndex.getDefiningOp<arith::ConstantIndexOp>();
if (!insertIndexConstOp) return failure();

auto insertOffsetAttr =
llvm::dyn_cast<IntegerAttr>(insertIndexConstOp.getValue());
if (!insertOffsetAttr) return failure();

return insertOffsetAttr.getInt();
}

} // namespace heir
} // namespace mlir

#endif // INCLUDE_DIALECT_UTILS_H_

0 comments on commit cf89039

Please sign in to comment.