Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add more ctc-loss related summaries #457

Merged
merged 4 commits into from
May 14, 2024
Merged

Conversation

yqwangustc
Copy link
Contributor

Add more detailed loss to easy debugging

jnp.mean(source_lengths), batch_size
),
"input_stats/frame_packing_effiency": WeightedScalar(
jnp.sum(source_lengths) / input_batch["paddings"].size, input_batch["paddings"].size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guard against division by 0 here and below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this before: the denominator input_batch['paddings'].size is directly pulled from the input. If it is zero, many other places will become problematic well before we hit here (e.g., various normalization layer). Nevertheless, I have added safeguard for them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I don't think we should make assumptions about the other components.

),
}
# pytype: enable=attribute-error
return ret_dict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific reason to prefer returning the summaries instead of just adding them here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By returning a dictionary, we can possibly reuse the summaries value in the base class for a more specific summaries in the subclasss. Also it allows us to override some summaries in the subclass.


def _input_stats_summary(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _input_stats_summary(
def _input_stats_summaries(

or def _add_input_stats_summaries if we decide to inline the add, which may be more similar to other callsites in the repo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have the changes been pushed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry. was confused and pushed to the other repo. Just pushed the change.

Comment on lines 309 to 311
per_frame_loss = total_ctc_loss / num_valid_frames
per_label_loss = total_ctc_loss / num_valid_labels
batch_size = per_example_weight.shape[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@yqwangustc yqwangustc requested a review from markblee May 13, 2024 14:22
Yongqiang Wang added 2 commits May 13, 2024 21:10
Copy link
Contributor

@markblee markblee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, a few nits.

Comment on lines +311 to +312
per_frame_loss = total_ctc_loss / num_valid_frames
per_label_loss = total_ctc_loss / num_valid_labels
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

valid_label_mask = (1.0 - target_paddings) * per_example_weight[:, None]
num_valid_frames = jnp.sum(valid_frame_mask)
num_valid_labels = jnp.sum(valid_label_mask)
num_valid_examples = jnp.maximum(per_example_weight.sum(), 1.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple nits -- since we sum over weights, 1.0 may not always be appropriate. We might also consider renaming num_valid_examples to total_example_weight.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion.

@yqwangustc yqwangustc requested review from zhiyun and markblee and removed request for zhiyun May 14, 2024 22:07
@yqwangustc yqwangustc added this pull request to the merge queue May 14, 2024
Merged via the queue into apple:main with commit d5f219f May 14, 2024
4 checks passed
@yqwangustc yqwangustc deleted the ctc_loss branch May 14, 2024 22:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants