Skip to content

Commit

Permalink
Merge pull request #337 from markstur/compatible_encode
Browse files Browse the repository at this point in the history
Make encode() in wrapped model compatible with super encode()
  • Loading branch information
evaline-ju committed Mar 19, 2024
2 parents f55b082 + fc7d81e commit ce34b1c
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 26 deletions.
2 changes: 2 additions & 0 deletions caikit_nlp/config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ embedding:
retries: 0
# Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used
batch_size: 0
# Should implicit truncation (with truncate_input_tokens=0) throw error for truncation (default) or disable this
implicit_truncation_errors: true
# Attempt to optimize with PyTorch compile()
pt2_compile: false
# Use IPEX optimize. Works best when used with autocast (bfloat16) below.
Expand Down
112 changes: 89 additions & 23 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,11 @@
sentence_transformers = importlib.import_module("sentence_transformers")
# Third Party
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import batch_to_device, cos_sim, dot_score
from sentence_transformers.util import (
batch_to_device,
cos_sim,
dot_score,
normalize_embeddings,
semantic_search,
normalize_embeddings as normalize, # avoid parameter shadowing
)
from sentence_transformers.util import semantic_search
except ModuleNotFoundError:
# When it is not available, create a dummy that raises an error on attempted init()
class SentenceTransformerNotAvailable:
Expand All @@ -89,6 +87,9 @@ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
PT2_COMPILE = env_val_to_bool(val=embedding_cfg.get("pt2_compile"))
RETRIES = env_val_to_int(val=embedding_cfg.get("retries"), default=0)
BATCH_SIZE = env_val_to_int(val=embedding_cfg.get("batch_size"), default=0)
NO_IMPLICIT_TRUNCATION = env_val_to_bool(
val=embedding_cfg.get("implicit_truncation_errors", True)
)
DEVICE = embedding_cfg.get("device", "")

RT = TypeVar("RT") # return type
Expand Down Expand Up @@ -271,7 +272,9 @@ def _with_retry(self, fn: Callable[..., RT], *args, **kwargs) -> RT:
exception=first_exception,
)

def _encode_with_retry(self, *args, **kwargs) -> EmbeddingResultTuple:
def _encode_with_retry(
self, *args, **kwargs
) -> Union[EmbeddingResultTuple, List[torch.Tensor], np.ndarray, torch.Tensor]:
"""All encode calls should use this for consistent param adding and retry loop"""

# Add the batch_size kwarg if not passed in and given a usable BATCH_SIZE
Expand All @@ -281,6 +284,23 @@ def _encode_with_retry(self, *args, **kwargs) -> EmbeddingResultTuple:
if "batch_size" not in kwargs:
kwargs["batch_size"] = BATCH_SIZE

if isinstance(self.model, SentenceTransformerWithTruncate):
kwargs[
"implicit_truncation_errors"
] = NO_IMPLICIT_TRUNCATION # config/env overrides default
return self._with_retry(self.model.encode, *args, **kwargs)

# Else...
# It's possible to init with a model that doesn't have the added kwargs.
# E.g. a SentenceTransformer or other transformer model. Remove those kwargs!
# This is not the normal use case but at least don't pass invalid kwargs, to encode()
# and don't return the unexpected tuple (adding token count).
if "truncate_input_tokens" in kwargs:
del kwargs["truncate_input_tokens"]
if "return_token_count" in kwargs:
del kwargs["return_token_count"]
if "implicit_truncation_errors" in kwargs:
del kwargs["implicit_truncation_errors"]
return self._with_retry(self.model.encode, *args, **kwargs)

@EmbeddingTask.taskmethod()
Expand All @@ -306,7 +326,9 @@ def run_embedding(
error.type_check("<NLP27491611E>", str, text=text)

embeddings, input_token_count = self._encode_with_retry(
text, truncate_input_tokens=truncate_input_tokens
text,
truncate_input_tokens=truncate_input_tokens,
return_token_count=True,
)
return EmbeddingResult(
result=Vector1D.from_vector(embeddings),
Expand Down Expand Up @@ -341,7 +363,9 @@ def run_embeddings(
texts = [texts]

embeddings, input_token_count = self._encode_with_retry(
texts, truncate_input_tokens=truncate_input_tokens
texts,
truncate_input_tokens=truncate_input_tokens,
return_token_count=True,
)
vectors = [Vector1D.from_vector(e) for e in embeddings]

Expand Down Expand Up @@ -375,10 +399,14 @@ def run_sentence_similarity(
"""

source_embedding, source_token_count = self._encode_with_retry(
source_sentence, truncate_input_tokens=truncate_input_tokens
source_sentence,
truncate_input_tokens=truncate_input_tokens,
return_token_count=True,
)
embeddings, sentences_token_count = self._encode_with_retry(
sentences, truncate_input_tokens=truncate_input_tokens
sentences,
truncate_input_tokens=truncate_input_tokens,
return_token_count=True,
)

input_token_count = source_token_count + sentences_token_count
Expand Down Expand Up @@ -415,10 +443,14 @@ def run_sentence_similarities(
"""

source_embedding, source_token_count = self._encode_with_retry(
source_sentences, truncate_input_tokens=truncate_input_tokens
source_sentences,
truncate_input_tokens=truncate_input_tokens,
return_token_count=True,
)
embeddings, sentences_token_count = self._encode_with_retry(
sentences, truncate_input_tokens=truncate_input_tokens
sentences,
truncate_input_tokens=truncate_input_tokens,
return_token_count=True,
)

input_token_count = source_token_count + sentences_token_count
Expand Down Expand Up @@ -582,16 +614,18 @@ def get_text(doc):
doc_embeddings, doc_token_count = self._encode_with_retry(
doc_texts,
truncate_input_tokens=truncate_input_tokens,
return_token_count=True,
convert_to_tensor=True,
)
doc_embeddings = normalize_embeddings(doc_embeddings.to(self.model.device))
doc_embeddings = normalize(doc_embeddings.to(self.model.device))

query_embeddings, query_token_count = self._encode_with_retry(
queries,
truncate_input_tokens=truncate_input_tokens,
return_token_count=True,
convert_to_tensor=True,
)
query_embeddings = normalize_embeddings(query_embeddings.to(self.model.device))
query_embeddings = normalize(query_embeddings.to(self.model.device))

res = semantic_search(
query_embeddings, doc_embeddings, top_k=top_n, score_function=dot_score
Expand Down Expand Up @@ -754,7 +788,10 @@ def sum_token_count(

class SentenceTransformerWithTruncate(SentenceTransformer):
def _truncate_input_tokens(
self, truncate_input_tokens: int, texts: List[str]
self,
truncate_input_tokens: int,
texts: List[str],
implicit_truncation_errors: bool = True,
) -> TruncatedTokensTuple:
"""Truncate input tokens
Args:
Expand All @@ -766,6 +803,8 @@ def _truncate_input_tokens(
Otherwise, we take this usable truncation limit to truncate the input tokens.
texts: List[str]
Input texts to be checked and optionally truncated.
implicit_truncation_errors: bool
Configuration indicates whether implicit truncation should be rejected.
Returns:
Tuple containing a dictionary of lists/arrays/tensors returned by the tokenizer, with
proper truncation ('input_ids', 'attention_mask', etc.), and the input_token_count int.
Expand All @@ -781,7 +820,7 @@ def _truncate_input_tokens(
okay_to_truncate = True
max_length = truncate_input_tokens
else:
okay_to_truncate = False
okay_to_truncate = not implicit_truncation_errors
max_length = max_tokens

assert len(texts) > 0, "Cannot truncate nothing"
Expand Down Expand Up @@ -837,32 +876,53 @@ def encode(
self,
sentences: Union[str, List[str]],
batch_size: int = 32,
device: Optional[str] = None,
show_progress_bar: bool = None,
output_value: str = "sentence_embedding",
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
device: str = None,
normalize_embeddings: bool = False,
truncate_input_tokens: int = 0,
) -> EmbeddingResultTuple:
return_token_count: bool = False,
implicit_truncation_errors: bool = True,
) -> Union[EmbeddingResultTuple, List[torch.Tensor], np.ndarray, torch.Tensor]:
"""
Computes sentence embeddings
:param sentences: the sentences to embed
:param batch_size: the batch size used for the computation
:param device: Which torch.device to use for the computation
:param show_progress_bar: Ignored here. Added for compatibility with super API.
:param output_value: Ignored here. Added for compatibility with super API.
:param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list
of pytorch tensors.
:param convert_to_tensor: If true, you get one large tensor as return. Overwrites any
setting from convert_to_numpy
:param device: Which torch.device to use for the computation
:param normalize_embeddings: Ignored here. Added for compatibility with super API.
:param truncate_input_tokens: Truncation length for input tokens.
Truncation length for input tokens.
If less than zero, this truncation is left up to the tokenizer default (model max).
If zero or greater than the model's maximum, then this is used as a test
to see if truncation is needed. If needed is needed, an exception is thrown.
to see if truncation is needed. If truncation is needed, an exception is thrown,
unless implicit_truncation_errors=False (see below).
Otherwise, we take this usable truncation limit to truncate the input tokens.
:param return_token_count: If true, a tuple is returned to add the input token count.
:param implicit_truncation_errors: If true (default) implicit truncation throws an error.
If false, the model default behavior or used.
:return:
A tuple of the embedding, as a numpy matrix, and the input_token_count int.
If return_token_count is False, the embedding is returned as a numpy matrix.
If return_token_count is True, a tuple is returned with both the embedding and
the input token count.
"""

# These args are for API compatability, but are currently ignored in our version of encode()
_ = (
show_progress_bar,
output_value,
normalize_embeddings,
)

self.eval()

if convert_to_tensor:
Expand Down Expand Up @@ -899,7 +959,9 @@ def encode(
for start_index in range(0, len(list_of_sentences), batch_size):
sentences_batch = sentences_sorted[start_index : start_index + batch_size]
features, token_count = self._truncate_input_tokens(
truncate_input_tokens, sentences_batch
truncate_input_tokens,
sentences_batch,
implicit_truncation_errors=implicit_truncation_errors,
)
input_token_count += token_count

Expand Down Expand Up @@ -931,4 +993,8 @@ def encode(
if input_was_string:
all_embeddings = all_embeddings[0]

return EmbeddingResultTuple(all_embeddings, input_token_count)
return (
EmbeddingResultTuple(all_embeddings, input_token_count)
if return_token_count
else all_embeddings
)
2 changes: 2 additions & 0 deletions runtime_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ embedding:
retries: 0
# Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used
batch_size: 0
# Should implicit truncation (with truncate_input_tokens=0) throw error for truncation (default) or disable this
implicit_truncation_errors: true
# Attempt to optimize with PyTorch compile()
pt2_compile: false
# Use IPEX optimize. Works best when used with autocast (bfloat16) below.
Expand Down
78 changes: 75 additions & 3 deletions tests/modules/text_embedding/test_embedding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Tests for text embedding module"""

# Standard
from typing import List
from typing import List, Tuple
import os
import tempfile

Expand Down Expand Up @@ -170,8 +170,15 @@ def _assert_valid_scores(scores, type_tests={}):
return type_tests


def test_bootstrap_reuse():
assert isinstance(BOOTSTRAPPED_MODEL, EmbeddingModule), "bootstrap reuse error"
def test_bootstrap_model(loaded_model):
assert isinstance(BOOTSTRAPPED_MODEL, EmbeddingModule), "bootstrap model type"
assert (
BOOTSTRAPPED_MODEL.model.__class__.__name__ == "SentenceTransformer"
), "bootstrap model class name"
# worth noting that bootstrap does not wrap, but load does
assert (
loaded_model.model.__class__.__name__ == "SentenceTransformerWithTruncate"
), "loaded model class name"


def test_save_load_and_run():
Expand Down Expand Up @@ -488,6 +495,49 @@ def test__truncate_input_tokens_raises(truncate_input_tokens, loaded_model):
loaded_model.model.encode(
sentences=[too_long], truncate_input_tokens=truncate_input_tokens
)
# Same behavior when implicit_truncation_errors is True (the default)
with pytest.raises(ValueError, match=f"({over} > {model_max})"):
loaded_model.model.encode(
sentences=[too_long],
truncate_input_tokens=truncate_input_tokens,
implicit_truncation_errors=True,
)
# Different behavior when implicit_truncation_errors is False -- no error raised!
loaded_model.model.encode(
sentences=[too_long],
truncate_input_tokens=truncate_input_tokens,
implicit_truncation_errors=False,
)


def test__implicit_truncation(loaded_model):
"""Test that implicit truncation happens (when allowed)"""
model_max = loaded_model.model.max_seq_length

too_long = "x " * (model_max - 1) # This will go over a little
extra_long = (
too_long
+ "more clever words that surely change the meaning of this text"
* (model_max - 1)
)

# Allowed truncation using default tokens (0) and config to disable the error.
res = loaded_model.model.encode(
sentences=[too_long], truncate_input_tokens=0, implicit_truncation_errors=False
)
# Allowed truncation using model max
res_extra_max = loaded_model.model.encode(
sentences=[extra_long], truncate_input_tokens=loaded_model.model.max_seq_length
)
# Allowed truncation using -1 to just let the model do its thing
res_extra_neg = loaded_model.model.encode(
sentences=[extra_long], truncate_input_tokens=-1
)

# Demonstrating that when implicit truncation is allowed, sentence-transformers is quietly truncating at model max
# The simple too_long string of x's, is equivalent to the string with significantly different extra text (truncated)
assert np.allclose(res, res_extra_max)
assert np.allclose(res, res_extra_neg)


def test_not_too_many_tokens(loaded_model):
Expand Down Expand Up @@ -930,3 +980,25 @@ def test_get_sample_start_indexes(mapping, expected):
"overflow_to_sample_mapping": torch.Tensor(mapping).type(torch.int8)
}
assert get_sample_start_indexes(mock_tokenized) == expected


def test_encode_extensions(loaded_model):
# loaded model can return_token_count
ret = loaded_model._encode_with_retry("text here", return_token_count=True)
assert isinstance(ret, Tuple)
assert isinstance(ret[0], np.ndarray)
assert isinstance(ret[1], int)
ret = loaded_model._encode_with_retry("text here", return_token_count=False)
assert isinstance(ret, np.ndarray)

# Make sure use with un-wrapped SentenceTransformer model is unaffected by extended params or return tokens
ret = BOOTSTRAPPED_MODEL._encode_with_retry(
"text here",
return_token_count=True,
truncate_input_tokens=123,
implicit_truncation_errors=False,
)
assert isinstance(ret, np.ndarray)
BOOTSTRAPPED_MODEL._encode_with_retry(
"text here"
) # and no KeyError trying to remove non-existing keys

0 comments on commit ce34b1c

Please sign in to comment.