Skip to content

Commit

Permalink
Merge pull request #45 from joerunde/fix-purge
Browse files Browse the repository at this point in the history
🐛 add backend purge
  • Loading branch information
joerunde committed Nov 30, 2023
2 parents 53ab63d + 4ca6537 commit 51e9cac
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
8 changes: 8 additions & 0 deletions caikit_tgis_backend/tgis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,14 @@ def load_prompt_artifacts(self, model_id: str, prompt_id: str, *prompt_artifacts
)
conn.load_prompt_artifacts(prompt_id, *prompt_artifacts)

def unload_prompt_artifacts(self, model_id: str, *prompt_ids: str):
"""Unload all the artifacts for the prompt ids provided with base model model_id"""
conn = self.get_connection(model_id)
error.value_check(
"<TGB99822514E>", conn is not None, "Unknown model {}", model_id
)
conn.unload_prompt_artifacts(*prompt_ids)

@property
def local_tgis(self) -> bool:
return self._local_tgis
Expand Down
4 changes: 2 additions & 2 deletions caikit_tgis_backend/tgis_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def load_prompt_artifacts(self, prompt_id: str, *artifact_paths: List[str]):
log.debug3("Copying %s -> %s", artifact_path, target_file)
shutil.copyfile(artifact_path, target_file)

def unload_prompt_artifacts(self, *prompt_ids: List[str]):
def unload_prompt_artifacts(self, *prompt_ids: str):
"""Unload the given prompts from TGIS
As implemented, this simply removes the prompt artifacts for these IDs
Expand All @@ -235,7 +235,7 @@ def unload_prompt_artifacts(self, *prompt_ids: List[str]):
accept that it's already deleted.
Args:
*prompt_ids (List[str]): The IDs to unload
*prompt_ids (str): The IDs to unload
"""
error.value_check(
"<TGB07970365E>",
Expand Down
15 changes: 15 additions & 0 deletions tests/test_tgis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,21 @@ def test_tgis_backend_config_load_prompt_artifacts():
os.path.join(bar_prompt_dir, prompt_id2, source_fnames[1])
)

# piggy-back to test unloading prompt artifacts
tgis_be.unload_prompt_artifacts("bar", prompt_id1, prompt_id2)
assert os.path.exists(
os.path.join(foo_prompt_dir, prompt_id1, source_fnames[0])
)
assert os.path.exists(
os.path.join(foo_prompt_dir, prompt_id2, source_fnames[1])
)
assert not os.path.exists(
os.path.join(bar_prompt_dir, prompt_id1, source_fnames[0])
)
assert not os.path.exists(
os.path.join(bar_prompt_dir, prompt_id2, source_fnames[1])
)

# Make sure non-prompt models raise
with pytest.raises(ValueError):
tgis_be.load_prompt_artifacts("baz", prompt_id1, source_files[0])
Expand Down

0 comments on commit 51e9cac

Please sign in to comment.