Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Graph optimization model compilation error involving Pad operator #16898

Open
shaoyuyoung opened this issue Apr 17, 2024 · 3 comments
Open
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@shaoyuyoung
Copy link

I am trying to compile an ONNX (graph below) model using TVM.
5618520ff3e8817d39d4422c547eaf9

Of course, this is a complicated graph, but we can simplify it as below.
image

These two graphs are equal. When I try to compile them using TVM. The original ONNX model fails but the simplified ONNX model passes. It is very strange!

This seems to involve the Pad operator shape-checking problem.

In theory, I think TVM should have strong compatibility with the native ONNX model. However, the truth is not satisfactory.

It seems that only simplified, simple models are acceptable to TVM

Expected behavior

ONNX compilation passes

Actual behavior

onnx fail
Traceback (most recent call last):
  18: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  17: tvm::transform::Pass::operator()(tvm::IRModule) const
  16: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  15: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  14: _ZN3tvm7runtime13PackedFun
  13: tvm::runtime::TypedPackedFunc<tvm::relay::Function (tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::DynamicToStatic()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::transform::DynamicToStatic()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  12: tvm::relay::DynamicToStatic(tvm::relay::Function, tvm::IRModule)
  11: tvm::relay::DynamicToStaticMutator::PrepareInput(tvm::RelayExpr const&)
  10: tvm::transform::Pass::operator()(tvm::IRModule) const
  9: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  8: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  7: tvm::transform::Pass::operator()(tvm::IRModule) const
  6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  5: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  4: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  3: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  2: tvm::relay::TypeSolver::Solve()
  1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  0: tvm::relay::PadRel(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)
  File "/root/anaconda3/conda-bld/tvm-package_1701590675822/work/src/relay/op/nn/pad.cc", line 131
InternalError: Check failed: (data->shape.size() == param->pad_width.size()) is false: There should be as many pad width pairs as shape dimensions but the shape has 5 dimensions and there are 4 pad width pairs.

Environment

Operating System: Ubuntu 18
TVM:0.15
Torch: 2.1.1
ONNX: 1.15.0

Steps to reproduce

ONNX file is here: onnx.zip

Here is the script

from onnxsim import simplify
import tvm
from tvm import relay
import onnx


def compile_onnx(onnx_model, shape):
    mod_from_onnx, params_onnx = relay.frontend.from_onnx(onnx_model,
                                                          shape=shape)
    with tvm.transform.PassContext(opt_level=4):
        executor = relay.build_module.create_executor(
            'graph', mod_from_onnx, tvm.cpu(), 'llvm', params_onnx
        ).evaluate()


model = onnx.load('./model.onnx')

try:
    compile_onnx(model, {'v0_0': [], 'v6_0': [5, 5, 4, 2, 1]})
except Exception as e:
    print(f"onnx fail\n{e}")

model_simp, check = simplify(model)

onnx.save(model_simp, "./model_simp.onnx")

assert check, "Simplified ONNX model could not be validated"

try:
    compile_onnx(model_simp, {'v0_0': [], 'v6_0': [5, 5, 4, 2, 1]})
except Exception as e:
    print(f"onnx-simplify fail\n{e}")

Triage

  • needs-triage
@shaoyuyoung shaoyuyoung added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Apr 17, 2024
@xhmelon
Copy link
Contributor

xhmelon commented May 2, 2024

Could you provide the simplified model for debugging?

@shaoyuyoung
Copy link
Author

sorry for the late response.
below is the simplified model
model-sim.zip
In fact, the simplified model seems correct but the original model can't pass the compilation.
BTW, the simplified tool I used is here: https://github.com/daquexian/onnx-simplifier

@xhmelon

@shaoyuyoung
Copy link
Author

Hello, sorry to bother you again but I still feel confused about this bug. @xhmelon

Is there any new progress on this issue?

I know maybe you have no time to investigate this because of your busy schedule.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

2 participants