Skip to content

Commit

Permalink
feat: pass custom Client object to dbapi (#911)
Browse files Browse the repository at this point in the history
  • Loading branch information
asthamohta committed Mar 28, 2023
1 parent 520d6d7 commit 52b1a0a
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 13 deletions.
33 changes: 20 additions & 13 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def connect(
credentials=None,
pool=None,
user_agent=None,
client=None,
):
"""Creates a connection to a Google Cloud Spanner database.
Expand Down Expand Up @@ -529,25 +530,31 @@ def connect(
:param user_agent: (Optional) User agent to be used with this connection's
requests.
:type client: Concrete subclass of
:class:`~google.cloud.spanner_v1.Client`.
:param client: (Optional) Custom user provided Client Object
:rtype: :class:`google.cloud.spanner_dbapi.connection.Connection`
:returns: Connection object associated with the given Google Cloud Spanner
resource.
"""

client_info = ClientInfo(
user_agent=user_agent or DEFAULT_USER_AGENT,
python_version=PY_VERSION,
client_library_version=spanner.__version__,
)

if isinstance(credentials, str):
client = spanner.Client.from_service_account_json(
credentials, project=project, client_info=client_info
if client is None:
client_info = ClientInfo(
user_agent=user_agent or DEFAULT_USER_AGENT,
python_version=PY_VERSION,
client_library_version=spanner.__version__,
)
if isinstance(credentials, str):
client = spanner.Client.from_service_account_json(
credentials, project=project, client_info=client_info
)
else:
client = spanner.Client(
project=project, credentials=credentials, client_info=client_info
)
else:
client = spanner.Client(
project=project, credentials=credentials, client_info=client_info
)
if project is not None and client.project != project:
raise ValueError("project in url does not match client object project")

instance = client.instance(instance_id)
conn = Connection(instance, instance.database(database_id, pool=pool))
Expand Down
46 changes: 46 additions & 0 deletions tests/unit/spanner_dbapi/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import mock
import unittest
import warnings
import pytest

PROJECT = "test-project"
INSTANCE = "test-instance"
Expand Down Expand Up @@ -915,7 +916,52 @@ def test_request_priority(self):
sql, params, param_types=param_types, request_options=None
)

@mock.patch("google.cloud.spanner_v1.Client")
def test_custom_client_connection(self, mock_client):
from google.cloud.spanner_dbapi import connect

client = _Client()
connection = connect("test-instance", "test-database", client=client)
self.assertTrue(connection.instance._client == client)

@mock.patch("google.cloud.spanner_v1.Client")
def test_invalid_custom_client_connection(self, mock_client):
from google.cloud.spanner_dbapi import connect

client = _Client()
with pytest.raises(ValueError):
connect(
"test-instance",
"test-database",
project="invalid_project",
client=client,
)


def exit_ctx_func(self, exc_type, exc_value, traceback):
"""Context __exit__ method mock."""
pass


class _Client(object):
def __init__(self, project="project_id"):
self.project = project
self.project_name = "projects/" + self.project

def instance(self, instance_id="instance_id"):
return _Instance(name=instance_id, client=self)


class _Instance(object):
def __init__(self, name="instance_id", client=None):
self.name = name
self._client = client

def database(self, database_id="database_id", pool=None):
return _Database(database_id, pool)


class _Database(object):
def __init__(self, database_id="database_id", pool=None):
self.name = database_id
self.pool = pool

0 comments on commit 52b1a0a

Please sign in to comment.