-
Notifications
You must be signed in to change notification settings - Fork 141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
expose max_decode_len and eos_token_id in decoding #328
base: main
Are you sure you want to change the base?
Conversation
@@ -157,6 +158,7 @@ def beam_search_decode( | |||
input_batch: a dict with a minimum of the following entries: | |||
prefix: Prompt IDs representing a Tensor of shape [batch, max_sequence_length]. | |||
num_decodes: the number of beams to decode. | |||
eos_token_id: The end of sentence token id. If not set, will use cfg.eos_token_id. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason not to just set cfg.eos_token_id
directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
setting cfg.eos_token_id might not be suitable for inference side? especially if we have different requests to run against the same instance with different eos token id.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the motivation to customize when to stop decoding? If so, have you considered adding stop_decoding_condition
similar to sample_decode
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Guoli, what's the use case? Should we first discuss in an internal PR?
@@ -157,6 +158,7 @@ def beam_search_decode( | |||
input_batch: a dict with a minimum of the following entries: | |||
prefix: Prompt IDs representing a Tensor of shape [batch, max_sequence_length]. | |||
num_decodes: the number of beams to decode. | |||
eos_token_id: The end of sentence token id. If not set, will use cfg.eos_token_id. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the motivation to customize when to stop decoding? If so, have you considered adding stop_decoding_condition
similar to sample_decode
?
sg. let's discuss in an internal PR firstly. |
Expose
It looks like both have been added unit test in decoding_test.py. This change is to expose the parameter to causal_lm and decoder module.