diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 9669628619..bd1fd5b5ce 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -2608,6 +2608,15 @@ def _try_compile( _check_force_database_error(stmt_ctx, stmt) + # Initialize user_schema_version with the version this query is + # going to be compiled upon. This can be overwritten later by DDLs. + try: + schema_version = _get_schema_version( + stmt_ctx.state.current_tx().get_user_schema() + ) + except errors.InvalidReferenceError: + schema_version = None + comp, capabilities = _compile_dispatch_ql( stmt_ctx, stmt, @@ -2623,6 +2632,7 @@ def _try_compile( capabilities=capabilities, output_format=stmt_ctx.output_format, cache_key=ctx.cache_key, + user_schema_version=schema_version, ) if not comp.is_transactional: diff --git a/edb/server/compiler/dbstate.py b/edb/server/compiler/dbstate.py index 720d45828e..b627337ad1 100644 --- a/edb/server/compiler/dbstate.py +++ b/edb/server/compiler/dbstate.py @@ -319,7 +319,10 @@ class QueryUnit: # If present, represents the future schema state after # the command is run. The schema is pickled. user_schema: Optional[bytes] = None - user_schema_version: Optional[uuid.UUID] = None + # Unlike user_schema, user_schema_version usually exist, pointing to the + # latest user schema, which is self.user_schema if changed, or the user + # schema this QueryUnit was compiled upon. + user_schema_version: uuid.UUID | None = None cached_reflection: Optional[bytes] = None extensions: Optional[set[str]] = None ext_config_settings: Optional[list[config.Setting]] = None diff --git a/edb/server/protocol/execute.pxd b/edb/server/protocol/execute.pxd new file mode 100644 index 0000000000..e4453c8ba9 --- /dev/null +++ b/edb/server/protocol/execute.pxd @@ -0,0 +1,27 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from edb.server.pgproto.pgproto cimport WriteBuffer + + +cdef class ExecutionGroup: + cdef: + object group + list bind_datas + + cdef append(self, object query_unit, WriteBuffer bind_data=?) diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index aa70059e44..11df2505fd 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -47,7 +47,6 @@ from edb.server.compiler cimport rpc from edb.server.dbview cimport dbview from edb.server.protocol cimport args_ser from edb.server.protocol cimport frontend -from edb.server.pgproto.pgproto cimport WriteBuffer from edb.server.pgcon cimport pgcon from edb.server.pgcon import errors as pgerror @@ -55,13 +54,66 @@ from edb.server.pgcon import errors as pgerror cdef object logger = logging.getLogger('edb.server') cdef object FMT_NONE = compiler.OutputFormat.NONE +cdef WriteBuffer NO_ARGS = args_ser.combine_raw_args() -async def persist_cache( - be_conn: pgcon.PGConnection, - dbv: dbview.DatabaseConnectionView, +cdef class ExecutionGroup: + def __cinit__(self): + self.group = compiler.QueryUnitGroup() + self.bind_datas = [] + + cdef append(self, object query_unit, WriteBuffer bind_data=NO_ARGS): + self.group.append(query_unit, serialize=False) + self.bind_datas.append(bind_data) + + async def execute( + self, + pgcon.PGConnection be_conn, + dbview.DatabaseConnectionView dbv, + fe_conn: frontend.AbstractFrontendConnection = None, + bytes state = None, + ): + cdef int dbver + + rv = None + async with be_conn.parse_execute_script_context(): + dbver = dbv.dbver + parse_array = [False] * len(self.group) + be_conn.send_query_unit_group( + self.group, + True, # sync + self.bind_datas, + state, + 0, # start + len(self.group), # end + dbver, + parse_array, + ) + if state is not None: + await be_conn.wait_for_state_resp(state, state_sync=0) + for i, unit in enumerate(self.group): + if unit.output_format == FMT_NONE: + for sql in unit.sql: + await be_conn.wait_for_command( + unit, parse_array[i], dbver, ignore_data=True + ) + rv = None + else: + for sql in unit.sql: + rv = await be_conn.wait_for_command( + unit, parse_array[i], dbver, + ignore_data=False, + fe_conn=fe_conn, + ) + return rv + + +cdef ExecutionGroup build_cache_persistence_units( pairs: list[tuple[rpc.CompilationRequest, compiler.QueryUnitGroup]], + ExecutionGroup group = None, ): + if group is None: + group = ExecutionGroup() insert_sql = b''' INSERT INTO "edgedb"."_query_cache" ("key", "schema_version", "input", "output", "evict") @@ -70,9 +122,6 @@ async def persist_cache( "schema_version"=$2, "input"=$3, "output"=$4, "evict"=$5 ''' sql_hash = hashlib.sha1(insert_sql).hexdigest().encode('latin1') - no_args = args_ser.combine_raw_args() - group = compiler.QueryUnitGroup() - bind_datas = [] for request, units in pairs: # FIXME: this is temporary; drop this assertion when we support scripts assert len(units) == 1 @@ -85,48 +134,33 @@ async def persist_cache( assert serialized_result is not None if evict: - group.append( - compiler.QueryUnit(sql=(evict,), status=b''), serialize=False - ) - bind_datas.append(no_args) + group.append(compiler.QueryUnit(sql=(evict,), status=b'')) if persist: - group.append( - compiler.QueryUnit(sql=(persist,), status=b''), serialize=False - ) - bind_datas.append(no_args) + group.append(compiler.QueryUnit(sql=(persist,), status=b'')) group.append( compiler.QueryUnit( sql=(insert_sql,), sql_hash=sql_hash, status=b'', ), - serialize=False, - ) - bind_datas.append( args_ser.combine_raw_args(( query_unit.cache_key.bytes, - dbv.schema_version.bytes, + query_unit.user_schema_version.bytes, request.serialize(), serialized_result, evict, - )) + )), ) + return group + + +async def persist_cache( + be_conn: pgcon.PGConnection, + dbv: dbview.DatabaseConnectionView, + pairs: list[tuple[rpc.CompilationRequest, compiler.QueryUnitGroup]], +): + cdef group = build_cache_persistence_units(pairs) try: - async with be_conn.parse_execute_script_context(): - parse_array = [False] * len(group) - be_conn.send_query_unit_group( - group, - True, # sync - bind_datas, - None, # state - 0, # start - len(group), # end - dbv.dbver, - parse_array, - ) - for i, unit in enumerate(group): - await be_conn.wait_for_command( - unit, parse_array[i], dbv.dbver, ignore_data=True - ) + await group.execute(be_conn, dbv) except Exception as ex: if ( isinstance(ex, pgerror.BackendError) @@ -160,6 +194,7 @@ async def execute( cdef: bytes state = None, orig_state = None WriteBuffer bound_args_buf + ExecutionGroup group query_unit = compiled.query_unit_group[0] @@ -194,13 +229,6 @@ async def execute( else: config_ops = query_unit.config_ops - if compiled.request and query_unit.cache_sql: - await persist_cache( - be_conn, - dbv, - [(compiled.request, compiled.query_unit_group)], - ) - if query_unit.sql: if query_unit.user_schema: ddl_ret = await be_conn.run_ddl(query_unit, state) @@ -217,14 +245,28 @@ async def execute( read_data = ( query_unit.needs_readback or query_unit.is_explain) - data = await be_conn.parse_execute( - query=query_unit, - fe_conn=fe_conn if not read_data else None, - bind_data=bound_args_buf, - use_prep_stmt=use_prep_stmt, - state=state, - dbver=dbv.dbver, - ) + if compiled.request and query_unit.cache_sql: + group = build_cache_persistence_units( + [(compiled.request, compiled.query_unit_group)] + ) + if not use_prep_stmt: + query_unit.sql_hash = b'' + group.append(query_unit, bound_args_buf) + data = await group.execute( + be_conn, + dbv, + fe_conn=fe_conn if not read_data else None, + state=state, + ) + else: + data = await be_conn.parse_execute( + query=query_unit, + fe_conn=fe_conn if not read_data else None, + bind_data=bound_args_buf, + use_prep_stmt=use_prep_stmt, + state=state, + dbver=dbv.dbver, + ) if query_unit.needs_readback and data: config_ops = [