-
Notifications
You must be signed in to change notification settings - Fork 141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add more ctc-loss related summaries #457
Conversation
axlearn/audio/asr_decoder.py
Outdated
jnp.mean(source_lengths), batch_size | ||
), | ||
"input_stats/frame_packing_effiency": WeightedScalar( | ||
jnp.sum(source_lengths) / input_batch["paddings"].size, input_batch["paddings"].size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard against division by 0 here and below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this before: the denominator input_batch['paddings'].size
is directly pulled from the input. If it is zero, many other places will become problematic well before we hit here (e.g., various normalization layer). Nevertheless, I have added safeguard for them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I don't think we should make assumptions about the other components.
), | ||
} | ||
# pytype: enable=attribute-error | ||
return ret_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a specific reason to prefer returning the summaries instead of just adding them here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By returning a dictionary, we can possibly reuse the summaries value in the base class for a more specific summaries in the subclasss. Also it allows us to override some summaries in the subclass.
axlearn/audio/asr_decoder.py
Outdated
|
||
def _input_stats_summary( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _input_stats_summary( | |
def _input_stats_summaries( |
or def _add_input_stats_summaries
if we decide to inline the add, which may be more similar to other callsites in the repo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have the changes been pushed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry. was confused and pushed to the other repo. Just pushed the change.
axlearn/audio/asr_decoder.py
Outdated
per_frame_loss = total_ctc_loss / num_valid_frames | ||
per_label_loss = total_ctc_loss / num_valid_labels | ||
batch_size = per_example_weight.shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, a few nits.
per_frame_loss = total_ctc_loss / num_valid_frames | ||
per_label_loss = total_ctc_loss / num_valid_labels |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
axlearn/audio/asr_decoder.py
Outdated
valid_label_mask = (1.0 - target_paddings) * per_example_weight[:, None] | ||
num_valid_frames = jnp.sum(valid_frame_mask) | ||
num_valid_labels = jnp.sum(valid_label_mask) | ||
num_valid_examples = jnp.maximum(per_example_weight.sum(), 1.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple nits -- since we sum over weights, 1.0 may not always be appropriate. We might also consider renaming num_valid_examples
to total_example_weight
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion.
Add more detailed loss to easy debugging