Skip to content

Commit

Permalink
Copybara import of the project:
Browse files Browse the repository at this point in the history
--
5c5b7e3 by Alexander <alexander.viand@intel.com>:

add "tensor of tensor" dialect conversion helpers

Adds a utility function `addTensorOfTensorConversionPatterns`` designed to be used as part of dialect conversion.
These support lowerings that would create "tensor of tensor" types, such as what happens to `tensor<..xpoly>` in `--polynomial-to-standard`.

This commit also enables this for `--polynomial-to-standard`, which is a first step towards supporting polynomial operations over tensors in that lowering.

COPYBARA_INTEGRATE_REVIEW=#508 from AlexanderViand-Intel:poly_tensor_support 5c5b7e3
PiperOrigin-RevId: 615421370
  • Loading branch information
AlexanderViand-Intel authored and Copybara-Service committed Mar 13, 2024
1 parent a96a312 commit 348c225
Show file tree
Hide file tree
Showing 6 changed files with 342 additions and 3 deletions.
5 changes: 5 additions & 0 deletions lib/Conversion/BUILD
Expand Up @@ -8,9 +8,14 @@ cc_library(
srcs = ["Utils.cpp"],
hdrs = ["Utils.h"],
deps = [
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)
3 changes: 3 additions & 0 deletions lib/Conversion/PolynomialToStandard/BUILD
Expand Up @@ -12,6 +12,8 @@ cc_library(
deps = [
"@heir//include/Conversion/PolynomialToStandard:pass_inc_gen",
"@heir//lib/Conversion:Utils",
"@heir//lib/Dialect/Polynomial/IR:Polynomial",
"@heir//lib/Dialect/Polynomial/IR:PolynomialAttributes",
"@heir//lib/Dialect/Polynomial/IR:PolynomialOps",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
Expand All @@ -23,6 +25,7 @@ cc_library(
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
Expand Down
34 changes: 31 additions & 3 deletions lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp
@@ -1,18 +1,45 @@
#include "include/Conversion/PolynomialToStandard/PolynomialToStandard.h"

#include <cassert>
#include <cstddef>
#include <cstdint>
#include <string>
#include <utility>
#include <vector>

#include "include/Dialect/Polynomial/IR/Polynomial.h"
#include "include/Dialect/Polynomial/IR/PolynomialAttributes.h"
#include "include/Dialect/Polynomial/IR/PolynomialDialect.h"
#include "include/Dialect/Polynomial/IR/PolynomialOps.h"
#include "include/Dialect/Polynomial/IR/PolynomialTypes.h"
#include "lib/Conversion/Utils.h"
#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project
#include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/include/mlir/IR/AffineMap.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/IR/Location.h" // from @llvm-project
#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/TypeRange.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project

Expand Down Expand Up @@ -474,11 +501,12 @@ struct ConvertMul : public OpConversionPattern<MulOp> {

// 2N - 1 sized result tensor -> reduce modulo ideal to get a N sized tensor
func::FuncOp divMod = getFuncOpCallback(funcType, polyTy.getRing());
if (!divMod)
if (!divMod) {
return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {
diag << "Missing software implementation for polynomial mod op of type"
<< funcType << " and for ring " << polyTy.getRing();
});
}

rewriter.replaceOpWithNewOp<func::CallOp>(op, divMod, polyMul.getResult(0));
return success();
Expand Down Expand Up @@ -739,8 +767,8 @@ void PolynomialToStandard::runOnOperation() {
ConvertConstant, ConvertMulScalar>(typeConverter, context);
patterns.add<ConvertMul>(typeConverter, patterns.getContext(), getDivmodOp);
addStructuralConversionPatterns(typeConverter, patterns, target);
addTensorOfTensorConversionPatterns(typeConverter, patterns, target);

// TODO(#143): Handle tensor of polys.
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
Expand Down
244 changes: 244 additions & 0 deletions lib/Conversion/Utils.cpp
@@ -1,8 +1,23 @@
#include "lib/Conversion/Utils.h"

#include <cstddef>
#include <cstdint>
#include <memory>

#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/IRMapping.h" // from @llvm-project
#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/include/mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Region.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project

namespace mlir {
Expand All @@ -12,6 +27,235 @@ using ::mlir::func::CallOp;
using ::mlir::func::FuncOp;
using ::mlir::func::ReturnOp;

struct ConvertAny : public ConversionPattern {
ConvertAny(const TypeConverter &typeConverter, MLIRContext *context)
: ConversionPattern(typeConverter, RewritePattern::MatchAnyOpTypeTag(),
/*benefit=*/1, context) {
setDebugName("ConvertAny");
setHasBoundedRewriteRecursion(true);
}

// generate a new op where all operands have been replaced with their
// materialized/typeconverted versions
LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> newOperandTypes;
if (failed(getTypeConverter()->convertTypes(op->getOperandTypes(),
newOperandTypes)))
return failure();

SmallVector<Type> newResultTypes;
if (failed(getTypeConverter()->convertTypes(op->getResultTypes(),
newResultTypes)))
return failure();

SmallVector<std::unique_ptr<Region>, 1> regions;
IRMapping mapping;
for (auto &r : op->getRegions()) {
Region *newRegion = new Region();
rewriter.cloneRegionBefore(r, *newRegion, newRegion->end(), mapping);
if (failed(rewriter.convertRegionTypes(newRegion, *this->typeConverter)))
return failure();
regions.emplace_back(newRegion);
}

Operation *newOp = rewriter.create(OperationState(
op->getLoc(), op->getName().getStringRef(), operands, newResultTypes,
op->getAttrs(), op->getSuccessors(), regions));

rewriter.replaceOp(op, newOp);
return success();
}
};

struct ConvertExtract : public OpConversionPattern<tensor::ExtractOp> {
ConvertExtract(mlir::MLIRContext *context)
: OpConversionPattern<tensor::ExtractOp>(context) {}

using OpConversionPattern::OpConversionPattern;

// Convert a tensor.extract that would type-convert to extracting a tensor to
// a tensor.extract_slice operation instead. Specifically, this targets
// extracting SourceType from tensor<...xSourceType> when SourceType would be
// type converted to tensor<...>.
LogicalResult matchAndRewrite(
tensor::ExtractOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// replace tensor.extract %t[%i] from tensor<shape x SourceType>
// with an equivalent tensor.slice from tensor<shape x resultshape>
auto shape = op.getTensor().getType().getShape();
auto resultType = getTypeConverter()
->convertType(op.getResult().getType())
.cast<RankedTensorType>();
auto resultShape = resultType.getShape();

// expand op's list of indices by appending as many zeros as there are
// dimension in resultShape
SmallVector<OpFoldResult> offsets;
offsets.append(op.getIndices().begin(), op.getIndices().end());
for (size_t i = 0; i < resultShape.size(); ++i) {
offsets.push_back(rewriter.getIndexAttr(0));
}

// expand resultShape by prepending as many ones as there are dimensions in
// shape
SmallVector<OpFoldResult> sizes;
for (size_t i = 0; i < shape.size(); ++i) {
sizes.push_back(rewriter.getIndexAttr(1));
}
for (int64_t i : resultShape) {
sizes.push_back(rewriter.getIndexAttr(i));
}

// strides are all 1, and we need as many as there are dimensions in
// both shape and resultShape together
SmallVector<OpFoldResult> strides;
for (size_t i = 0; i < shape.size() + resultShape.size(); ++i) {
strides.push_back(rewriter.getIndexAttr(1));
}

rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
op, resultType, adaptor.getTensor(), offsets, sizes, strides);

return success();
}
};

struct ConvertInsert : public OpConversionPattern<tensor::InsertOp> {
ConvertInsert(mlir::MLIRContext *context)
: OpConversionPattern<tensor::InsertOp>(context) {}

using OpConversionPattern::OpConversionPattern;

// Convert a tensor.insert that would type-convert to inserting a tensor to
// a tensor.insert_slice operation instead. Specifically, this targets
// inserting SourceType into tensor<...xSourceType> when SourceType would be
// type converted to tensor<...>.
LogicalResult matchAndRewrite(
tensor::InsertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// replace tensor.insert %s into %t[%i] with tensor<shape x SourceType>
// with an equivalent tensor.insert_slice with tensor<shape x resultshape>
auto shape = op.getDest().getType().getShape();
auto resultType = getTypeConverter()
->convertType(op.getScalar().getType())
.cast<RankedTensorType>();
auto resultShape = resultType.getShape();

// expand op's list of indices by appending as many zeros as there are
// dimension in resultShape
SmallVector<OpFoldResult> offsets;
offsets.append(op.getIndices().begin(), op.getIndices().end());
for (size_t i = 0; i < resultShape.size(); ++i) {
offsets.push_back(rewriter.getIndexAttr(0));
}

// expand resultShape by prepending as many ones as there are dimensions in
// shape
SmallVector<OpFoldResult> sizes;
for (size_t i = 0; i < shape.size(); ++i) {
sizes.push_back(rewriter.getIndexAttr(1));
}
for (int64_t i : resultShape) {
sizes.push_back(rewriter.getIndexAttr(i));
}

// strides are all 1, and we need as many as there are dimensions in
// both shape and resultShape together
SmallVector<OpFoldResult> strides;
for (size_t i = 0; i < shape.size() + resultShape.size(); ++i) {
strides.push_back(rewriter.getIndexAttr(1));
}

rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
op, adaptor.getScalar(), adaptor.getDest(), offsets, sizes, strides);

return success();
}
};

struct ConvertFromElements
: public OpConversionPattern<tensor::FromElementsOp> {
ConvertFromElements(mlir::MLIRContext *context)
: OpConversionPattern<tensor::FromElementsOp>(context) {}

using OpConversionPattern::OpConversionPattern;

// Converts a tensor.from_elements %s0, %s1, ... : tensor<...xSourceType>
// where SourceType would be type-converted to tensor<...> to
// a concatenation of the converted operands (with appropriate reshape)
LogicalResult matchAndRewrite(
tensor::FromElementsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Expand each of the (converted) operands:
SmallVector<Value> newOperands;
for (auto o : adaptor.getElements()) {
// extend tensor<...xT> to tensor<1x...xT>
if (auto tensorType = o.getType().dyn_cast<RankedTensorType>()) {
auto shape = tensorType.getShape();
SmallVector<int64_t> newShape(1, 1);
newShape.append(shape.begin(), shape.end());

// Create a dense constant for targetShape
auto shapeOp = rewriter.create<arith::ConstantOp>(
op.getLoc(),
RankedTensorType::get(newShape.size(), rewriter.getIndexType()),
rewriter.getIndexTensorAttr(newShape));

auto reshapeOp = rewriter.create<tensor::ReshapeOp>(
op.getLoc(),
RankedTensorType::get(newShape, tensorType.getElementType()), o,
shapeOp);
newOperands.push_back(reshapeOp);
} else {
newOperands.push_back(o);
}
}
// Create the final tensor.concat operation
rewriter.replaceOpWithNewOp<tensor::ConcatOp>(op, 0, newOperands);

return success();
}
};

void addTensorOfTensorConversionPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target) {
target.addDynamicallyLegalDialect<tensor::TensorDialect>([&](Operation *op) {
return typeConverter.isLegal(op->getOperandTypes());
});

typeConverter.addConversion([&](TensorType type) -> Type {
if (!typeConverter.isLegal(type.getElementType())) {
typeConverter.convertType(type.getElementType()).dump();
if (auto convertedType =
typeConverter.convertType(type.getElementType())) {
if (auto castConvertedType =
convertedType.dyn_cast<RankedTensorType>()) {
// Create the combined shape
auto polyShape = castConvertedType.getShape();
auto tensorShape = type.getShape();
SmallVector<int64_t, 4> combinedShape(tensorShape.begin(),
tensorShape.end());
combinedShape.append(polyShape.begin(), polyShape.end());
auto combinedType = RankedTensorType::get(
combinedShape, castConvertedType.getElementType());
return combinedType;
}
}
}
return type;
});

target.addDynamicallyLegalDialect<affine::AffineDialect>([&](Operation *op) {
return typeConverter.isLegal(op->getOperandTypes());
});

patterns.add<ConvertAny, ConvertExtract, ConvertInsert, ConvertFromElements>(
typeConverter, patterns.getContext());
}

void addStructuralConversionPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target) {
Expand Down
7 changes: 7 additions & 0 deletions lib/Conversion/Utils.h
@@ -1,11 +1,18 @@
#ifndef LIB_CONVERSION_UTILS_H_
#define LIB_CONVERSION_UTILS_H_

#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project

namespace mlir {
namespace heir {

// Adds conversion patterns that deal with tensor<..xsource_type>
// when source_type will be type converted to tensor<...>, too
void addTensorOfTensorConversionPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);

// Adds the standard set of conversion patterns for
// converting types involved in func, cf, etc., which
// don't depend on the logic of the dialect beyond the
Expand Down

0 comments on commit 348c225

Please sign in to comment.