Skip to content

Commit

Permalink
Merge pull request #53 from mynhardtburger/51_Do-not-overwrite-prompt…
Browse files Browse the repository at this point in the history
…-artifacts

Do not overwrite prompt artifacts
  • Loading branch information
evaline-ju committed Feb 14, 2024
2 parents f7b8a76 + 26e4f27 commit 43f8174
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 10 deletions.
55 changes: 46 additions & 9 deletions caikit_tgis_backend/tgis_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,14 @@
# limitations under the License.
"""Encapsulate the creation of a TGIS Connection"""

# Future
from __future__ import annotations

# Standard
from collections.abc import Container
from dataclasses import dataclass
from typing import List, Optional
from pathlib import Path
from typing import TYPE_CHECKING, Optional
import os
import shutil

Expand All @@ -30,6 +35,10 @@
from .load_balancing_proxy import GRPCLoadBalancerProxy
from .protobufs import generation_pb2, generation_pb2_grpc

if TYPE_CHECKING:
# Third Party
from _typeshed import StrPath

log = alog.use_channel("TGCONN")
error = error_handler.get(log)

Expand All @@ -43,7 +52,6 @@ class TLSFilePair:
# pylint: disable=too-many-instance-attributes
@dataclass
class TGISConnection:

#################
# Class members #
#################
Expand All @@ -57,7 +65,7 @@ class TGISConnection:
# Paths to client key/cert pair when TGIS requires mTLS
client_tls: Optional[TLSFilePair] = None
# TLS HN override
tls_hostname_override: str = None
tls_hostname_override: Optional[str] = None
# Mounted directory where TGIS will look for prompt vector artifacts
prompt_dir: Optional[str] = None
# Load balancing policy
Expand Down Expand Up @@ -196,7 +204,7 @@ def tls_enabled(self) -> bool:
def mtls_enabled(self) -> bool:
return None not in [self.ca_cert_file, self.client_tls]

def load_prompt_artifacts(self, prompt_id: str, *artifact_paths: List[str]):
def load_prompt_artifacts(self, prompt_id: str, *artifact_paths: str):
"""Load the given artifact paths to this TGIS connection
As implemented, this is a simple copy to the TGIS instance's prompt dir,
Expand All @@ -208,7 +216,7 @@ def load_prompt_artifacts(self, prompt_id: str, *artifact_paths: List[str]):
Args:
prompt_id (str): The ID that this prompt should use
*artifact_paths (List[str]): The paths to the artifacts to laod
*artifact_paths (Tuple[str]): The paths to the artifacts to load
"""
error.value_check(
"<TGB07970356E>",
Expand All @@ -221,13 +229,30 @@ def load_prompt_artifacts(self, prompt_id: str, *artifact_paths: List[str]):
str,
artifact_paths=artifact_paths,
)
target_dir = os.path.join(self.prompt_dir, prompt_id)

target_dir = Path(self.prompt_dir) / prompt_id
os.makedirs(target_dir, exist_ok=True)
for artifact_path in artifact_paths:

# Don't copy files which are already in the target_dir
existing_artifact_names = {f.name for f in target_dir.iterdir()}
new_artifacts = {
Path(f)
for f in artifact_paths
if file_or_swp_not_in_listing(Path(f).name, existing_artifact_names)
}

for artifact_path in new_artifacts:
error.file_check("<TGB14818050E>", artifact_path)
target_file = os.path.join(target_dir, os.path.basename(artifact_path))
target_file = target_dir / artifact_path.name
swp_file = target_file.with_name(target_file.name + ".swp")

# Copy file as a swap file
log.debug3("Copying %s -> %s", artifact_path, target_file)
shutil.copyfile(artifact_path, target_file)
shutil.copyfile(artifact_path, swp_file)

# Rename on completion of copy using replace
# Replace silently overrides the destination irrespective of OS
os.replace(swp_file, target_file)

def unload_prompt_artifacts(self, *prompt_ids: str):
"""Unload the given prompts from TGIS
Expand Down Expand Up @@ -344,3 +369,15 @@ def _load_tls_file(file_path: Optional[str]) -> Optional[bytes]:
)
with open(file_path, "rb") as handle:
return handle.read()


def file_or_swp_not_in_listing(
filename: "StrPath", file_listing: Container[str], swap_extension: str = ".swp"
) -> bool:
"""Determine if the file, or its swap variant, is in the file listing."""
file = Path(filename)

return (
file.name not in file_listing
and file.with_suffix(swap_extension).name not in file_listing
)
89 changes: 88 additions & 1 deletion tests/test_tgis_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
# Standard
from contextlib import contextmanager
from pathlib import Path
import os
import tempfile

Expand Down Expand Up @@ -142,7 +143,6 @@ def test_load_prompt_artifacts_bad_source_file():
"""
with tempfile.TemporaryDirectory() as source_dir:
with tempfile.TemporaryDirectory() as prompt_dir:

# Make the connection with the prompt dir
conn = TGISConnection.from_config(
"",
Expand Down Expand Up @@ -184,6 +184,93 @@ def test_load_prompt_artifacts_no_prompt_dir():
conn.load_prompt_artifacts(prompt_id, *source_files)


def tests_load_prompt_artifacts_dont_copy_existing_files():
"""Make sure that only files which doesn't exist in the prompt dir is copied"""
with tempfile.TemporaryDirectory() as source_dir:
with tempfile.TemporaryDirectory() as prompt_dir:
prompt_id = "some-prompt-id"

# Make some source files and prompt files
fnames = ["foo.pt", "bar.pt"]
source_files = [os.path.join(source_dir, fname) for fname in fnames]
for fname in source_files:
with open(fname, "w", encoding="utf8") as f:
f.write("new stub")

# Make some source files and prompt files
# Output path: prompt_dir / prompt_id / prompt_file.pt
os.mkdir(os.path.join(prompt_dir, prompt_id))
prompt_files = [
os.path.join(prompt_dir, prompt_id, fname) for fname in fnames
]
for fname in prompt_files:
with open(fname, "w", encoding="utf8") as f:
f.write("old stub")

# Make the connection with the prompt dir
conn = TGISConnection.from_config(
"",
{
TGISConnection.HOSTNAME_KEY: "foo.bar:1234",
TGISConnection.PROMPT_DIR_KEY: prompt_dir,
},
)

# Copy the artifacts over
conn.load_prompt_artifacts(prompt_id, *source_files)

# Make sure the artifacts are available
for fname in prompt_files:
assert os.path.exists(fname)
with open(fname, "r", encoding="utf8") as f:
assert f.read() == "old stub"


def tests_load_prompt_artifacts_exclude_swp_files():
"""Make sure that swp files in the prompt_dir causes the source file to be excluded.
This assumes another process is by copying other files."""
with tempfile.TemporaryDirectory() as source_dir:
with tempfile.TemporaryDirectory() as prompt_dir:
prompt_id = "some-prompt-id"

# Make some source files and prompt files
fnames = ["foo.pt", "bar.pt"]
source_files = [os.path.join(source_dir, fname) for fname in fnames]
for fname in source_files:
with open(fname, "w", encoding="utf8") as f:
f.write("new stub")

# Make some source files and prompt files
# Output path: prompt_dir / prompt_id / prompt_file.pt
target_dir = Path(prompt_dir) / prompt_id
os.mkdir(target_dir)
swp_file = os.path.join(prompt_dir, prompt_id, "bar.swp")
with open(swp_file, "w", encoding="utf8") as f:
f.write("in progress")

# Make the connection with the prompt dir
conn = TGISConnection.from_config(
"",
{
TGISConnection.HOSTNAME_KEY: "foo.bar:1234",
TGISConnection.PROMPT_DIR_KEY: prompt_dir,
},
)

# Copy the artifacts over
conn.load_prompt_artifacts(prompt_id, *source_files)

# Make sure the correct artifacts are available
expected_prompt_files = {"foo.pt", "bar.swp"}
assert {
f.name for f in Path(target_dir).iterdir()
} == expected_prompt_files, "Incorrect files were copied"
with open(target_dir / "foo.pt", "r", encoding="utf8") as f:
assert f.read() == "new stub", "File was not copied to prompt_dir"
with open(target_dir / "bar.swp", "r", encoding="utf8") as f:
assert f.read() == "in progress", "Swap file should not be overwritten"


def test_unload_prompt_artifacts_ok():
"""Make sure that prompt artifacts can be unloaded cleanly"""
with tempfile.TemporaryDirectory() as source_dir:
Expand Down

0 comments on commit 43f8174

Please sign in to comment.