Skip to content

Commit

Permalink
Persist cache in the same transaction (#6947)
Browse files Browse the repository at this point in the history
* Include schema_version in each QueryUnit
* Extract ExecutionGroup
* Persist cache in the same transaction
  • Loading branch information
fantix committed Mar 1, 2024
1 parent 44c9c6d commit 6cabfa9
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 52 deletions.
10 changes: 10 additions & 0 deletions edb/server/compiler/compiler.py
Expand Up @@ -2608,6 +2608,15 @@ def _try_compile(

_check_force_database_error(stmt_ctx, stmt)

# Initialize user_schema_version with the version this query is
# going to be compiled upon. This can be overwritten later by DDLs.
try:
schema_version = _get_schema_version(
stmt_ctx.state.current_tx().get_user_schema()
)
except errors.InvalidReferenceError:
schema_version = None

comp, capabilities = _compile_dispatch_ql(
stmt_ctx,
stmt,
Expand All @@ -2623,6 +2632,7 @@ def _try_compile(
capabilities=capabilities,
output_format=stmt_ctx.output_format,
cache_key=ctx.cache_key,
user_schema_version=schema_version,
)

if not comp.is_transactional:
Expand Down
5 changes: 4 additions & 1 deletion edb/server/compiler/dbstate.py
Expand Up @@ -319,7 +319,10 @@ class QueryUnit:
# If present, represents the future schema state after
# the command is run. The schema is pickled.
user_schema: Optional[bytes] = None
user_schema_version: Optional[uuid.UUID] = None
# Unlike user_schema, user_schema_version usually exist, pointing to the
# latest user schema, which is self.user_schema if changed, or the user
# schema this QueryUnit was compiled upon.
user_schema_version: uuid.UUID | None = None
cached_reflection: Optional[bytes] = None
extensions: Optional[set[str]] = None
ext_config_settings: Optional[list[config.Setting]] = None
Expand Down
27 changes: 27 additions & 0 deletions edb/server/protocol/execute.pxd
@@ -0,0 +1,27 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2024-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from edb.server.pgproto.pgproto cimport WriteBuffer


cdef class ExecutionGroup:
cdef:
object group
list bind_datas

cdef append(self, object query_unit, WriteBuffer bind_data=?)
144 changes: 93 additions & 51 deletions edb/server/protocol/execute.pyx
Expand Up @@ -47,21 +47,73 @@ from edb.server.compiler cimport rpc
from edb.server.dbview cimport dbview
from edb.server.protocol cimport args_ser
from edb.server.protocol cimport frontend
from edb.server.pgproto.pgproto cimport WriteBuffer
from edb.server.pgcon cimport pgcon
from edb.server.pgcon import errors as pgerror


cdef object logger = logging.getLogger('edb.server')

cdef object FMT_NONE = compiler.OutputFormat.NONE
cdef WriteBuffer NO_ARGS = args_ser.combine_raw_args()


async def persist_cache(
be_conn: pgcon.PGConnection,
dbv: dbview.DatabaseConnectionView,
cdef class ExecutionGroup:
def __cinit__(self):
self.group = compiler.QueryUnitGroup()
self.bind_datas = []

cdef append(self, object query_unit, WriteBuffer bind_data=NO_ARGS):
self.group.append(query_unit, serialize=False)
self.bind_datas.append(bind_data)

async def execute(
self,
pgcon.PGConnection be_conn,
dbview.DatabaseConnectionView dbv,
fe_conn: frontend.AbstractFrontendConnection = None,
bytes state = None,
):
cdef int dbver

rv = None
async with be_conn.parse_execute_script_context():
dbver = dbv.dbver
parse_array = [False] * len(self.group)
be_conn.send_query_unit_group(
self.group,
True, # sync
self.bind_datas,
state,
0, # start
len(self.group), # end
dbver,
parse_array,
)
if state is not None:
await be_conn.wait_for_state_resp(state, state_sync=0)
for i, unit in enumerate(self.group):
if unit.output_format == FMT_NONE:
for sql in unit.sql:
await be_conn.wait_for_command(
unit, parse_array[i], dbver, ignore_data=True
)
rv = None
else:
for sql in unit.sql:
rv = await be_conn.wait_for_command(
unit, parse_array[i], dbver,
ignore_data=False,
fe_conn=fe_conn,
)
return rv


cdef ExecutionGroup build_cache_persistence_units(
pairs: list[tuple[rpc.CompilationRequest, compiler.QueryUnitGroup]],
ExecutionGroup group = None,
):
if group is None:
group = ExecutionGroup()
insert_sql = b'''
INSERT INTO "edgedb"."_query_cache"
("key", "schema_version", "input", "output", "evict")
Expand All @@ -70,9 +122,6 @@ async def persist_cache(
"schema_version"=$2, "input"=$3, "output"=$4, "evict"=$5
'''
sql_hash = hashlib.sha1(insert_sql).hexdigest().encode('latin1')
no_args = args_ser.combine_raw_args()
group = compiler.QueryUnitGroup()
bind_datas = []
for request, units in pairs:
# FIXME: this is temporary; drop this assertion when we support scripts
assert len(units) == 1
Expand All @@ -85,48 +134,33 @@ async def persist_cache(
assert serialized_result is not None

if evict:
group.append(
compiler.QueryUnit(sql=(evict,), status=b''), serialize=False
)
bind_datas.append(no_args)
group.append(compiler.QueryUnit(sql=(evict,), status=b''))
if persist:
group.append(
compiler.QueryUnit(sql=(persist,), status=b''), serialize=False
)
bind_datas.append(no_args)
group.append(compiler.QueryUnit(sql=(persist,), status=b''))
group.append(
compiler.QueryUnit(
sql=(insert_sql,), sql_hash=sql_hash, status=b'',
),
serialize=False,
)
bind_datas.append(
args_ser.combine_raw_args((
query_unit.cache_key.bytes,
dbv.schema_version.bytes,
query_unit.user_schema_version.bytes,
request.serialize(),
serialized_result,
evict,
))
)),
)
return group


async def persist_cache(
be_conn: pgcon.PGConnection,
dbv: dbview.DatabaseConnectionView,
pairs: list[tuple[rpc.CompilationRequest, compiler.QueryUnitGroup]],
):
cdef group = build_cache_persistence_units(pairs)

try:
async with be_conn.parse_execute_script_context():
parse_array = [False] * len(group)
be_conn.send_query_unit_group(
group,
True, # sync
bind_datas,
None, # state
0, # start
len(group), # end
dbv.dbver,
parse_array,
)
for i, unit in enumerate(group):
await be_conn.wait_for_command(
unit, parse_array[i], dbv.dbver, ignore_data=True
)
await group.execute(be_conn, dbv)
except Exception as ex:
if (
isinstance(ex, pgerror.BackendError)
Expand Down Expand Up @@ -160,6 +194,7 @@ async def execute(
cdef:
bytes state = None, orig_state = None
WriteBuffer bound_args_buf
ExecutionGroup group

query_unit = compiled.query_unit_group[0]

Expand Down Expand Up @@ -194,13 +229,6 @@ async def execute(
else:
config_ops = query_unit.config_ops

if compiled.request and query_unit.cache_sql:
await persist_cache(
be_conn,
dbv,
[(compiled.request, compiled.query_unit_group)],
)

if query_unit.sql:
if query_unit.user_schema:
ddl_ret = await be_conn.run_ddl(query_unit, state)
Expand All @@ -217,14 +245,28 @@ async def execute(
read_data = (
query_unit.needs_readback or query_unit.is_explain)

data = await be_conn.parse_execute(
query=query_unit,
fe_conn=fe_conn if not read_data else None,
bind_data=bound_args_buf,
use_prep_stmt=use_prep_stmt,
state=state,
dbver=dbv.dbver,
)
if compiled.request and query_unit.cache_sql:
group = build_cache_persistence_units(
[(compiled.request, compiled.query_unit_group)]
)
if not use_prep_stmt:
query_unit.sql_hash = b''
group.append(query_unit, bound_args_buf)
data = await group.execute(
be_conn,
dbv,
fe_conn=fe_conn if not read_data else None,
state=state,
)
else:
data = await be_conn.parse_execute(
query=query_unit,
fe_conn=fe_conn if not read_data else None,
bind_data=bound_args_buf,
use_prep_stmt=use_prep_stmt,
state=state,
dbver=dbv.dbver,
)

if query_unit.needs_readback and data:
config_ops = [
Expand Down

0 comments on commit 6cabfa9

Please sign in to comment.