Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph break: Unsupported: torch.* op returned non-Tensor bool call_function <built-in function _is_any_autocast_enabled> #126026

Closed
ezyang opened this issue May 12, 2024 · 0 comments
Assignees
Labels
good first issue high priority module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ezyang
Copy link
Contributor

ezyang commented May 12, 2024

馃悰 Describe the bug

Internal xref: fb.workplace.com/groups/1075192433118967/posts/1426726281298912

Repro script: gist.github.com/ezyang/ff84e2d4b90bb6f9d83171b1582a237c

One log from TORCH_LOGS=graph_breaks:

DEBUG:torch._dynamo.symbolic_convert.__graph_breaks:Graph break: from user code at:
  File "/data/users/ezyang/b/transformers/src/transformers/models/bert/modeling_bert.py", line 1562, in forward
    outputs = self.bert(
  File "/data/users/ezyang/b/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/ezyang/b/transformers/src/transformers/models/bert/modeling_bert.py", line 1022, in forward
    encoder_outputs = self.encoder(
  File "/data/users/ezyang/b/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/ezyang/b/transformers/src/transformers/models/bert/modeling_bert.py", line 612, in forward
    layer_outputs = layer_module(
  File "/data/users/ezyang/b/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ezyang/local/b/pytorch-env/lib/python3.10/site-packages/optimum/bettertransformer/models/encoder_models.py", line 292, in forward
    if not self.training and not torch._C._is_any_autocast_enabled():

and

  File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 884, in run
    while self.step():
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 799, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 494, in wrapper
    return inner_fn(self, inst)
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 1253, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 737, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/torch.py", line 754, in call_function
    tensor_variable = wrap_fx_proxy(
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 1585, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 1858, in wrap_fx_proxy_cls
    unimplemented(
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/exc.py", line 216, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function <built-in function _is_any_autocast_enabled>
DEBUG:torch._dynamo.symbolic_convert.__graph_breaks:Graph break: from user code at:

Making a minimal repro should be easy. Should be easy to fix by constant folding.

Versions

main

cc @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78

@bdhirsh bdhirsh added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels May 14, 2024
ZelboK pushed a commit to ZelboK/pytorch that referenced this issue May 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue high priority module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants