New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor] Flex attention supports dynamic shape #125994
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125994
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit 5014543 with merge base d7fe3c4 (): NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should run some benchmarks too.
Yea, benchmarking is on the way. |
@@ -98,7 +99,7 @@ def generate_inputs( | |||
return query, key, value | |||
|
|||
|
|||
def run_single_experiment(config: ExperimentConfig) -> ExperimentResults: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Above in this file is
torch._dynamo.config.automatic_dynamic_shapes = False
does compile ignore this if dynamic=true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, dynamic=True
means forcing dynamic.
@@ -126,6 +126,19 @@ def score_mod(score, b, h, m, n): | |||
|
|||
|
|||
class TestTemplatedSDPA(InductorTestCase): | |||
def _check_equal(self, golden_out, ref_out, compiled_out, dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wrote something pretty similar lol:
https://github.com/pytorch/pytorch/pull/125515/files#diff-e3963412cc249e81fecfcf8774f5428b2b5e837ff3633ae13d8b7886ab5bc3b9R134
@@ -617,3 +617,7 @@ def is_from_defaults(source: Source): | |||
if isinstance(source, ChainedSource): | |||
return is_from_defaults(source.base) | |||
return False | |||
|
|||
|
|||
def is_cell_contents(source: Source): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this doing out of curiosity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is part of heuristic rules that determinate if we should wrap int as symint. Here we are saying if the value is from a cell closures, we would not make it dynamic since cell closures usually are constant. We define these heuristics based on source
.
@pytorchbot merge -f "No space left on device" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
## static shapes perf ``` | Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | |---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|----------------| | Average | 0.692 | | | | | | | | | Max | 0.855 | 16 | 16 | 4096 | 4096 | 64 | head_bias | torch.bfloat16 | | Min | 0.419 | 8 | 16 | 512 | 512 | 256 | noop | torch.bfloat16 | ``` ## dynamic shapes perf ``` | Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | |---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------| | Average | 0.670 | | | | | | | | | Max | 0.864 | 16 | 16 | 4096 | 4096 | 64 | relative_bias | torch.bfloat16 | | Min | 0.376 | 8 | 16 | 512 | 512 | 256 | relative_bias | torch.bfloat16 | ``` Pull Request resolved: pytorch#125994 Approved by: https://github.com/Chillee
static shapes perf
dynamic shapes perf
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang