Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Converting PrimLoopOp to SCF does not properly convert tensor arguments #3029

Open
renxida opened this issue Mar 15, 2024 · 7 comments
Open

Comments

@renxida
Copy link
Collaborator

renxida commented Mar 15, 2024

This causes errors like

./scratch/scfloop.mlir:16:12: error: 'torch_c.to_builtin_tensor' op operand #0 must be Multi-dimensional array modeling Torch's Tensor type, but got 'tensor<2x3xf32>'

by outputting MLIR like

this ```mlir %15 = "scf.for"(%12, %14, %13, %9) ({ ^bb0(%arg0: index loc("./scratch/scfloop.mlir":14:12), %arg1: tensor<2x3xf32> loc("./scratch/scfloop.mlir":14:12)): ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::HasParent::Impl) %17 = "arith.index_cast"(%arg0) : (index) -> i64 loc("./scratch/scfloop.mlir":14:12) %18 = "torch_c.from_i64"(%17) : (i64) -> !torch.int loc("./scratch/scfloop.mlir":14:12) %19 = "torch_c.to_builtin_tensor"(%arg1) : (tensor<2x3xf32>) -> tensor<2x3xf32> loc("./scratch/scfloop.mlir":16:12) %20 = "tensor.empty"() : () -> tensor<2x3xf32> loc("./scratch/scfloop.mlir":16:12) %21 = "linalg.generic"(%19, %11, %20) <{indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = [#linalg.iterator_type, #linalg.iterator_type], operandSegmentSizes = array}> ({ ^bb0(%arg2: f32 loc("./scratch/scfloop.mlir":16:12), %arg3: f32 loc("./scratch/scfloop.mlir":13:10), %arg4: f32 loc("./scratch/scfloop.mlir":16:12)): %23 = "arith.addf"(%arg2, %arg3) <{fastmath = #arith.fastmath}> : (f32, f32) -> f32 loc("./scratch/scfloop.mlir":16:12) "linalg.yield"(%23) : (f32) -> () loc("./scratch/scfloop.mlir":16:12) }) : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> loc("./scratch/scfloop.mlir":16:12) %22 = "torch_c.from_builtin_tensor"(%21) : (tensor<2x3xf32>) -> !torch.vtensor<[2,3],f32> loc("./scratch/scfloop.mlir":16:12) "scf.yield"(%22) : (!torch.vtensor<[2,3],f32>) -> () loc("./scratch/scfloop.mlir":14:12) }) : (index, index, index, tensor<2x3xf32>) -> tensor<2x3xf32> loc("./scratch/scfloop.mlir":14:12) ```
@renxida
Copy link
Collaborator Author

renxida commented Mar 15, 2024

Minimal replicating example:

module {
  func.func @minimal_example() -> (!torch.vtensor<[2,3],f32>) {
    %true = torch.constant.bool true
    %int0 = torch.constant.int 0
    %int1 = torch.constant.int 1
    %int2 = torch.constant.int 2
    %int3 = torch.constant.int 3
    %int5 = torch.constant.int 5
    %int6 = torch.constant.int 6
    %none = torch.constant.none
    %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.aten.zeros %0, %int6, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
    %2 = torch.aten.ones %0, %int6, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
    %3:1 = torch.prim.Loop %int5, %true, init(%1) {
    ^bb0(%arg1: !torch.int, %arg2: !torch.vtensor<[2,3],f32>):
      %4 = torch.aten.add.Tensor %arg2, %2, %int1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
      torch.prim.Loop.condition %true, iter(%4 : !torch.vtensor<[2,3],f32>)
    } : (!torch.int, !torch.bool, !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[2,3],f32>)
    return %3#0 : !torch.vtensor<[2,3],f32>
  }
}

@renxida
Copy link
Collaborator Author

renxida commented Mar 15, 2024

Confirmed that this is a ConvertTorchPrimLoopForLikeOp issue, because:

  1. the error happened right after ConvertTorchPrimLoopForLikeOp ran
  2. All the way up to the ConvertTorchPrimLoopForLikeOp run, the loop had %arg1: !torch.vtensor<[2,3],f32>

It looks like before the conversion, the loop body was first converting the arg using torch_c.to_builtin_tensor and converting it back for the next iteration. ConvertTorchPrimLoopForLikeOp caused the double-conversion by converting the loop block argument types.

But I have trouble locating exactly where it happens. It seems to be properly skipping tensor arguments. Weird.

@renxida
Copy link
Collaborator Author

renxida commented Mar 15, 2024

Going through things that uses the TypeConverter instance.

Things that could be it:

  1. this shouldn't be it because the vtensor should match to neither mlir::FloatType nor mlir::IntegerType.
      // If the target type is non-torch type, then use TypeConverter to convert
      // the type of the source.
      if (targetType.isa<mlir::FloatType>()) {
        targetType = Torch::FloatType::get(op->getContext());
        torchArg = typeConverter->materializeSourceConversion(
            rewriter, scfForOp.getLoc(), targetType, {to});
      } else if (targetType.isa<mlir::IntegerType>()) {
        unsigned bitWidth = targetType.getIntOrFloatBitWidth();
        if (bitWidth == 1)
          targetType = Torch::BoolType::get(op->getContext());
        else
          targetType = Torch::IntType::get(op->getContext());
        torchArg = typeConverter->materializeSourceConversion(
            rewriter, scfForOp.getLoc(), targetType, {to});
      }
  1. this shouldn't be it because the isa<Torch::BaseTensorType> would skip this one
          if (torchType.isa<Torch::BaseTensorType>()) {
            loopConditionIterArgs.push_back(torchArg);
            continue;
          }
          Value arg = typeConverter->materializeTargetConversion(
              rewriter, scfForOp.getLoc(),
              typeConverter->convertType(torchArg.getType()), {torchArg});

@renxida
Copy link
Collaborator Author

renxida commented Mar 15, 2024

It's unlikely but maybe it's this one right at the beginning?

    if (failed(
            typeConverter->convertTypes(op.getResultTypes(), newResultTypes)))
      return rewriter.notifyMatchFailure(
          op, "could not convert PrimLoopOp outputs");

But it's only supposed to be converting loop result types. Let's change it anyways and see what happens.

@renxida
Copy link
Collaborator Author

renxida commented Mar 15, 2024

Nope. Skipping basetensortypes doesn't work.

@renxida
Copy link
Collaborator Author

renxida commented Mar 15, 2024

Currently stuck. Can't find any other uses of the typeconverter. Current assumptions that got me stuck:

  1. this is a ConvertTorchPriLoopForLikeOp issue
  2. This issue is caused by a basetensor being converted when it's not supposed to
  3. the faulty conversion happens through TypeConverter
  4. The vtensor should not match isamlir::FloatType, isamlir::IntegerType
  5. the vtensor should match isaTorch::BaseTensorType

I wonder what's wrong.

@rsuderman
Copy link
Contributor

rsuderman commented Mar 15, 2024

Fix is uploaded to:
https://github.com/rsuderman/torch-mlir/tree/torch_scf

We will need to check if while or if have similar issues.

The tests only roundtripped tensors, they never included any computation in the body. As a result they never materialized the compatbility casts between the torch and mlir tensor types. This meant as long as it was just round-tripping the tensor things worked, however once you have any computation it starts returning incorrect types on the region boundaries.

@rsuderman rsuderman reopened this Mar 15, 2024
@renxida renxida changed the title Converting PrimLoopOp to SCF double-converts tensor arguments Converting PrimLoopOp to SCF does not properly convert tensor arguments Mar 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants