Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnx2torch incorrectly omits one weight layer when converting onnx into pytorch. #208

Open
SuhwanSong opened this issue Mar 29, 2024 · 0 comments

Comments

@SuhwanSong
Copy link

Description

onnx2torch incorrectly omits one weight layer when converting onnx into pytorch. This mis-conversion leads to an IndexError: Dimension out of range in node_converters/global_average_pool.py", line 41, in <lambda>.

Steps to Reproduce

poc.zip

  1. Download and unzip "poc.zip" to get "poc.onnx" file.
  2. Run the following code with poc file.
import onnx, torch
import onnxruntime
from onnx2torch import convert

poc_onnx_model_path = 'poc.onnx'

# load onnx_model
onnx_model = onnx.load(poc_onnx_model_path)

# check model validity
onnx.checker.check_model(onnx_model)

# input
input_torch = torch.randn(1, 3, 512, 512)
input_ort      = {'input': input_torch.numpy()}

# no error in onnx
ort_session = onnxruntime.InferenceSession(poc_onnx_model_path)
output_ort  = ort_session.run(None, input_ort)


torch_model = convert(onnx_model)

# error occurs!
torch_model(input_torch)
Traceback (most recent call last):
  File "/home/suhwan/grad_course/metadl/report.py", line 23, in <module>
    torch_model(input_torch)
  File "/home/suhwan/.local/lib/python3.10/site-packages/torch/fx/graph_module.py", line 738, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/home/suhwan/.local/lib/python3.10/site-packages/torch/fx/graph_module.py", line 317, in __call__
    raise e
  File "/home/suhwan/.local/lib/python3.10/site-packages/torch/fx/graph_module.py", line 304, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/home/suhwan/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/suhwan/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.0", line 73, in forward
  File "/home/suhwan/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/suhwan/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/suhwan/.local/lib/python3.10/site-packages/onnx2torch/node_converters/global_average_pool.py", line 46, in forward
    return forward_lambda()
  File "/home/suhwan/.local/lib/python3.10/site-packages/onnx2torch/node_converters/global_average_pool.py", line 41, in <lambda>
    forward_lambda = lambda: torch.mean(input_tensor, dim=self._x_dims, keepdim=True)
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

System Configuration

  • onnx2torch version: 1.5.13
  • python version: 3.10.12
  • onnx: 1.16.0
  • onnxruntime: 1.17.1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant