-
Notifications
You must be signed in to change notification settings - Fork 141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add new model config for smaller tests #450
base: main
Are you sure you want to change the base?
Conversation
@jesus-orozco is still working on this PR, but it would be helpful to get some early feedback from @markblee - Thanks! |
cfg.mesh_shape = mesh_shape_from_axes(data=-1, fsdp=4) | ||
cfg.summary_writer.write_every_n_steps = eval_every_n_steps | ||
cfg.checkpointer.save_policy = config_for_function(every_n_steps_policy).set( | ||
n=eval_every_n_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to save checkpointer more frequently than eval? Something like save ckpt every 500 steps, eval every 1500 steps. This allows us to identify issues separately if the job hangs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback! Added the custom policy to save checkpoints more often to differentiate from eval steps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @jesus-orozco !
axlearn/experiments/text/gpt/fuji.py
Outdated
@@ -140,6 +140,29 @@ def get_trainer_kwargs(model_size: str, *, vocab_size: int, version: Version) -> | |||
), | |||
), | |||
) | |||
elif model_size == "simple": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Does this need to be separate from "test"
(which is itself intended to be the testing configuration)?
In particular, we can configure mesh_rules
for the accelerator that you are testing on. This way, it'll run on both CPU and the target testing hardware.
The only other differences seem to be batch sizes and eval/saving more frequently, which seem tolerable as defaults. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re mesh_rules, yeah sometimes we test on v4-8, and we need something like (-1,1,4,1,1).
However, re eval/saving/max step, we do want to have a config that terminate training early. As long as the training runs for a few thousands steps without problems, then we know the jax testing passes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! That sounds reasonable, adding the new defaults to the "test" configuration instead for frequent saving/early termination.
On mesh rules, I'll leave the default for it to work on CPU, but can you clarify how we can configure the rules for specific accelerators? as Maggie mentioned, we'd be testing mainly on smaller TPU shapes like v4-8.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a comment inline -- since it's a simple case, we probably do not need mesh rules. You can think of mesh rules as overrides to the default mesh. E.g.
mesh_rules=(
("tpu-v4-8", mesh_shape_from_axes(fsdp=-1)),
)
means that if the instance type matches tpu-v4-8
, we instead use (1,1,4,1,1)
rather than the default mesh_shape
. Let me know whether this makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @markblee !
Committed the changes you suggested, it makes sense to add data=-1 to the default configuration.
Co-authored-by: Mark Lee <mmaarrkklleeee@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, you might need to run golden config updates: https://github.com/apple/axlearn/blob/main/docs/01-start.md#testing
Adding a new model configuration for text experiments. The goal is to get an early termination model for fuji-test to accelerate infrastructure validation.