Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

all_gather with gloo backend does not work in inference mode #126032

Open
youkaichao opened this issue May 12, 2024 · 1 comment
Open

all_gather with gloo backend does not work in inference mode #126032

youkaichao opened this issue May 12, 2024 · 1 comment
Labels
module: c10d Issues/PRs related to collective communications and process groups oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@youkaichao
Copy link
Collaborator

youkaichao commented May 12, 2024

馃悰 Describe the bug

A minimal reproducible example:

import torch
import torch.distributed as dist
dist.init_process_group(backend='gloo')
# dist.init_process_group(backend='nccl')
# torch.cuda.set_device(dist.get_rank())
with torch.inference_mode():
    data = [torch.ones((3, 3))] * dist.get_world_size()
    obj = data[dist.get_rank()]
    dist.all_gather(data, obj)
    # dist.broadcast(obj, src=0)

The error is:

E RuntimeError: Inplace update to inference tensor outside InferenceMode is not allowed.You can make a clone to get a normal tensor before doing inplace update.See pytorch/rfcs#17 for more details.

It looks strange, that nccl backend works in this case. broadcast works, too. Only all_gather does not work.

Versions

pytorch 2.3.0

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

@mikaylagawarecki mikaylagawarecki added the oncall: distributed Add this issue/PR to distributed oncall triage queue label May 14, 2024
@yf225 yf225 added the module: c10d Issues/PRs related to collective communications and process groups label May 20, 2024
@wconstab
Copy link
Contributor

wconstab commented Jun 5, 2024

I can reproduce this issue.

It may be that gloo implements all_gather in another cpu thread and the thread-local 'inference-mode' is not attached to that thread.

cc @albanD, do you know if it is the case and if we can easily attach the inference mode context onto the other thread?

note: i also tried adding .wait() on the allgather op, to ensure the operation completes before the inference mode context exits; it did not help.

@wconstab wconstab added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: c10d Issues/PRs related to collective communications and process groups oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants