Skip to content

Commit

Permalink
Merge pull request #316 from markstur/better_retry
Browse files Browse the repository at this point in the history
Embeddings: Fix retry error handling to return first exception. Default to zero retries.
  • Loading branch information
evaline-ju committed Feb 6, 2024
2 parents 18c4c55 + ecbd350 commit 6a749bd
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 18 deletions.
30 changes: 19 additions & 11 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,6 @@ def env_var_to_int(name, default):
# Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used
BATCH_SIZE = env_var_to_int("BATCH_SIZE", default=0)

# Retry count for catching sporadic encode() or tokenize() errors (in case if they come back)
RETRY_COUNT = env_var_to_int("RETRY_COUNT", default=5)


@module(
"eeb12558-b4fa-4f34-a9fd-3f5890e9cd3f",
Expand All @@ -109,6 +106,9 @@ def env_var_to_int(name, default):
)
class EmbeddingModule(ModuleBase):

# Retry count if enabled to try again (was for thread contention errors)
RETRY_COUNT = max(env_var_to_int("RETRY_COUNT", default=0), 0)

_ARTIFACTS_PATH_KEY = "artifacts_path"
_ARTIFACTS_PATH_DEFAULT = "artifacts"

Expand Down Expand Up @@ -238,17 +238,25 @@ def _optimize(model, ipex, device):
logger.warning(warn_msg, exc_info=True)
return model

@staticmethod
def _with_retry(fn, *args, **kwargs):
retries = max(RETRY_COUNT, 0)
for count in range(1 + retries): # try once plus retries (if needed)
def _with_retry(self, fn, *args, **kwargs):
first_exception = None
for count in range(1 + self.RETRY_COUNT): # try once plus retries (if needed)
try:
return fn(*args, **kwargs)
except Exception as e: # pylint: disable=broad-exception-caught
warn_msg = f"Retry {fn} due to: {e}"
logger.warning(warn_msg, exc_info=True)
time.sleep(0.1 * (count * 2))
error.log_raise("<NLP31069292E>", RuntimeError(f"Too many retries of fn={fn}"))
if first_exception is None:
first_exception = e
if self.RETRY_COUNT > 0:
warn_msg = f"Try {count + 1}: {fn} failed due to: {e}"
logger.warning("<NLP54902271W>", warn_msg, exc_info=True)
if count + 1 < self.RETRY_COUNT:
time.sleep(0.1 * (count * 2))

# If above return did not happen, raise the first exception
error.log_raise(
log_code="<NLP13096081E>",
exception=first_exception,
)

def _encode_with_retry(self, *args, **kwargs):
"""All encode calls should use this for consistent param adding and retry loop"""
Expand Down
40 changes: 33 additions & 7 deletions tests/modules/text_embedding/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,19 +727,45 @@ def test__with_retry_happy_path(loaded_model):
loaded_model._with_retry(print, "hello", "world", sep="<:)>", end="!!!\n")


def test__with_retry_fail(loaded_model):
"""fn never works, loops then raises RuntimeError"""
def test__with_retry_fail(loaded_model, monkeypatch):
"""fn never works, loops then raises the exception"""

def fn():
assert 0
raise (ValueError("always fails with ValueError"))

with pytest.raises(RuntimeError):
with pytest.raises(ValueError):
loaded_model._with_retry(fn)


def test__with_retry_fail_fail_win(loaded_model):
def test__with_retry_fail_fail(loaded_model, monkeypatch):
"""fn needs a few tries, tries twice and fails."""

monkeypatch.setattr(loaded_model, "RETRY_COUNT", 1) # less than 3 tries

def generate_ints():
yield from range(9) # More than enough for retry loop

ints = generate_ints()

def fail_fail_win():
for i in ints:
if i < 2: # fail, fail
raise (ValueError(f"fail {i}"))
else: # win and return 3
return i + 1

# Without a third try raises first exception
with pytest.raises(ValueError) as e:
loaded_model._with_retry(fail_fail_win)

assert e.value.args[0] == "fail 0", "expected first exception 'fail 0'"


def test__with_retry_fail_fail_win(loaded_model, monkeypatch):
"""fn needs a few tries, logs, loops and succeeds"""

monkeypatch.setattr(loaded_model, "RETRY_COUNT", 6) # test needs at least 3 tries

def generate_ints():
yield from range(9) # More than enough for retry loop

Expand All @@ -748,8 +774,8 @@ def generate_ints():
def fail_fail_win():
for i in ints:
if i < 2: # fail, fail
assert 0
else: # win
raise (ValueError("fail, fail"))
else: # win and return 3
return i + 1

# Third try did not raise an exception. Returns 3.
Expand Down

0 comments on commit 6a749bd

Please sign in to comment.