Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new model config for smaller tests #450

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

jesus-orozco
Copy link

@jesus-orozco jesus-orozco commented May 8, 2024

Adding a new model configuration for text experiments. The goal is to get an early termination model for fuji-test to accelerate infrastructure validation.

@jiya-zhang
Copy link
Contributor

@jesus-orozco is still working on this PR, but it would be helpful to get some early feedback from @markblee - Thanks!

cfg.mesh_shape = mesh_shape_from_axes(data=-1, fsdp=4)
cfg.summary_writer.write_every_n_steps = eval_every_n_steps
cfg.checkpointer.save_policy = config_for_function(every_n_steps_policy).set(
n=eval_every_n_steps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to save checkpointer more frequently than eval? Something like save ckpt every 500 steps, eval every 1500 steps. This allows us to identify issues separately if the job hangs

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback! Added the custom policy to save checkpoints more often to differentiate from eval steps

Copy link
Contributor

@markblee markblee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jesus-orozco !

axlearn/experiments/text/gpt/c4_trainer.py Outdated Show resolved Hide resolved
@jesus-orozco jesus-orozco marked this pull request as ready for review May 10, 2024 00:29
@@ -140,6 +140,29 @@ def get_trainer_kwargs(model_size: str, *, vocab_size: int, version: Version) ->
),
),
)
elif model_size == "simple":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Does this need to be separate from "test" (which is itself intended to be the testing configuration)?

In particular, we can configure mesh_rules for the accelerator that you are testing on. This way, it'll run on both CPU and the target testing hardware.

The only other differences seem to be batch sizes and eval/saving more frequently, which seem tolerable as defaults. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re mesh_rules, yeah sometimes we test on v4-8, and we need something like (-1,1,4,1,1).

However, re eval/saving/max step, we do want to have a config that terminate training early. As long as the training runs for a few thousands steps without problems, then we know the jax testing passes

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! That sounds reasonable, adding the new defaults to the "test" configuration instead for frequent saving/early termination.
On mesh rules, I'll leave the default for it to work on CPU, but can you clarify how we can configure the rules for specific accelerators? as Maggie mentioned, we'd be testing mainly on smaller TPU shapes like v4-8.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a comment inline -- since it's a simple case, we probably do not need mesh rules. You can think of mesh rules as overrides to the default mesh. E.g.

mesh_rules=(
                ("tpu-v4-8", mesh_shape_from_axes(fsdp=-1)),
)

means that if the instance type matches tpu-v4-8, we instead use (1,1,4,1,1) rather than the default mesh_shape. Let me know whether this makes sense.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @markblee !
Committed the changes you suggested, it makes sense to add data=-1 to the default configuration.

Co-authored-by: Mark Lee <mmaarrkklleeee@gmail.com>
Copy link
Contributor

@markblee markblee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, you might need to run golden config updates: https://github.com/apple/axlearn/blob/main/docs/01-start.md#testing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants