Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Implementation of batch ddl in dbapi #1092

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from enum import Enum
from typing import TYPE_CHECKING, List

from google.cloud.spanner_dbapi import parse_utils
from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
StatementType,
Expand All @@ -28,6 +30,48 @@

if TYPE_CHECKING:
from google.cloud.spanner_dbapi.cursor import Cursor
from google.cloud.spanner_dbapi.connection import Connection


class BatchDdlExecutor:
"""Executor that is used when a DDL batch is started. These batches only
accept DDL statements. All DDL statements are buffered locally and sent to
Spanner when runBatch() is called.

:type "Connection": :class:`~google.cloud.spanner_dbapi.connection.Connection`
:param connection:
"""

def __init__(self, connection: "Connection"):
self._connection = connection
self._statements: List[str] = []

def execute_statement(self, parsed_statement: ParsedStatement):
"""Executes the statement when ddl batch is active by buffering the
statement in-memory.

:type parsed_statement: ParsedStatement
:param parsed_statement: parsed statement containing sql query
"""
from google.cloud.spanner_dbapi import ProgrammingError

if parsed_statement.statement_type != StatementType.DDL:
raise ProgrammingError("Only DDL statements are allowed in batch DDL mode.")
self._statements.extend(
parse_utils.parse_and_get_ddl_statements(parsed_statement.statement.sql)
)

def run_batch(self):
"""Executes all the buffered statements on the active ddl batch by
making a call to Spanner.
"""
from google.cloud.spanner_dbapi import ProgrammingError

if self._connection._client_transaction_started:
raise ProgrammingError(
"Cannot execute DDL statement when transaction is already active."
)
return self._connection.database.update_ddl(self._statements).result()


class BatchDmlExecutor:
Expand All @@ -52,6 +96,7 @@ def execute_statement(self, parsed_statement: ParsedStatement):
:param parsed_statement: parsed statement containing sql query and query
params
"""

from google.cloud.spanner_dbapi import ProgrammingError

if (
Expand All @@ -61,7 +106,7 @@ def execute_statement(self, parsed_statement: ParsedStatement):
raise ProgrammingError("Only DML statements are allowed in batch DML mode.")
self._statements.append(parsed_statement.statement)

def run_batch_dml(self):
def run_batch(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this a breaking change?

We can probably get away with it now, but this is something that we need to be more careful with, and we should try to mark whatever is not for public consumption as such as much as possible.

Either by adding a _ at the start of the method name, or by explicitly adding documentation that the method is not for public use, and can change at any moment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack. Added the following documentation on the method This method is internal and not for public use as it can change anytime.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of renaming the method, we could mark it as deprecated and create a new one with the _ prefix which indicates that it is internal. Please avoid a major version bump if at all possible as it is disruptive for downstream users.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the reason because of which we are proposing a major version bump. Please check https://docs.google.com/document/d/1mSVC_AhPpdSA0xUoj4LUt2o2tVM_LDv6IdssiwWnbmw for details

"""Executes all the buffered statements on the active dml batch by
making a call to Spanner.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
TypeCode.TIMESTAMP,
column_values,
)
if statement_type == ClientSideStatementType.START_BATCH_DDL:
connection.start_batch_ddl()
return None
if statement_type == ClientSideStatementType.START_BATCH_DML:
connection.start_batch_dml(cursor)
return None
Expand Down
3 changes: 3 additions & 0 deletions google/cloud/spanner_dbapi/client_side_statement_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
RE_SHOW_READ_TIMESTAMP = re.compile(
r"^\s*(SHOW)\s+(VARIABLE)\s+(READ_TIMESTAMP)", re.IGNORECASE
)
RE_START_BATCH_DDL = re.compile(r"^\s*(START)\s+(BATCH)\s+(DDL)", re.IGNORECASE)
RE_START_BATCH_DML = re.compile(r"^\s*(START)\s+(BATCH)\s+(DML)", re.IGNORECASE)
RE_RUN_BATCH = re.compile(r"^\s*(RUN)\s+(BATCH)", re.IGNORECASE)
RE_ABORT_BATCH = re.compile(r"^\s*(ABORT)\s+(BATCH)", re.IGNORECASE)
Expand Down Expand Up @@ -62,6 +63,8 @@ def parse_stmt(query):
client_side_statement_type = ClientSideStatementType.SHOW_COMMIT_TIMESTAMP
elif RE_SHOW_READ_TIMESTAMP.match(query):
client_side_statement_type = ClientSideStatementType.SHOW_READ_TIMESTAMP
elif RE_START_BATCH_DDL.match(query):
client_side_statement_type = ClientSideStatementType.START_BATCH_DDL
elif RE_START_BATCH_DML.match(query):
client_side_statement_type = ClientSideStatementType.START_BATCH_DML
elif RE_BEGIN.match(query):
Expand Down
67 changes: 61 additions & 6 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from google.api_core.gapic_v1.client_info import ClientInfo
from google.cloud import spanner_v1 as spanner
from google.cloud.spanner_dbapi import partition_helper
from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor
from google.cloud.spanner_dbapi.batch_executor import (
BatchMode,
BatchDmlExecutor,
BatchDdlExecutor,
)
from google.cloud.spanner_dbapi.parse_utils import _get_statement_type
from google.cloud.spanner_dbapi.parsed_statement import (
StatementType,
Expand Down Expand Up @@ -91,7 +95,9 @@ class Connection:
should end a that a new one should be started when the next statement is executed.
"""

def __init__(self, instance, database=None, read_only=False):
def __init__(
self, instance, database=None, read_only=False, buffer_ddl_statements=False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

buffer_ddl_statements=False means that we are changing the default behavior in a significant way. That requires a major version bump, and should be clearly called out in the release notes. The description of the PR will automatically be included in the release notes, so please add information about this change there.

Also, in order to force a major version bump, you can add a ! to the title like this: feat!: Support explicit DDL batching

(Also note that the title of the PR will be included in the release notes, and as such is intended for public consumption. We should therefore try to keep that title as descriptive as possible for external users so they understand what the change is.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added description and changed title to force a major version bump

):
self._instance = instance
self._database = database
self._ddl_statements = []
Expand All @@ -114,8 +120,10 @@ def __init__(self, instance, database=None, read_only=False):
# made atleast one call to Spanner.
self._spanner_transaction_started = False
self._batch_mode = BatchMode.NONE
self._batch_ddl_executor: BatchDdlExecutor = None
self._batch_dml_executor: BatchDmlExecutor = None
self._transaction_helper = TransactionRetryHelper(self)
self._buffer_ddl_statements = buffer_ddl_statements

@property
def autocommit(self):
Expand All @@ -126,6 +134,15 @@ def autocommit(self):
"""
return self._autocommit

@property
def buffer_ddl_statements(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a read-only property? As in; can it only be set when a connection is created, and never changed after that? If so, is that something that we can/should document? Or is it clear from the context?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes its a read only property and there are no setters for it but in python the private variable can also be overridden.

"""Whether to buffer ddl statements for this connection.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs more documentation. I (think I) understand what it means. To someone who is not familiar with Cloud Spanner, it is not clear what this means. It is also not clear when they should enable/disable this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


:rtype: bool
:returns: _buffer_ddl_statements flag value.
"""
return self._buffer_ddl_statements

@autocommit.setter
def autocommit(self, value):
"""Change this connection autocommit mode. Setting this value to True
Expand Down Expand Up @@ -365,7 +382,8 @@ def commit(self):
)
return

self.run_prior_DDL_statements()
if self.buffer_ddl_statements:
self.run_prior_DDL_statements()
try:
if self._spanner_transaction_started and not self._read_only:
self._transaction.commit()
Expand Down Expand Up @@ -463,6 +481,31 @@ def validate(self):
"Expected: [[1]]" % result
)

@check_not_closed
def start_batch_ddl(self):
if self._batch_mode is not BatchMode.NONE:
raise ProgrammingError(
"Cannot start a DDL batch when a batch is already active"
)
if self.read_only:
raise ProgrammingError(
"Cannot start a DDL batch when the connection is in read-only mode"
)
if self.buffer_ddl_statements:
raise ProgrammingError(
"Cannot start a DDL batch when _buffer_ddl_statements flag is True"
)
self._batch_mode = BatchMode.DDL
self._batch_ddl_executor = BatchDdlExecutor(self)

@check_not_closed
def execute_batch_ddl_statement(self, parsed_statement: ParsedStatement):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this method to be public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comment that its an internal method to this and other methods as well

if self._batch_mode is not BatchMode.DDL:
raise ProgrammingError(
"Cannot execute statement when the BatchMode is not DDL"
)
self._batch_ddl_executor.execute_statement(parsed_statement)

@check_not_closed
def start_batch_dml(self, cursor):
if self._batch_mode is not BatchMode.NONE:
Expand All @@ -486,22 +529,28 @@ def execute_batch_dml_statement(self, parsed_statement: ParsedStatement):

@check_not_closed
def run_batch(self):
result_set = None
olavloite marked this conversation as resolved.
Show resolved Hide resolved
if self._batch_mode is BatchMode.NONE:
raise ProgrammingError("Cannot run a batch when the BatchMode is not set")
try:
if self._batch_mode is BatchMode.DML:
many_result_set = self._batch_dml_executor.run_batch_dml()
result_set = self._batch_dml_executor.run_batch()
elif self._batch_mode is BatchMode.DDL:
self._batch_ddl_executor.run_batch()
finally:
self._batch_mode = BatchMode.NONE
self._batch_dml_executor = None
return many_result_set
self._batch_ddl_executor = None
return result_set

@check_not_closed
def abort_batch(self):
if self._batch_mode is BatchMode.NONE:
raise ProgrammingError("Cannot abort a batch when the BatchMode is not set")
if self._batch_mode is BatchMode.DML:
self._batch_dml_executor = None
if self._batch_mode is BatchMode.DDL:
self._batch_ddl_executor = None
self._batch_mode = BatchMode.NONE

@check_not_closed
Expand Down Expand Up @@ -584,10 +633,14 @@ def connect(
pool=None,
user_agent=None,
client=None,
buffer_ddl_statements=False,
route_to_leader_enabled=True,
):
"""Creates a connection to a Google Cloud Spanner database.

:type buffer_ddl_statements: bool
:param buffer_ddl_statements:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a description of what it does, and when you should enable/disable it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added


:type instance_id: str
:param instance_id: The ID of the instance to connect to.

Expand Down Expand Up @@ -658,7 +711,9 @@ def connect(

instance = client.instance(instance_id)
conn = Connection(
instance, instance.database(database_id, pool=pool) if database_id else None
instance,
instance.database(database_id, pool=pool) if database_id else None,
buffer_ddl_statements=buffer_ddl_statements,
)
if pool is not None:
conn._own_pool = False
Expand Down
50 changes: 24 additions & 26 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
"""Database cursor for Google Cloud Spanner DB API."""
from collections import namedtuple

import sqlparse

from google.api_core.exceptions import Aborted
from google.api_core.exceptions import AlreadyExists
from google.api_core.exceptions import FailedPrecondition
Expand All @@ -25,7 +23,7 @@
from google.api_core.exceptions import OutOfRange

from google.cloud import spanner_v1 as spanner
from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode
from google.cloud.spanner_dbapi.batch_executor import BatchMode
from google.cloud.spanner_dbapi.exceptions import IntegrityError
from google.cloud.spanner_dbapi.exceptions import InterfaceError
from google.cloud.spanner_dbapi.exceptions import OperationalError
Expand All @@ -34,7 +32,7 @@
from google.cloud.spanner_dbapi import (
_helpers,
client_side_statement_executor,
batch_dml_executor,
batch_executor,
)
from google.cloud.spanner_dbapi._helpers import ColumnInfo
from google.cloud.spanner_dbapi._helpers import CODE_TO_DISPLAY_SIZE
Expand Down Expand Up @@ -210,18 +208,8 @@ def _batch_DDLs(self, sql):
:raises: :class:`ValueError` in case not a DDL statement
present in the operation.
"""
statements = []
for ddl in sqlparse.split(sql):
if ddl:
ddl = ddl.rstrip(";")
if (
parse_utils.classify_statement(ddl).statement_type
!= StatementType.DDL
):
raise ValueError("Only DDL statements may be batched.")

statements.append(ddl)

statements = parse_utils.parse_and_get_ddl_statements(sql)
# Only queue DDL statements if they are all correctly classified.
self.connection._ddl_statements.extend(statements)

Expand Down Expand Up @@ -261,6 +249,8 @@ def _execute(self, sql, args=None, call_from_execute_many=False):
self._itr = self._result_set
else:
self._itr = PeekIterator(self._result_set)
elif self.connection._batch_mode == BatchMode.DDL:
self.connection.execute_batch_ddl_statement(self._parsed_statement)
elif self.connection._batch_mode == BatchMode.DML:
self.connection.execute_batch_dml_statement(self._parsed_statement)
elif self.connection.read_only or (
Expand All @@ -269,9 +259,18 @@ def _execute(self, sql, args=None, call_from_execute_many=False):
):
self._handle_DQL(sql, args or None)
elif self._parsed_statement.statement_type == StatementType.DDL:
self._batch_DDLs(sql)
if not self.connection._client_transaction_started:
self.connection.run_prior_DDL_statements()
if not self.connection.buffer_ddl_statements:
if not self.connection._client_transaction_started:
self._batch_DDLs(sql)
self.connection.run_prior_DDL_statements()
else:
raise ProgrammingError(
olavloite marked this conversation as resolved.
Show resolved Hide resolved
"Cannot execute DDL statement when transaction is already active"
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
)
else:
self._batch_DDLs(sql)
if not self.connection._client_transaction_started:
self.connection.run_prior_DDL_statements()
else:
self._execute_in_rw_transaction()

Expand All @@ -296,9 +295,8 @@ def _execute(self, sql, args=None, call_from_execute_many=False):
self.connection._spanner_transaction_started = False

def _execute_in_rw_transaction(self):
# For every other operation, we've got to ensure that
# any prior DDL statements were run.
self.connection.run_prior_DDL_statements()
if self.connection.buffer_ddl_statements:
self.connection.run_prior_DDL_statements()
statement = self._parsed_statement.statement
if self.connection._client_transaction_started:
while True:
Expand Down Expand Up @@ -347,9 +345,8 @@ def executemany(self, operation, seq_of_params):
+ ", with executemany() method is not allowed."
)

# For every operation, we've got to ensure that any prior DDL
# statements were run.
self.connection.run_prior_DDL_statements()
if self.connection.buffer_ddl_statements:
self.connection.run_prior_DDL_statements()
if self._parsed_statement.statement_type in (
StatementType.INSERT,
StatementType.UPDATE,
Expand All @@ -360,7 +357,7 @@ def executemany(self, operation, seq_of_params):
operation, params
)
statements.append(Statement(sql, params, get_param_types(params)))
many_result_set = batch_dml_executor.run_batch_dml(self, statements)
many_result_set = batch_executor.run_batch_dml(self, statements)
else:
many_result_set = StreamedManyResultSets()
for params in seq_of_params:
Expand Down Expand Up @@ -523,7 +520,8 @@ def run_sql_in_snapshot(self, sql, params=None, param_types=None):
# hence this method exists to circumvent that limit.
if self.connection.database is None:
raise ValueError("Database needs to be passed for this operation")
self.connection.run_prior_DDL_statements()
if self.connection.buffer_ddl_statements:
self.connection.run_prior_DDL_statements()

with self.connection.database.snapshot() as snapshot:
return list(snapshot.execute_sql(sql, params, param_types))
Expand Down
11 changes: 11 additions & 0 deletions google/cloud/spanner_dbapi/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,17 @@ def classify_stmt(query):
return STMT_UPDATING


def parse_and_get_ddl_statements(sql):
statements = []
for ddl in sqlparse.split(sql):
if ddl:
ddl = ddl.rstrip(";")
if classify_statement(ddl).statement_type != StatementType.DDL:
raise ValueError("Only DDL statements may be batched.")
statements.append(ddl)
return statements


def classify_statement(query, args=None):
"""Determine SQL query type.

Expand Down
1 change: 1 addition & 0 deletions google/cloud/spanner_dbapi/parsed_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class ClientSideStatementType(Enum):
PARTITION_QUERY = 9
RUN_PARTITION = 10
RUN_PARTITIONED_QUERY = 11
START_BATCH_DDL = 12


@dataclass
Expand Down