-
Notifications
You must be signed in to change notification settings - Fork 434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Converting PrimLoopOp to SCF does not properly convert tensor arguments #3029
Comments
Minimal replicating example: module {
func.func @minimal_example() -> (!torch.vtensor<[2,3],f32>) {
%true = torch.constant.bool true
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%int5 = torch.constant.int 5
%int6 = torch.constant.int 6
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.zeros %0, %int6, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
%2 = torch.aten.ones %0, %int6, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
%3:1 = torch.prim.Loop %int5, %true, init(%1) {
^bb0(%arg1: !torch.int, %arg2: !torch.vtensor<[2,3],f32>):
%4 = torch.aten.add.Tensor %arg2, %2, %int1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
torch.prim.Loop.condition %true, iter(%4 : !torch.vtensor<[2,3],f32>)
} : (!torch.int, !torch.bool, !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[2,3],f32>)
return %3#0 : !torch.vtensor<[2,3],f32>
}
} |
Confirmed that this is a
It looks like before the conversion, the loop body was first converting the arg using But I have trouble locating exactly where it happens. It seems to be properly skipping tensor arguments. Weird. |
Going through things that uses the TypeConverter instance. Things that could be it:
// If the target type is non-torch type, then use TypeConverter to convert
// the type of the source.
if (targetType.isa<mlir::FloatType>()) {
targetType = Torch::FloatType::get(op->getContext());
torchArg = typeConverter->materializeSourceConversion(
rewriter, scfForOp.getLoc(), targetType, {to});
} else if (targetType.isa<mlir::IntegerType>()) {
unsigned bitWidth = targetType.getIntOrFloatBitWidth();
if (bitWidth == 1)
targetType = Torch::BoolType::get(op->getContext());
else
targetType = Torch::IntType::get(op->getContext());
torchArg = typeConverter->materializeSourceConversion(
rewriter, scfForOp.getLoc(), targetType, {to});
}
if (torchType.isa<Torch::BaseTensorType>()) {
loopConditionIterArgs.push_back(torchArg);
continue;
}
Value arg = typeConverter->materializeTargetConversion(
rewriter, scfForOp.getLoc(),
typeConverter->convertType(torchArg.getType()), {torchArg}); |
It's unlikely but maybe it's this one right at the beginning?
But it's only supposed to be converting loop result types. Let's change it anyways and see what happens. |
Nope. Skipping basetensortypes doesn't work. |
Currently stuck. Can't find any other uses of the typeconverter. Current assumptions that got me stuck:
I wonder what's wrong. |
Fix is uploaded to: We will need to check if The tests only roundtripped tensors, they never included any computation in the body. As a result they never materialized the compatbility casts between the |
This causes errors like
by outputting MLIR like
this
```mlir %15 = "scf.for"(%12, %14, %13, %9) ({ ^bb0(%arg0: index loc("./scratch/scfloop.mlir":14:12), %arg1: tensor<2x3xf32> loc("./scratch/scfloop.mlir":14:12)): ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::HasParent::Impl) %17 = "arith.index_cast"(%arg0) : (index) -> i64 loc("./scratch/scfloop.mlir":14:12) %18 = "torch_c.from_i64"(%17) : (i64) -> !torch.int loc("./scratch/scfloop.mlir":14:12) %19 = "torch_c.to_builtin_tensor"(%arg1) : (tensor<2x3xf32>) -> tensor<2x3xf32> loc("./scratch/scfloop.mlir":16:12) %20 = "tensor.empty"() : () -> tensor<2x3xf32> loc("./scratch/scfloop.mlir":16:12) %21 = "linalg.generic"(%19, %11, %20) <{indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = [#linalg.iterator_type, #linalg.iterator_type], operandSegmentSizes = array}> ({ ^bb0(%arg2: f32 loc("./scratch/scfloop.mlir":16:12), %arg3: f32 loc("./scratch/scfloop.mlir":13:10), %arg4: f32 loc("./scratch/scfloop.mlir":16:12)): %23 = "arith.addf"(%arg2, %arg3) <{fastmath = #arith.fastmath}> : (f32, f32) -> f32 loc("./scratch/scfloop.mlir":16:12) "linalg.yield"(%23) : (f32) -> () loc("./scratch/scfloop.mlir":16:12) }) : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> loc("./scratch/scfloop.mlir":16:12) %22 = "torch_c.from_builtin_tensor"(%21) : (tensor<2x3xf32>) -> !torch.vtensor<[2,3],f32> loc("./scratch/scfloop.mlir":16:12) "scf.yield"(%22) : (!torch.vtensor<[2,3],f32>) -> () loc("./scratch/scfloop.mlir":14:12) }) : (index, index, index, tensor<2x3xf32>) -> tensor<2x3xf32> loc("./scratch/scfloop.mlir":14:12) ```The text was updated successfully, but these errors were encountered: