-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core][Distributed] remove graph mode function #4818
Conversation
@@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn(): | |||
device = torch.device(f"cuda:{torch.distributed.get_rank()}") | |||
ensure_model_parallel_initialized(2, 2) | |||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) | |||
with graph_mode(): | |||
with graph_capture(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, how do we make sure it's not using custom all reduce?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, even before this PR, we cannot make sure it's not using custom all reduce. It is true in CI because our CI does not have custom allreduce.
To solve this problem, another refactor is needed. We need to expose a new function to create tp groups with different communicators. That's my next PR to come!
@dataclass | ||
class GraphCaptureContext: | ||
stream: torch.cuda.Stream |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this work for non-CUDA backends?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For XPU, this will be torch.xpu.Stream
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for addressing my comments!
Users only need to use
with graph_capture()
to manage the context when they capture the graph, before the graph can be replayed.Inside the capture, we need to turn on graph mode. Outside the capture, there is no need to call graph mode.
Therefore, these two functions can be merged into one.