Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch_mlir.compile failing on RetinaNET #3096

Open
afanasyev-ilya opened this issue Apr 2, 2024 · 0 comments
Open

torch_mlir.compile failing on RetinaNET #3096

afanasyev-ilya opened this issue Apr 2, 2024 · 0 comments

Comments

@afanasyev-ilya
Copy link

afanasyev-ilya commented Apr 2, 2024

Hello everyone,

I want to convert retinanet model into MLIR using the following code:

        from torchvision import models
        import torch_mlir

        model = models.detection.retinanet_resnet50_fpn(weights=None)
        url = "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth"

        weights = torch.load(
                udlc_models.utils.download_weights(model_name, model_category, url)
            )

        model.load_state_dict(weights)
        model.eval()

        # inference example
        # TODO remove this after mlir conversion problem is fixed
        input = torch.rand(1, 3, 300, 400)
        predictions = model(input)
        print(predictions)

        # convert to mlir example, currently fails inside torch for some reason (fixme)
        self.label = model_label
        self.model = model
        self.mlir = torch_mlir.compile(self.model, input, output_type="stablehlo", use_tracing=True, verbose=True)

and while torch inference works correctly, torch_mlir.compile fails with strange error:

  File "/home/i.afanasyev/.cache/bazel/_bazel_i.afanasyev/15f743c13dff5dd647c8fd3b03305b9d/execroot/_main/bazel-out/k8-opt/bin/udlc.venv.runfiles/_main/models/python/udlc_models/vision/_impl/retinanet.py", line 45, in __init__
    mlir = torch_mlir.compile(model, input, output_type="stablehlo", use_tracing=True, verbose=True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/i.afanasyev/.cache/bazel/_bazel_i.afanasyev/15f743c13dff5dd647c8fd3b03305b9d/execroot/_main/bazel-out/k8-opt/bin/udlc.venv.runfiles/rules_python~0.31.0~pip~pypi_311_torch_mlir/site-packages/torch_mlir/__init__.py", line 414, in compile
    scripted = torch.jit.trace_module(
               ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/i.afanasyev/.cache/bazel/_bazel_i.afanasyev/15f743c13dff5dd647c8fd3b03305b9d/execroot/_main/bazel-out/k8-opt/bin/udlc.venv.runfiles/rules_python~0.31.0~pip~pypi_311_torch/site-packages/torch/jit/_trace.py", line 1074, in trace_module
    module._c._create_method_from_trace(
  File "/home/i.afanasyev/.cache/bazel/_bazel_i.afanasyev/15f743c13dff5dd647c8fd3b03305b9d/execroot/_main/bazel-out/k8-opt/bin/udlc.venv.runfiles/rules_python~0.31.0~pip~pypi_311_torch/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/i.afanasyev/.cache/bazel/_bazel_i.afanasyev/15f743c13dff5dd647c8fd3b03305b9d/execroot/_main/bazel-out/k8-opt/bin/udlc.venv.runfiles/rules_python~0.31.0~pip~pypi_311_torch/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/i.afanasyev/.cache/bazel/_bazel_i.afanasyev/15f743c13dff5dd647c8fd3b03305b9d/execroot/_main/bazel-out/k8-opt/bin/udlc.venv.runfiles/rules_python~0.31.0~pip~pypi_311_torch/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/i.afanasyev/.cache/bazel/_bazel_i.afanasyev/15f743c13dff5dd647c8fd3b03305b9d/execroot/_main/bazel-out/k8-opt/bin/udlc.venv.runfiles/rules_python~0.31.0~pip~pypi_311_torchvision/site-packages/torchvision/models/detection/retinanet.py", line 663, in forward
    detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/i.afanasyev/.cache/bazel/_bazel_i.afanasyev/15f743c13dff5dd647c8fd3b03305b9d/execroot/_main/bazel-out/k8-opt/bin/udlc.venv.runfiles/rules_python~0.31.0~pip~pypi_311_torchvision/site-packages/torchvision/models/detection/retinanet.py", line 556, in postprocess_detections
    keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/i.afanasyev/.cache/bazel/_bazel_i.afanasyev/15f743c13dff5dd647c8fd3b03305b9d/execroot/_main/bazel-out/k8-opt/bin/udlc.venv.runfiles/rules_python~0.31.0~pip~pypi_311_torchvision/site-packages/torchvision/ops/boxes.py", line 75, in batched_nms
    return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/i.afanasyev/.cache/bazel/_bazel_i.afanasyev/15f743c13dff5dd647c8fd3b03305b9d/execroot/_main/bazel-out/k8-opt/bin/udlc.venv.runfiles/rules_python~0.31.0~pip~pypi_311_torch/site-packages/torch/jit/_trace.py", line 1243, in wrapper
    return compiled_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/home/i.afanasyev/.cache/bazel/_bazel_i.afanasyev/15f743c13dff5dd647c8fd3b03305b9d/execroot/_main/bazel-out/k8-opt/bin/udlc.venv.runfiles/rules_python~0.31.0~pip~pypi_311_torchvision/site-packages/torchvision/ops/boxes.py", line 94, in _batched_nms_coordinate_trick
    offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
    boxes_for_nms = boxes + offsets[:, None]
    keep = nms(boxes_for_nms, scores, iou_threshold)
           ~~~ <--- HERE
    return keep
  File "/home/i.afanasyev/.cache/bazel/_bazel_i.afanasyev/15f743c13dff5dd647c8fd3b03305b9d/execroot/_main/bazel-out/k8-opt/bin/udlc.venv.runfiles/rules_python~0.31.0~pip~pypi_311_torchvision/site-packages/torchvision/ops/boxes.py", line 41, in nms
        _log_api_usage_once(nms)
    _assert_has_ops()
    return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
           ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
RuntimeError: Could not run 'torchvision::nms' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'torchvision::nms' is only available for these backends: [Meta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMeta, Tracer, AutocastCPU, AutocastXLA, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

Meta: registered at /dev/null:440 [kernel]
BackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:154 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:498 [backend fallback]
Functionalize: registered at ../aten/src/ATen/FunctionalizeFallbackKernel.cpp:324 [backend fallback]
Named: registered at ../aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at ../aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at ../aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
ZeroTensor: registered at ../aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:86 [backend fallback]
AutogradOther: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:53 [backend fallback]
AutogradCPU: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:57 [backend fallback]
AutogradCUDA: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:65 [backend fallback]
AutogradXLA: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:69 [backend fallback]
AutogradMPS: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:77 [backend fallback]
AutogradXPU: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:61 [backend fallback]
AutogradHPU: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:90 [backend fallback]
AutogradLazy: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:73 [backend fallback]
AutogradMeta: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:81 [backend fallback]
Tracer: registered at /__w/vision/vision/pytorch/vision/torchvision/csrc/ops/autocast/nms_kernel.cpp:34 [kernel]
AutocastCPU: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:378 [backend fallback]
AutocastXLA: registered at /__w/vision/vision/pytorch/vision/torchvision/csrc/ops/autocast/nms_kernel.cpp:27 [kernel]
AutocastCUDA: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:244 [backend fallback]
FuncTorchBatched: registered at ../aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:720 [backend fallback]
BatchedNestedTensor: registered at ../aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:746 [backend fallback]
FuncTorchVmapMode: fallthrough registered at ../aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at ../aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at ../aten/src/ATen/functorch/TensorWrapper.cpp:203 [backend fallback]
PythonTLSSnapshot: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:162 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:494 [backend fallback]
PreDispatch: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:166 [backend fallback]
PythonDispatcher: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:158 [backend fallback]

what can be a reason of this? Should I try torch with other backend, for example CUDA? I thought that MLIR conversion is target-independent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant