New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dtensor] refactor view ops to use OpStrategy #126011
Conversation
[ghstack-poisoned]
ghstack-source-id: 1767d9ed7dc1b77e7a83f7be3fa5caa158cfc5c2 Pull Request resolved: #126011
Previously, the rule-based view op sharding prop adjust a non-tensor arg `local_out_shape` within the rule itself. This was not viable in strategy-based sharding prop. Thus, this PR is adding a new option `non_tensor_arg_suggestions` into `PlacementStrategy` to address this problem. It also benefits the new factory ops in that we no longer need to compute their local shape and stride in `sharding_prop.py` in a customized way. Instead, we compute all such expected **tensor and non-tensor args** in `tensor_ops.py` and `view_ops.py`, which keeps the `sharding_prop.py` clean. cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 wconstab yf225 chauhang d4l3k [ghstack-poisoned]
ghstack-source-id: e7b1885f5fb24b8198e6e4b565d35cbc92a54b4d Pull Request resolved: #126011
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
first pass, this looks pretty good after an initial look, I do want to see if we can simplify propagate_shape_and_sharding
, and why we need to call it twice
looks like CI is failing with circular deps, please fix |
As titled. Some ops require adjustment of output shape argument. In rule-based sharding prop, global output shape was inferred in the rule (in `view_ops.py`). In strategy-based sharding prop, it is now obtained from propagated out_tensor_meta (in `sharding_prop.py`). cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 wconstab yf225 chauhang d4l3k [ghstack-poisoned]
ghstack-source-id: 94c584701a1f79c1314fb4b2a602a5cb87cf0f28 Pull Request resolved: #126011
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! Have some minor comments
self.assertEqual(view_as_complex_rule, expected_view_as_complex_rule) | ||
expected_view_as_real_rule = ( | ||
InputDim(0), | ||
Split(InputDim(1), (13, 2), 0), | ||
Split(InputDim(1), (13, 2), 1), | ||
) | ||
view_as_real_rule = ops[torch.view_as_real].dim_map(intermediate) | ||
view_as_real_rule = dim_maps[torch.view_as_real](intermediate) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we'll need to improve some of the existing test cases in this file to work with CommDebugMode, to make sure the communication that are happening is expected. This can be done in a follow up PR
As titled. Some ops require adjustment of output shape argument. In rule-based sharding prop, global output shape was inferred in the rule (in `view_ops.py`). In strategy-based sharding prop, it is now obtained from propagated out_tensor_meta (in `sharding_prop.py`). cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 wconstab yf225 chauhang d4l3k [ghstack-poisoned]
As titled. Some ops require adjustment of output shape argument. In rule-based sharding prop, global output shape was inferred in the rule (in `view_ops.py`). In strategy-based sharding prop, it is now obtained from propagated out_tensor_meta (in `sharding_prop.py`). cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 wconstab yf225 chauhang d4l3k [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
As titled. Some ops require adjustment of output shape argument. In rule-based sharding prop, global output shape was inferred in the rule (in `view_ops.py`). In strategy-based sharding prop, it is now obtained from propagated out_tensor_meta (in `sharding_prop.py`). Pull Request resolved: pytorch#126011 Approved by: https://github.com/wanchaol, https://github.com/XilunWu
Stack from ghstack (oldest at bottom):
output_
prefix from OpStrategy properties #126359As titled. Some ops require adjustment of output shape argument. In rule-based sharding prop, global output shape was inferred in the rule (in
view_ops.py
). In strategy-based sharding prop, it is now obtained from propagated out_tensor_meta (insharding_prop.py
).cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @wconstab @yf225 @chauhang @d4l3k