Skip to content

Commit

Permalink
Merge pull request #48 from gabe-l-hart/PersistRemoteModelMapping
Browse files Browse the repository at this point in the history
Persist remote model mapping
  • Loading branch information
gabe-l-hart committed Nov 30, 2023
2 parents 662b8f5 + 6ee48c7 commit a91c3e9
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 14 deletions.
18 changes: 7 additions & 11 deletions caikit_tgis_backend/tgis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def __init__(self, config: Optional[dict] = None):
# TGIS instance or running a local copy
connection_cfg = self.config.get("connection") or {}
error.type_check("<TGB20235229E>", dict, connection=connection_cfg)
remote_models_cfg = self.config.get("remote_models") or {}
error.type_check("<TGB20235338E>", dict, connection=remote_models_cfg)
self._remote_models_cfg = self.config.get("remote_models") or {}
error.type_check("<TGB20235338E>", dict, connection=self._remote_models_cfg)
local_cfg = self.config.get("local") or {}
error.type_check("<TGB20235225E>", dict, local=local_cfg)

Expand All @@ -99,7 +99,7 @@ def __init__(self, config: Optional[dict] = None):
)

# Parse connection objects for all model-specific connections
for model_id, model_conn_cfg in remote_models_cfg.items():
for model_id, model_conn_cfg in self._remote_models_cfg.items():
model_conn = TGISConnection.from_config(model_id, model_conn_cfg)
error.value_check(
"<TGB90377847E>",
Expand All @@ -124,7 +124,7 @@ def __init__(self, config: Optional[dict] = None):

# We manage a local TGIS instance if there are no remote connections
# specified as either a valid base connection or remote_connections
self._local_tgis = not self._base_connection_cfg and not self._model_connections
self._local_tgis = not self._base_connection_cfg and not self._remote_models_cfg
log.info("Running %s TGIS backend", "LOCAL" if self._local_tgis else "REMOTE")

if self._local_tgis:
Expand Down Expand Up @@ -172,13 +172,9 @@ def get_connection(
) -> Optional[TGISConnection]:
"""Get the TGISConnection object for the given model"""
model_conn = self._model_connections.get(model_id)
if (
not model_conn
and create
and not self.local_tgis
and self._base_connection_cfg
):
model_conn = TGISConnection.from_config(model_id, self._base_connection_cfg)
conn_cfg = self._remote_models_cfg.get(model_id, self._base_connection_cfg)
if not model_conn and create and not self.local_tgis and conn_cfg:
model_conn = TGISConnection.from_config(model_id, conn_cfg)
if self._test_connections:
try:
model_conn.test_connection()
Expand Down
31 changes: 31 additions & 0 deletions tests/test_tgis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# Third Party
import grpc
import pytest
import tls_test_tools

# First Party
import caikit
Expand Down Expand Up @@ -245,6 +246,36 @@ def test_stop():
assert not tgis_be.get_connection(model_id2, False)


def test_lazy_start_remote_model():
"""Make sure that an entry in the remote_models config can be respected even
if it is invalid at instantiation and test_connections is enabled
"""
# Set up a TGIS connection that will not be valid yet
port = tls_test_tools.open_port()
model_name = "some-model"
cfg = {
"remote_models": {model_name: {"hostname": f"localhost:{port}"}},
"test_connections": True,
}

# Initialize the backend and make sure getting the model's connection
# returns None
tgis_be = TGISBackend(cfg)
assert tgis_be.get_connection(model_name) is None

# Now boot the TGIS instance and try again
with TGISMock(grpc_port=port):
max_time = 5
start_time = time.time()
conn = None
while time.time() - start_time < max_time:
conn = tgis_be.get_connection(model_name)
if conn:
break
time.sleep(0.1)
assert conn


## Local Subprocess ############################################################


Expand Down
9 changes: 6 additions & 3 deletions tests/tgis_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,20 @@ def __init__(
prompt_responses: Optional[Dict[str, str]] = None,
health_delay: float = 0.0,
san_list: List[str] = None,
grpc_port: Optional[int] = None,
http_port: Optional[int] = None,
):
self.prompt_responses = prompt_responses
self.health_delay = health_delay

# find ports for the servers
self.grpc_port = tls_test_tools.open_port()
self.grpc_port = grpc_port or tls_test_tools.open_port()
self.hostname = f"localhost:{self.grpc_port}"

self.http_port = tls_test_tools.open_port()
self.http_port = http_port or tls_test_tools.open_port()

# generate TLS certificates
self.tls = tls
self.mtls = mtls
if tls or mtls:
self.ca_key = tls_test_tools.generate_key()[0]
self.ca_cert = tls_test_tools.generate_ca_cert(self.ca_key)
Expand Down

0 comments on commit a91c3e9

Please sign in to comment.