Skip to content

Commit

Permalink
Merge pull request #324 from waleedqk/ResponseOptions
Browse files Browse the repository at this point in the history
Adding tgis params to caikit api
  • Loading branch information
evaline-ju committed Apr 2, 2024
2 parents 3bf2a13 + 8c16dc4 commit 42c3075
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 11 deletions.
16 changes: 16 additions & 0 deletions caikit_nlp/modules/text_generation/peft_tgis_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ def run(
stop_sequences: Optional[List[str]] = None,
seed: Optional[np.uint64] = None,
preserve_input_text: bool = False,
input_tokens: bool = False,
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
) -> GeneratedTextResult:
f"""Run inference against the model running in TGIS.
Expand All @@ -221,6 +225,10 @@ def run(
return self.tgis_generation_client.unary_generate(
text=verbalized_text,
preserve_input_text=preserve_input_text,
input_tokens=input_tokens,
generated_tokens=generated_tokens,
token_logprobs=token_logprobs,
token_ranks=token_ranks,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
Expand Down Expand Up @@ -256,6 +264,10 @@ def run_stream_out(
stop_sequences: Optional[List[str]] = None,
seed: Optional[np.uint64] = None,
preserve_input_text: bool = False,
input_tokens: bool = False,
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
) -> Iterable[GeneratedTextStreamResult]:
f"""Run output stream inferencing against the model running in TGIS
Expand All @@ -275,6 +287,10 @@ def run_stream_out(
return self.tgis_generation_client.stream_generate(
text=verbalized_text,
preserve_input_text=preserve_input_text,
input_tokens=input_tokens,
generated_tokens=generated_tokens,
token_logprobs=token_logprobs,
token_ranks=token_ranks,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
Expand Down
17 changes: 16 additions & 1 deletion caikit_nlp/modules/text_generation/text_generation_tgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ def run(
stop_sequences: Optional[List[str]] = None,
seed: Optional[np.uint64] = None,
preserve_input_text: bool = False,
input_tokens: bool = False,
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
) -> GeneratedTextResult:
f"""Run inference against the model running in TGIS.
Expand All @@ -236,11 +240,14 @@ def run(
GeneratedTextResult
Generated text result produced by TGIS.
"""

if self._model_loaded:
return self.tgis_generation_client.unary_generate(
text=text,
preserve_input_text=preserve_input_text,
input_tokens=input_tokens,
generated_tokens=generated_tokens,
token_logprobs=token_logprobs,
token_ranks=token_ranks,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
Expand Down Expand Up @@ -276,6 +283,10 @@ def run_stream_out(
stop_sequences: Optional[List[str]] = None,
seed: Optional[np.uint64] = None,
preserve_input_text: bool = False,
input_tokens: bool = False,
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
) -> Iterable[GeneratedTextStreamResult]:
f"""Run output stream inferencing for text generation module.
Expand All @@ -289,6 +300,10 @@ def run_stream_out(
return self.tgis_generation_client.stream_generate(
text=text,
preserve_input_text=preserve_input_text,
input_tokens=input_tokens,
generated_tokens=generated_tokens,
token_logprobs=token_logprobs,
token_ranks=token_ranks,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
Expand Down
96 changes: 87 additions & 9 deletions caikit_nlp/toolkit/text_generation/tgis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,19 @@

GENERATE_FUNCTION_TGIS_ARGS = """
{}
preserve_input_text: str
preserve_input_text: bool
Whether or not the source string should be contained in the generated output,
e.g., as a prefix.
input_tokens: bool
Whether or not to include list of input tokens.
generated_tokens: bool
Whether or not to include list of individual generated tokens.
token_logprobs: bool
Whether or not to include logprob for each returned token.
Applicable only if generated_tokens == true and/or input_tokens == true
token_ranks: bool
Whether or not to include rank of each returned token.
Applicable only if generated_tokens == true and/or input_tokens == true
""".format(
GENERATE_FUNCTION_ARGS
)
Expand All @@ -48,6 +58,10 @@
def validate_inf_params(
text,
preserve_input_text,
input_tokens,
generated_tokens,
token_logprobs,
token_ranks,
eos_token,
max_new_tokens,
min_new_tokens,
Expand All @@ -74,6 +88,10 @@ def validate_inf_params(
)
error.type_check("<NLP65883535E>", str, text=text)
error.type_check("<NLP65883537E>", bool, preserve_input_text=preserve_input_text)
error.type_check("<NLP65883538E>", bool, input_tokens=input_tokens)
error.type_check("<NLP65883539E>", bool, generated_tokens=generated_tokens)
error.type_check("<NLP65883540E>", bool, token_logprobs=token_logprobs)
error.type_check("<NLP65883541E>", bool, token_ranks=token_ranks)
error.type_check("<NLP85452188E>", str, allow_none=True, eos_token=eos_token)
error.type_check(
"<NLP03860681E>",
Expand Down Expand Up @@ -174,6 +192,10 @@ def validate_inf_params(

def get_params(
preserve_input_text,
input_tokens,
generated_tokens,
token_logprobs,
token_ranks,
max_new_tokens,
min_new_tokens,
truncate_input_tokens,
Expand Down Expand Up @@ -211,10 +233,10 @@ def get_params(

res_options = generation_pb2.ResponseOptions(
input_text=preserve_input_text,
generated_tokens=True,
input_tokens=False,
token_logprobs=True,
token_ranks=True,
generated_tokens=generated_tokens,
input_tokens=input_tokens,
token_logprobs=token_logprobs,
token_ranks=token_ranks,
)
stopping = generation_pb2.StoppingCriteria(
stop_sequences=stop_sequences,
Expand Down Expand Up @@ -268,6 +290,10 @@ def unary_generate(
self,
text,
preserve_input_text,
input_tokens,
generated_tokens,
token_logprobs,
token_ranks,
max_new_tokens,
min_new_tokens,
truncate_input_tokens,
Expand Down Expand Up @@ -305,6 +331,10 @@ def unary_generate(
validate_inf_params(
text=text,
preserve_input_text=preserve_input_text,
input_tokens=input_tokens,
generated_tokens=generated_tokens,
token_logprobs=token_logprobs,
token_ranks=token_ranks,
eos_token=self.eos_token,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
Expand All @@ -325,6 +355,10 @@ def unary_generate(

params = get_params(
preserve_input_text=preserve_input_text,
input_tokens=input_tokens,
generated_tokens=generated_tokens,
token_logprobs=token_logprobs,
token_ranks=token_ranks,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
Expand Down Expand Up @@ -366,19 +400,43 @@ def unary_generate(
)
response = batch_response.responses[0]

token_list = []
if response.tokens is not None:
for token in response.tokens:
token_list.append(
GeneratedToken(
text=token.text, logprob=token.logprob, rank=token.rank
)
)

input_token_list = []
if response.input_tokens is not None:
for token in response.input_tokens:
input_token_list.append(
GeneratedToken(
text=token.text, logprob=token.logprob, rank=token.rank
)
)

return GeneratedTextResult(
generated_text=response.text,
generated_tokens=response.generated_token_count,
finish_reason=response.stop_reason,
producer_id=self.producer_id,
input_token_count=response.input_token_count,
seed=seed,
tokens=token_list,
input_tokens=input_token_list,
)

def stream_generate(
self,
text,
preserve_input_text,
input_tokens,
generated_tokens,
token_logprobs,
token_ranks,
max_new_tokens,
min_new_tokens,
truncate_input_tokens,
Expand Down Expand Up @@ -416,6 +474,10 @@ def stream_generate(
validate_inf_params(
text=text,
preserve_input_text=preserve_input_text,
input_tokens=input_tokens,
generated_tokens=generated_tokens,
token_logprobs=token_logprobs,
token_ranks=token_ranks,
eos_token=self.eos_token,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
Expand All @@ -434,6 +496,10 @@ def stream_generate(

params = get_params(
preserve_input_text=preserve_input_text,
input_tokens=input_tokens,
generated_tokens=generated_tokens,
token_logprobs=token_logprobs,
token_ranks=token_ranks,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
Expand Down Expand Up @@ -476,13 +542,25 @@ def stream_generate(
input_token_count=stream_part.input_token_count,
)
token_list = []
for token in stream_part.tokens:
token_list.append(
GeneratedToken(text=token.text, logprob=token.logprob)
)
if stream_part.tokens is not None:
for token in stream_part.tokens:
token_list.append(
GeneratedToken(
text=token.text, logprob=token.logprob, rank=token.rank
)
)
input_token_list = []
if stream_part.input_tokens is not None:
for token in stream_part.input_tokens:
input_token_list.append(
GeneratedToken(
text=token.text, logprob=token.logprob, rank=token.rank
)
)
yield GeneratedTextStreamResult(
generated_text=stream_part.text,
tokens=token_list,
input_tokens=input_token_list,
details=details,
)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ classifiers=[
"License :: OSI Approved :: Apache Software License"
]
dependencies = [
"caikit[runtime-grpc,runtime-http]>=0.26.14,<0.27.0",
"caikit[runtime-grpc,runtime-http]>=0.26.17,<0.27.0",
"caikit-tgis-backend>=0.1.27,<0.2.0",
# TODO: loosen dependencies
"accelerate>=0.22.0",
Expand Down
26 changes: 26 additions & 0 deletions tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,16 @@ def unary_generate(request):
fake_result.generated_token_count = 1
fake_result.text = "moose"
fake_result.input_token_count = 1
token = mock.Mock()
token.text = "moose"
token.logprob = 0.2
token.rank = 1
fake_result.tokens = [token]
input_tokens = mock.Mock()
input_tokens.text = "moose"
input_tokens.logprob = 0.2
input_tokens.rank = 1
fake_result.input_tokens = [input_tokens]
fake_response.responses = [fake_result]
return fake_response

Expand All @@ -228,7 +238,13 @@ def stream_generate(request):
token = mock.Mock()
token.text = "moose"
token.logprob = 0.2
token.rank = 1
fake_stream.tokens = [token]
input_tokens = mock.Mock()
input_tokens.text = "moose"
input_tokens.logprob = 0.2
input_tokens.rank = 1
fake_stream.input_tokens = [input_tokens]
fake_stream.text = "moose"
for _ in range(3):
yield fake_stream
Expand All @@ -248,6 +264,12 @@ def validate_unary_generate_response(result):
assert result.generated_tokens == 1
assert result.finish_reason == 5
assert result.input_token_count == 1
assert result.tokens[0].text == "moose"
assert result.tokens[0].logprob == 0.2
assert result.tokens[0].rank == 1
assert result.input_tokens[0].text == "moose"
assert result.input_tokens[0].logprob == 0.2
assert result.input_tokens[0].rank == 1

@staticmethod
def validate_stream_generate_response(stream_result):
Expand All @@ -259,6 +281,10 @@ def validate_stream_generate_response(stream_result):
assert first_result.generated_text == "moose"
assert first_result.tokens[0].text == "moose"
assert first_result.tokens[0].logprob == 0.2
assert first_result.tokens[0].rank == 1
assert first_result.input_tokens[0].text == "moose"
assert first_result.input_tokens[0].logprob == 0.2
assert first_result.input_tokens[0].rank == 1
assert first_result.details.finish_reason == 5
assert first_result.details.generated_tokens == 1
assert first_result.details.seed == 10
Expand Down

0 comments on commit 42c3075

Please sign in to comment.