Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile does not support strided NestedTensor #126025

Open
ezyang opened this issue May 12, 2024 · 2 comments
Open

torch.compile does not support strided NestedTensor #126025

ezyang opened this issue May 12, 2024 · 2 comments
Labels
module: nestedtensor NestedTensor tag see issue #25032 oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ezyang
Copy link
Contributor

ezyang commented May 12, 2024

馃悰 Describe the bug

Internal xref: https://fb.workplace.com/groups/1075192433118967/posts/1426726281298912/

Repro script: https://gist.github.com/ezyang/ff84e2d4b90bb6f9d83171b1582a237c

One log from TORCH_LOGS=graph_breaks:

    self._load_attr(inst)                                                                                                                                     
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 1345, in _load_attr
    result = BuiltinVariable(getattr).call_function(                                                                                                          
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builtin.py", line 946, in call_function
    return handler(tx, args, kwargs)                                                                                                                          
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builtin.py", line 712, in <lambda>
    tx, [v.realize() for v in args], kwargs
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builtin.py", line 712, in <listcomp>
    tx, [v.realize() for v in args], kwargs
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/lazy.py", line 58, in realize          
    self._cache.realize()           
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/lazy.py", line 24, in realize         
    self.vt = VariableBuilder(tx, self.source)(self.value)
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 302, in __call__  
    vt = self._wrap(value)                                                     
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 444, in _wrap   
    return type_dispatch(self, value)                                          
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 1273, in wrap_tensor
    unimplemented("torch.compile does not support strided NestedTensor")
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/exc.py", line 216, in unimplemented                       
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: torch.compile does not support strided NestedTensor                       
                                                                                                                                                              
from user code:
   File "/home/ezyang/local/b/pytorch-env/lib/python3.10/site-packages/optimum/bettertransformer/models/encoder_models.py", line 293, in torch_dynamo_resume_i
n_forward_at_292
    if hidden_states.is_nested:

I wonder if it's actually exercising nested codepath at all though lol, maybe onlyt eh query is failing?

Versions

main

cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @msaroufim @bdhirsh @anijain2305 @chauhang

@ezyang
Copy link
Contributor Author

ezyang commented May 12, 2024

Oh it's actually trying to use nested tensor

DEBUG:torch._dynamo.symbolic_convert.__graph_breaks:Graph break: from user code at: 
  File "/home/ezyang/local/b/pytorch-env/lib/python3.10/site-packages/optimum/bettertransformer/models/encoder_models.py", line 301, in torch_dynamo_resume_in_forward_at_292
    hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
Traceback (most recent call last):
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 494, in wrapper
    return inner_fn(self, inst)
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 1253, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 737, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/functions.py", line 644, in call_function
    unimplemented(msg)
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/exc.py", line 216, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: 'skip function _VariableFunctionsClass._nested_tensor_from_mask in file Builtin _nested_tensor_from_mask'
/home/ezyang/local/b/pytorch-env/lib/python3.10/site-packages/optimum/bettertransformer/models/encoder_models.py:301: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at /data/users/ezyang/b/pytorch/aten/src/ATen/NestedTensorImpl.cpp:178.) 
  hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)

@jbschlosser
Copy link
Contributor

jbschlosser commented May 13, 2024

They're doing their own BetterTransformer fast path thing, and constructing NSTs via torch._nested_tensor_from_mask() to pass to torch._transformer_encoder_layer_fwd(), which is the fast path fusion op.

The solution is to use NJT instead, which is supported for torch.compile.

On our end, this requires NJT support for:

  • torch._nested_tensor_from_mask()
    • I'd expect this to error out if the masked output doesn't have a single ragged dim
  • torch._transformer_encoder_layer_fwd()
  • NJT.to_padded_tensor()

On their end, they should specify the jagged layout: torch._nested_tensor_from_mask(..., layout=torch.jagged).

@jbschlosser jbschlosser added the module: nestedtensor NestedTensor tag see issue #25032 label May 13, 2024
@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nestedtensor NestedTensor tag see issue #25032 oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants