Skip to content

Commit

Permalink
Merge pull request #40 from dtrifiro/add-test-connection-timeout
Browse files Browse the repository at this point in the history
Add test connection timeout
  • Loading branch information
gabe-l-hart committed Nov 3, 2023
2 parents b8ecf47 + 7a9a346 commit 7b37304
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 4 deletions.
36 changes: 35 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,37 @@
## Caikit Text Generation Inference Service (TGIS) Backend
# Caikit Text Generation Inference Service (TGIS) Backend

This project provides a Caikit module backend that manages models run in TGIS

## Configuration

Sample configuration using the `MULTI` finder and a remote `TGIS` backend):

```yaml
runtime:
library: caikit_nlp
local_models_dir: /path/to/models
lazy_load_local_models: true

model_management:
finders:
default:
type: MULTI
config:
finder_priority:
- local
- tgis-auto
initializers:
default:
type: LOCAL
config:
backend_priority:
- type: TGIS
config:
connection:
hostname: "localhost:8033"
test_connections: true
connect_timeout: 30

log:
formatter: pretty
```
3 changes: 2 additions & 1 deletion caikit_tgis_backend/tgis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self, config: Optional[dict] = None):
self._managed_tgis = None
self._model_connections = {}
self._test_connections = self.config.get("test_connections", False)
self._connect_timeout = self.config.get("connect_timeout", None)

# Parse the config to see if we're managing a connection to a remote
# TGIS instance or running a local copy
Expand Down Expand Up @@ -108,7 +109,7 @@ def __init__(self, config: Optional[dict] = None):
)
if self._test_connections:
try:
model_conn.test_connection()
model_conn.test_connection(timeout=self._connect_timeout)
except grpc.RpcError as err:
log.warning(
"<TGB95244222W>",
Expand Down
7 changes: 5 additions & 2 deletions caikit_tgis_backend/tgis_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,15 @@ def get_client(self) -> generation_pb2_grpc.GenerationServiceStub:
self._client = load_balancer.client
return self._client

def test_connection(self):
def test_connection(self, timeout: Optional[float] = None):
"""Test whether the connection is valid. If not valid, an appropriate
grpc.RpcError will be raised
"""
client = self.get_client()
client.ModelInfo(generation_pb2.ModelInfoRequest(model_id=self.model_id))
client.ModelInfo(
generation_pb2.ModelInfoRequest(model_id=self.model_id),
timeout=timeout,
)

@staticmethod
def _load_tls_file(file_path: Optional[str]) -> Optional[bytes]:
Expand Down
1 change: 1 addition & 0 deletions tests/test_tgis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,3 +701,4 @@ def test_tgis_backend_conn_testing_enabled(tgis_mock_insecure):
assert tgis_be.is_started
conn = tgis_be.get_connection(model_id)
conn.test_connection()
conn.test_connection(timeout=1)

0 comments on commit 7b37304

Please sign in to comment.