Skip to content

Commit

Permalink
Fix all fixable stubtest_allowlist entries in SQLAlchemy (python#9596)
Browse files Browse the repository at this point in the history
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com>
  • Loading branch information
3 people committed Apr 14, 2023
1 parent 08e6e4c commit b0ed50e
Show file tree
Hide file tree
Showing 23 changed files with 379 additions and 1,268 deletions.
1,173 changes: 22 additions & 1,151 deletions stubs/SQLAlchemy/@tests/stubtest_allowlist.txt

Large diffs are not rendered by default.

60 changes: 60 additions & 0 deletions stubs/SQLAlchemy/@tests/test_cases/check_loader_option.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

from typing_extensions import assert_type

from sqlalchemy.orm.strategy_options import (
Load,
contains_eager,
defaultload,
defer,
immediateload,
joinedload,
lazyload,
load_only,
loader_option,
noload,
raiseload,
selectin_polymorphic,
selectinload,
subqueryload,
undefer,
undefer_group,
with_expression,
)


def fn(loadopt: Load, *args: object) -> loader_option:
return loader_option()


# Testing that the function and return type of function are actually all instances of "loader_option"
assert_type(contains_eager, loader_option)
assert_type(contains_eager(fn), loader_option)
assert_type(load_only, loader_option)
assert_type(load_only(fn), loader_option)
assert_type(joinedload, loader_option)
assert_type(joinedload(fn), loader_option)
assert_type(subqueryload, loader_option)
assert_type(subqueryload(fn), loader_option)
assert_type(selectinload, loader_option)
assert_type(selectinload(fn), loader_option)
assert_type(lazyload, loader_option)
assert_type(lazyload(fn), loader_option)
assert_type(immediateload, loader_option)
assert_type(immediateload(fn), loader_option)
assert_type(noload, loader_option)
assert_type(noload(fn), loader_option)
assert_type(raiseload, loader_option)
assert_type(raiseload(fn), loader_option)
assert_type(defaultload, loader_option)
assert_type(defaultload(fn), loader_option)
assert_type(defer, loader_option)
assert_type(defer(fn), loader_option)
assert_type(undefer, loader_option)
assert_type(undefer(fn), loader_option)
assert_type(undefer_group, loader_option)
assert_type(undefer_group(fn), loader_option)
assert_type(with_expression, loader_option)
assert_type(with_expression(fn), loader_option)
assert_type(selectin_polymorphic, loader_option)
assert_type(selectin_polymorphic(fn), loader_option)
67 changes: 67 additions & 0 deletions stubs/SQLAlchemy/@tests/test_cases/check_register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations

from _typeshed.dbapi import DBAPIConnection
from typing import cast

from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.engine.url import URL
from sqlalchemy.pool.base import Pool
from sqlalchemy.testing import config as ConfigModule
from sqlalchemy.testing.provision import (
configure_follower,
create_db,
drop_all_schema_objects_post_tables,
drop_all_schema_objects_pre_tables,
drop_db,
follower_url_from_main,
generate_driver_url,
get_temp_table_name,
post_configure_engine,
prepare_for_drop_tables,
register,
run_reap_dbs,
set_default_schema_on_connection,
stop_test_class_outside_fixtures,
temp_table_keyword_args,
update_db_opts,
)
from sqlalchemy.util import immutabledict

url = URL("", "", "", "", 0, "", immutabledict())
engine = Engine(Pool(lambda: cast(DBAPIConnection, object())), DefaultDialect(), "")
config = cast(ConfigModule.Config, object())
unused = None


class Foo:
pass


# Test that the decorator changes the first parameter to "cfg: str | URL | _ConfigProtocol"
@register.init
def no_args(__foo: Foo) -> None:
pass


no_args(cfg="")
no_args(cfg=url)
no_args(cfg=config)

# Test pre-decorated functions
generate_driver_url(url, "", "")
drop_all_schema_objects_pre_tables(url, unused)
drop_all_schema_objects_post_tables(url, unused)
create_db(url, engine, unused)
drop_db(url, engine, unused)
update_db_opts(url, unused)
post_configure_engine(url, unused, unused)
follower_url_from_main(url, "")
configure_follower(url, unused)
run_reap_dbs(url, unused)
temp_table_keyword_args(url, engine)
prepare_for_drop_tables(url, unused)
stop_test_class_outside_fixtures(url, unused, type)
get_temp_table_name(url, unused, "")
set_default_schema_on_connection(ConfigModule, unused, unused)
set_default_schema_on_connection(config, unused, unused)
1 change: 0 additions & 1 deletion stubs/SQLAlchemy/sqlalchemy/dialects/mssql/base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ class MSExecutionContext(default.DefaultExecutionContext):
@property
def rowcount(self): ...
def handle_dbapi_exception(self, e) -> None: ...
def get_result_cursor_strategy(self, result): ...
def fire_sequence(self, seq, type_): ...
def get_insert_default(self, column): ...

Expand Down
1 change: 1 addition & 0 deletions stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ class PGCompiler(compiler.SQLCompiler):
class PGDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kwargs): ...
def visit_check_constraint(self, constraint): ...
def visit_foreign_key_constraint(self, constraint) -> str: ... # type: ignore[override] # Different params
def visit_drop_table_comment(self, drop): ...
def visit_create_enum_type(self, create): ...
def visit_drop_enum_type(self, drop): ...
Expand Down
1 change: 0 additions & 1 deletion stubs/SQLAlchemy/sqlalchemy/engine/interfaces.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ class ExecutionContext:
def pre_exec(self) -> None: ...
def get_out_parameter_values(self, out_param_names) -> None: ...
def post_exec(self) -> None: ...
def get_result_cursor_strategy(self, result) -> None: ...
def handle_dbapi_exception(self, e) -> None: ...
def should_autocommit_text(self, statement) -> None: ...
def lastrow_has_defaults(self) -> None: ...
Expand Down
1 change: 1 addition & 0 deletions stubs/SQLAlchemy/sqlalchemy/engine/url.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class _URLTuple(NamedTuple):
_Query: TypeAlias = Mapping[str, str | Sequence[str]] | Sequence[tuple[str, str | Sequence[str]]]

class URL(_URLTuple):
def __new__(self, *arg, **kw) -> Self | URL: ...
@classmethod
def create(
cls,
Expand Down
2 changes: 1 addition & 1 deletion stubs/SQLAlchemy/sqlalchemy/event/base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class _Dispatch:
class _EventMeta(type):
def __init__(cls, classname, bases, dict_) -> None: ...

class Events:
class Events(metaclass=_EventMeta):
dispatch: Any

class _JoinedDispatcher:
Expand Down
54 changes: 46 additions & 8 deletions stubs/SQLAlchemy/sqlalchemy/orm/collections.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from _typeshed import Incomplete
from typing import Any
from _typeshed import Incomplete, SupportsKeysAndGetItem
from collections.abc import Iterable
from typing import Any, TypeVar, overload
from typing_extensions import Literal, SupportsIndex

from ..orm.attributes import Event
from ..util.langhelpers import _symbol, symbol

_T = TypeVar("_T")
_KT = TypeVar("_KT")
_VT = TypeVar("_VT")

class _PlainColumnGetter:
cols: Any
Expand Down Expand Up @@ -81,12 +90,41 @@ class CollectionAdapter:
def fire_remove_event(self, item, initiator: Incomplete | None = None) -> None: ...
def fire_pre_remove_event(self, initiator: Incomplete | None = None) -> None: ...

class InstrumentedList(list[Any]): ...
class InstrumentedSet(set[Any]): ...
class InstrumentedDict(dict[Any, Any]): ...
class InstrumentedList(list[_T]):
def append(self, item, _sa_initiator: Event | Literal[False] | None = None) -> None: ...
def clear(self, index: SupportsIndex = -1) -> None: ...
def extend(self, iterable: Iterable[_T]) -> None: ...
def insert(self, index: SupportsIndex, value: _T) -> None: ...
def pop(self, index: SupportsIndex = -1) -> _T: ...
def remove(self, value: _T, _sa_initiator: Event | Literal[False] | None = None) -> None: ...

class InstrumentedSet(set[_T]):
def add(self, value: _T, _sa_initiator: Event | Literal[False] | None = None) -> None: ...
def difference_update(self, value: Iterable[_T]) -> None: ... # type: ignore[override]
def discard(self, value: _T, _sa_initiator: Event | Literal[False] | None = None) -> None: ...
def intersection_update(self, other: Iterable[_T]) -> None: ... # type: ignore[override]
def remove(self, value: _T, _sa_initiator: Event | Literal[False] | None = None) -> None: ...
def symmetric_difference_update(self, other: Iterable[_T]) -> None: ...
def update(self, value: Iterable[_T]) -> None: ... # type: ignore[override]

class InstrumentedDict(dict[_KT, _VT]): ...

class MappedCollection(dict[Any, Any]):
class MappedCollection(dict[_KT, _VT]):
keyfunc: Any
def __init__(self, keyfunc) -> None: ...
def set(self, value, _sa_initiator: Incomplete | None = None) -> None: ...
def remove(self, value, _sa_initiator: Incomplete | None = None) -> None: ...
def set(self, value: _VT, _sa_initiator: Event | Literal[False] | None = None) -> None: ...
def remove(self, value: _VT, _sa_initiator: Event | Literal[False] | None = None) -> None: ...
def __delitem__(self, key: _KT, _sa_initiatorEvent: Event | Literal[False] | None = None) -> None: ...
def __setitem__(self, key: _KT, value: _VT, _sa_initiator: Event | Literal[False] | None = None) -> None: ...
@overload
def pop(self, key: _KT) -> _VT: ...
@overload
def pop(self, key: _KT, default: _VT | _T | _symbol | symbol = ...) -> _VT | _T: ...
@overload # type: ignore[override]
def setdefault(self, key: _KT, default: _T) -> _VT | _T: ...
@overload
def setdefault(self, key: _KT, default: None = None) -> _VT | None: ...
@overload
def update(self, __other: SupportsKeysAndGetItem[_KT, _VT] = ..., **kwargs: _VT) -> None: ...
@overload
def update(self, __other: Iterable[tuple[_KT, _VT]] = ..., **kwargs: _VT) -> None: ...
2 changes: 1 addition & 1 deletion stubs/SQLAlchemy/sqlalchemy/orm/decl_api.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ _DeclT = TypeVar("_DeclT", bound=type[_DeclarativeBase])

# Dynamic class as created by registry.generate_base() via DeclarativeMeta
# or another metaclass. This class does not exist at runtime.
class _DeclarativeBase(Any): # super classes are dynamic
class _DeclarativeBase(Any): # type: ignore[misc] # super classes are dynamic
registry: ClassVar[registry]
metadata: ClassVar[MetaData]
__abstract__: ClassVar[bool]
Expand Down
97 changes: 58 additions & 39 deletions stubs/SQLAlchemy/sqlalchemy/orm/strategy_options.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from _typeshed import Incomplete
from collections.abc import Callable
from typing import Any
from typing_extensions import Self

from ..sql.base import Generative
from .interfaces import LoaderOption
Expand All @@ -17,50 +19,67 @@ class Load(Generative, LoaderOption):
propagate_to_loaders: bool
def process_compile_state_replaced_entities(self, compile_state, mapper_entities) -> None: ...
def process_compile_state(self, compile_state) -> None: ...
def options(self, *opts) -> None: ...
def set_relationship_strategy(self, attr, strategy, propagate_to_loaders: bool = True) -> None: ...
def set_column_strategy(self, attrs, strategy, opts: Incomplete | None = None, opts_only: bool = False) -> None: ...
def set_generic_strategy(self, attrs, strategy) -> None: ...
def set_class_strategy(self, strategy, opts) -> None: ...
# added dynamically at runtime
def contains_eager(self, attr, alias: Incomplete | None = None): ...
def load_only(self, *attrs): ...
def joinedload(self, attr, innerjoin: Incomplete | None = None): ...
def subqueryload(self, attr): ...
def selectinload(self, attr): ...
def lazyload(self, attr): ...
def immediateload(self, attr): ...
def noload(self, attr): ...
def raiseload(self, attr, sql_only: bool = False): ...
def defaultload(self, attr): ...
def defer(self, key, raiseload: bool = False): ...
def undefer(self, key): ...
def undefer_group(self, name): ...
def with_expression(self, key, expression): ...
def selectin_polymorphic(self, classes): ...
def options(self, *opts) -> Self: ...
def set_relationship_strategy(self, attr, strategy, propagate_to_loaders: bool = True) -> Self: ...
def set_column_strategy(self, attrs, strategy, opts: Incomplete | None = None, opts_only: bool = False) -> Self: ...
def set_generic_strategy(self, attrs, strategy) -> Self: ...
def set_class_strategy(self, strategy, opts) -> Self: ...
# Added dynamically at runtime
def contains_eager(loadopt, attr, alias: Incomplete | None = None) -> Self: ...
def load_only(loadopt, *attrs) -> Self: ...
def joinedload(loadopt, attr, innerjoin: Incomplete | None = None) -> Self: ...
def subqueryload(loadopt, attr) -> Self: ...
def selectinload(loadopt, attr) -> Self: ...
def lazyload(loadopt, attr) -> Self: ...
def immediateload(loadopt, attr) -> Self: ...
def noload(loadopt, attr) -> Self: ...
def raiseload(loadopt, attr, sql_only: bool = False) -> Self: ...
def defaultload(loadopt, attr) -> Self: ...
def defer(loadopt, key, raiseload: bool = False) -> Self: ...
def undefer(loadopt, key) -> Self: ...
def undefer_group(loadopt, name) -> Self: ...
def with_expression(loadopt, key, expression) -> Self: ...
def selectin_polymorphic(loadopt, classes) -> Self: ...

class _UnboundLoad(Load):
path: Any
local_opts: Any
def __init__(self) -> None: ...

class loader_option:
name: Any
fn: Any
def __call__(self, fn): ...
name: str
# The first parameter of this Callable should always be `loadopt: Load`
fn: Callable[..., loader_option]
def __call__(self, fn: Callable[..., loader_option]) -> Self: ...

def contains_eager(loadopt, attr, alias: Incomplete | None = ...): ...
def load_only(loadopt, *attrs): ...
def joinedload(loadopt, attr, innerjoin: Incomplete | None = ...): ...
def subqueryload(loadopt, attr): ...
def selectinload(loadopt, attr): ...
def lazyload(loadopt, attr): ...
def immediateload(loadopt, attr): ...
def noload(loadopt, attr): ...
def raiseload(loadopt, attr, sql_only: bool = ...): ...
def defaultload(loadopt, attr): ...
def defer(loadopt, key, raiseload: bool = ...): ...
def undefer(loadopt, key): ...
def undefer_group(loadopt, name): ...
def with_expression(loadopt, key, expression): ...
def selectin_polymorphic(loadopt, classes): ...
# loader_option instances that can be used to dynamically add methods to Load at runtime
@loader_option()
def contains_eager(loadopt: Load, attr, alias: Incomplete | None = ...) -> loader_option: ...
@loader_option()
def load_only(loadopt: Load, *attrs) -> loader_option: ...
@loader_option()
def joinedload(loadopt, attr, innerjoin=None): ...
@loader_option()
def subqueryload(loadopt: Load, attr) -> loader_option: ...
@loader_option()
def selectinload(loadopt: Load, attr) -> loader_option: ...
@loader_option()
def lazyload(loadopt: Load, attr) -> loader_option: ...
@loader_option()
def immediateload(loadopt: Load, attr) -> loader_option: ...
@loader_option()
def noload(loadopt: Load, attr) -> loader_option: ...
@loader_option()
def raiseload(loadopt: Load, attr, sql_only: bool = ...) -> loader_option: ...
@loader_option()
def defaultload(loadopt: Load, attr) -> loader_option: ...
@loader_option()
def defer(loadopt: Load, key, raiseload: bool = ...) -> loader_option: ...
@loader_option()
def undefer(loadopt: Load, key) -> loader_option: ...
@loader_option()
def undefer_group(loadopt: Load, name) -> loader_option: ...
@loader_option()
def with_expression(loadopt: Load, key) -> loader_option: ...
@loader_option()
def selectin_polymorphic(loadopt: Load, classes) -> loader_option: ...
4 changes: 3 additions & 1 deletion stubs/SQLAlchemy/sqlalchemy/pool/base.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from _typeshed import Incomplete
from _typeshed.dbapi import DBAPIConnection
from collections.abc import Callable
from typing import Any

from .. import log
Expand All @@ -24,7 +26,7 @@ class Pool(log.Identified):
echo: Any
def __init__(
self,
creator,
creator: Callable[[], DBAPIConnection],
recycle: int = -1,
echo: Incomplete | None = None,
logging_name: Incomplete | None = None,
Expand Down
2 changes: 1 addition & 1 deletion stubs/SQLAlchemy/sqlalchemy/sql/base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class _MetaOptions(type):
def __init__(cls, classname, bases, dict_) -> None: ...
def __add__(self, other): ...

class Options:
class Options(metaclass=_MetaOptions):
def __init__(self, **kw) -> None: ...
def __add__(self, other): ...
def __eq__(self, other): ...
Expand Down

0 comments on commit b0ed50e

Please sign in to comment.