Update 2025-04-24_11:44:19

This commit is contained in:
oib
2025-04-24 11:44:23 +02:00
commit e748c737f4
3408 changed files with 717481 additions and 0 deletions

View File

@ -0,0 +1,11 @@
# ext/__init__.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from .. import util as _sa_util
_sa_util.preloaded.import_prefix("sqlalchemy.ext")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,25 @@
# ext/asyncio/__init__.py
# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from .engine import async_engine_from_config as async_engine_from_config
from .engine import AsyncConnection as AsyncConnection
from .engine import AsyncEngine as AsyncEngine
from .engine import AsyncTransaction as AsyncTransaction
from .engine import create_async_engine as create_async_engine
from .engine import create_async_pool_from_url as create_async_pool_from_url
from .result import AsyncMappingResult as AsyncMappingResult
from .result import AsyncResult as AsyncResult
from .result import AsyncScalarResult as AsyncScalarResult
from .result import AsyncTupleResult as AsyncTupleResult
from .scoping import async_scoped_session as async_scoped_session
from .session import async_object_session as async_object_session
from .session import async_session as async_session
from .session import async_sessionmaker as async_sessionmaker
from .session import AsyncAttrs as AsyncAttrs
from .session import AsyncSession as AsyncSession
from .session import AsyncSessionTransaction as AsyncSessionTransaction
from .session import close_all_sessions as close_all_sessions

View File

@ -0,0 +1,281 @@
# ext/asyncio/base.py
# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
import abc
import functools
from typing import Any
from typing import AsyncGenerator
from typing import AsyncIterator
from typing import Awaitable
from typing import Callable
from typing import ClassVar
from typing import Dict
from typing import Generator
from typing import Generic
from typing import NoReturn
from typing import Optional
from typing import overload
from typing import Tuple
from typing import TypeVar
import weakref
from . import exc as async_exc
from ... import util
from ...util.typing import Literal
from ...util.typing import Self
_T = TypeVar("_T", bound=Any)
_T_co = TypeVar("_T_co", bound=Any, covariant=True)
_PT = TypeVar("_PT", bound=Any)
class ReversibleProxy(Generic[_PT]):
_proxy_objects: ClassVar[
Dict[weakref.ref[Any], weakref.ref[ReversibleProxy[Any]]]
] = {}
__slots__ = ("__weakref__",)
@overload
def _assign_proxied(self, target: _PT) -> _PT: ...
@overload
def _assign_proxied(self, target: None) -> None: ...
def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]:
if target is not None:
target_ref: weakref.ref[_PT] = weakref.ref(
target, ReversibleProxy._target_gced
)
proxy_ref = weakref.ref(
self,
functools.partial(ReversibleProxy._target_gced, target_ref),
)
ReversibleProxy._proxy_objects[target_ref] = proxy_ref
return target
@classmethod
def _target_gced(
cls,
ref: weakref.ref[_PT],
proxy_ref: Optional[weakref.ref[Self]] = None, # noqa: U100
) -> None:
cls._proxy_objects.pop(ref, None)
@classmethod
def _regenerate_proxy_for_target(
cls, target: _PT, **additional_kw: Any
) -> Self:
raise NotImplementedError()
@overload
@classmethod
def _retrieve_proxy_for_target(
cls, target: _PT, regenerate: Literal[True] = ..., **additional_kw: Any
) -> Self: ...
@overload
@classmethod
def _retrieve_proxy_for_target(
cls, target: _PT, regenerate: bool = True, **additional_kw: Any
) -> Optional[Self]: ...
@classmethod
def _retrieve_proxy_for_target(
cls, target: _PT, regenerate: bool = True, **additional_kw: Any
) -> Optional[Self]:
try:
proxy_ref = cls._proxy_objects[weakref.ref(target)]
except KeyError:
pass
else:
proxy = proxy_ref()
if proxy is not None:
return proxy # type: ignore
if regenerate:
return cls._regenerate_proxy_for_target(target, **additional_kw)
else:
return None
class StartableContext(Awaitable[_T_co], abc.ABC):
__slots__ = ()
@abc.abstractmethod
async def start(self, is_ctxmanager: bool = False) -> _T_co:
raise NotImplementedError()
def __await__(self) -> Generator[Any, Any, _T_co]:
return self.start().__await__()
async def __aenter__(self) -> _T_co:
return await self.start(is_ctxmanager=True)
@abc.abstractmethod
async def __aexit__(
self, type_: Any, value: Any, traceback: Any
) -> Optional[bool]:
pass
def _raise_for_not_started(self) -> NoReturn:
raise async_exc.AsyncContextNotStarted(
"%s context has not been started and object has not been awaited."
% (self.__class__.__name__)
)
class GeneratorStartableContext(StartableContext[_T_co]):
__slots__ = ("gen",)
gen: AsyncGenerator[_T_co, Any]
def __init__(
self,
func: Callable[..., AsyncIterator[_T_co]],
args: Tuple[Any, ...],
kwds: Dict[str, Any],
):
self.gen = func(*args, **kwds) # type: ignore
async def start(self, is_ctxmanager: bool = False) -> _T_co:
try:
start_value = await util.anext_(self.gen)
except StopAsyncIteration:
raise RuntimeError("generator didn't yield") from None
# if not a context manager, then interrupt the generator, don't
# let it complete. this step is technically not needed, as the
# generator will close in any case at gc time. not clear if having
# this here is a good idea or not (though it helps for clarity IMO)
if not is_ctxmanager:
await self.gen.aclose()
return start_value
async def __aexit__(
self, typ: Any, value: Any, traceback: Any
) -> Optional[bool]:
# vendored from contextlib.py
if typ is None:
try:
await util.anext_(self.gen)
except StopAsyncIteration:
return False
else:
raise RuntimeError("generator didn't stop")
else:
if value is None:
# Need to force instantiation so we can reliably
# tell if we get the same exception back
value = typ()
try:
await self.gen.athrow(value)
except StopAsyncIteration as exc:
# Suppress StopIteration *unless* it's the same exception that
# was passed to throw(). This prevents a StopIteration
# raised inside the "with" statement from being suppressed.
return exc is not value
except RuntimeError as exc:
# Don't re-raise the passed in exception. (issue27122)
if exc is value:
return False
# Avoid suppressing if a Stop(Async)Iteration exception
# was passed to athrow() and later wrapped into a RuntimeError
# (see PEP 479 for sync generators; async generators also
# have this behavior). But do this only if the exception
# wrapped
# by the RuntimeError is actully Stop(Async)Iteration (see
# issue29692).
if (
isinstance(value, (StopIteration, StopAsyncIteration))
and exc.__cause__ is value
):
return False
raise
except BaseException as exc:
# only re-raise if it's *not* the exception that was
# passed to throw(), because __exit__() must not raise
# an exception unless __exit__() itself failed. But throw()
# has to raise the exception to signal propagation, so this
# fixes the impedance mismatch between the throw() protocol
# and the __exit__() protocol.
if exc is not value:
raise
return False
raise RuntimeError("generator didn't stop after athrow()")
def asyncstartablecontext(
func: Callable[..., AsyncIterator[_T_co]]
) -> Callable[..., GeneratorStartableContext[_T_co]]:
"""@asyncstartablecontext decorator.
the decorated function can be called either as ``async with fn()``, **or**
``await fn()``. This is decidedly different from what
``@contextlib.asynccontextmanager`` supports, and the usage pattern
is different as well.
Typical usage:
.. sourcecode:: text
@asyncstartablecontext
async def some_async_generator(<arguments>):
<setup>
try:
yield <value>
except GeneratorExit:
# return value was awaited, no context manager is present
# and caller will .close() the resource explicitly
pass
else:
<context manager cleanup>
Above, ``GeneratorExit`` is caught if the function were used as an
``await``. In this case, it's essential that the cleanup does **not**
occur, so there should not be a ``finally`` block.
If ``GeneratorExit`` is not invoked, this means we're in ``__aexit__``
and we were invoked as a context manager, and cleanup should proceed.
"""
@functools.wraps(func)
def helper(*args: Any, **kwds: Any) -> GeneratorStartableContext[_T_co]:
return GeneratorStartableContext(func, args, kwds)
return helper
class ProxyComparable(ReversibleProxy[_PT]):
__slots__ = ()
@util.ro_non_memoized_property
def _proxied(self) -> _PT:
raise NotImplementedError()
def __hash__(self) -> int:
return id(self)
def __eq__(self, other: Any) -> bool:
return (
isinstance(other, self.__class__)
and self._proxied == other._proxied
)
def __ne__(self, other: Any) -> bool:
return (
not isinstance(other, self.__class__)
or self._proxied != other._proxied
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,21 @@
# ext/asyncio/exc.py
# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from ... import exc
class AsyncMethodRequired(exc.InvalidRequestError):
"""an API can't be used because its result would not be
compatible with async"""
class AsyncContextNotStarted(exc.InvalidRequestError):
"""a startable context manager has not been started."""
class AsyncContextAlreadyStarted(exc.InvalidRequestError):
"""a startable context manager is already started."""

View File

@ -0,0 +1,962 @@
# ext/asyncio/result.py
# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
import operator
from typing import Any
from typing import AsyncIterator
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
from . import exc as async_exc
from ... import util
from ...engine import Result
from ...engine.result import _NO_ROW
from ...engine.result import _R
from ...engine.result import _WithKeys
from ...engine.result import FilterResult
from ...engine.result import FrozenResult
from ...engine.result import ResultMetaData
from ...engine.row import Row
from ...engine.row import RowMapping
from ...sql.base import _generative
from ...util.concurrency import greenlet_spawn
from ...util.typing import Literal
from ...util.typing import Self
if TYPE_CHECKING:
from ...engine import CursorResult
from ...engine.result import _KeyIndexType
from ...engine.result import _UniqueFilterType
_T = TypeVar("_T", bound=Any)
_TP = TypeVar("_TP", bound=Tuple[Any, ...])
class AsyncCommon(FilterResult[_R]):
__slots__ = ()
_real_result: Result[Any]
_metadata: ResultMetaData
async def close(self) -> None: # type: ignore[override]
"""Close this result."""
await greenlet_spawn(self._real_result.close)
@property
def closed(self) -> bool:
"""proxies the .closed attribute of the underlying result object,
if any, else raises ``AttributeError``.
.. versionadded:: 2.0.0b3
"""
return self._real_result.closed
class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]):
"""An asyncio wrapper around a :class:`_result.Result` object.
The :class:`_asyncio.AsyncResult` only applies to statement executions that
use a server-side cursor. It is returned only from the
:meth:`_asyncio.AsyncConnection.stream` and
:meth:`_asyncio.AsyncSession.stream` methods.
.. note:: As is the case with :class:`_engine.Result`, this object is
used for ORM results returned by :meth:`_asyncio.AsyncSession.execute`,
which can yield instances of ORM mapped objects either individually or
within tuple-like rows. Note that these result objects do not
deduplicate instances or rows automatically as is the case with the
legacy :class:`_orm.Query` object. For in-Python de-duplication of
instances or rows, use the :meth:`_asyncio.AsyncResult.unique` modifier
method.
.. versionadded:: 1.4
"""
__slots__ = ()
_real_result: Result[_TP]
def __init__(self, real_result: Result[_TP]):
self._real_result = real_result
self._metadata = real_result._metadata
self._unique_filter_state = real_result._unique_filter_state
self._source_supports_scalars = real_result._source_supports_scalars
self._post_creational_filter = None
# BaseCursorResult pre-generates the "_row_getter". Use that
# if available rather than building a second one
if "_row_getter" in real_result.__dict__:
self._set_memoized_attribute(
"_row_getter", real_result.__dict__["_row_getter"]
)
@property
def t(self) -> AsyncTupleResult[_TP]:
"""Apply a "typed tuple" typing filter to returned rows.
The :attr:`_asyncio.AsyncResult.t` attribute is a synonym for
calling the :meth:`_asyncio.AsyncResult.tuples` method.
.. versionadded:: 2.0
"""
return self # type: ignore
def tuples(self) -> AsyncTupleResult[_TP]:
"""Apply a "typed tuple" typing filter to returned rows.
This method returns the same :class:`_asyncio.AsyncResult` object
at runtime,
however annotates as returning a :class:`_asyncio.AsyncTupleResult`
object that will indicate to :pep:`484` typing tools that plain typed
``Tuple`` instances are returned rather than rows. This allows
tuple unpacking and ``__getitem__`` access of :class:`_engine.Row`
objects to by typed, for those cases where the statement invoked
itself included typing information.
.. versionadded:: 2.0
:return: the :class:`_result.AsyncTupleResult` type at typing time.
.. seealso::
:attr:`_asyncio.AsyncResult.t` - shorter synonym
:attr:`_engine.Row.t` - :class:`_engine.Row` version
"""
return self # type: ignore
@_generative
def unique(self, strategy: Optional[_UniqueFilterType] = None) -> Self:
"""Apply unique filtering to the objects returned by this
:class:`_asyncio.AsyncResult`.
Refer to :meth:`_engine.Result.unique` in the synchronous
SQLAlchemy API for a complete behavioral description.
"""
self._unique_filter_state = (set(), strategy)
return self
def columns(self, *col_expressions: _KeyIndexType) -> Self:
r"""Establish the columns that should be returned in each row.
Refer to :meth:`_engine.Result.columns` in the synchronous
SQLAlchemy API for a complete behavioral description.
"""
return self._column_slices(col_expressions)
async def partitions(
self, size: Optional[int] = None
) -> AsyncIterator[Sequence[Row[_TP]]]:
"""Iterate through sub-lists of rows of the size given.
An async iterator is returned::
async def scroll_results(connection):
result = await connection.stream(select(users_table))
async for partition in result.partitions(100):
print("list of rows: %s" % partition)
Refer to :meth:`_engine.Result.partitions` in the synchronous
SQLAlchemy API for a complete behavioral description.
"""
getter = self._manyrow_getter
while True:
partition = await greenlet_spawn(getter, self, size)
if partition:
yield partition
else:
break
async def fetchall(self) -> Sequence[Row[_TP]]:
"""A synonym for the :meth:`_asyncio.AsyncResult.all` method.
.. versionadded:: 2.0
"""
return await greenlet_spawn(self._allrows)
async def fetchone(self) -> Optional[Row[_TP]]:
"""Fetch one row.
When all rows are exhausted, returns None.
This method is provided for backwards compatibility with
SQLAlchemy 1.x.x.
To fetch the first row of a result only, use the
:meth:`_asyncio.AsyncResult.first` method. To iterate through all
rows, iterate the :class:`_asyncio.AsyncResult` object directly.
:return: a :class:`_engine.Row` object if no filters are applied,
or ``None`` if no rows remain.
"""
row = await greenlet_spawn(self._onerow_getter, self)
if row is _NO_ROW:
return None
else:
return row
async def fetchmany(
self, size: Optional[int] = None
) -> Sequence[Row[_TP]]:
"""Fetch many rows.
When all rows are exhausted, returns an empty list.
This method is provided for backwards compatibility with
SQLAlchemy 1.x.x.
To fetch rows in groups, use the
:meth:`._asyncio.AsyncResult.partitions` method.
:return: a list of :class:`_engine.Row` objects.
.. seealso::
:meth:`_asyncio.AsyncResult.partitions`
"""
return await greenlet_spawn(self._manyrow_getter, self, size)
async def all(self) -> Sequence[Row[_TP]]:
"""Return all rows in a list.
Closes the result set after invocation. Subsequent invocations
will return an empty list.
:return: a list of :class:`_engine.Row` objects.
"""
return await greenlet_spawn(self._allrows)
def __aiter__(self) -> AsyncResult[_TP]:
return self
async def __anext__(self) -> Row[_TP]:
row = await greenlet_spawn(self._onerow_getter, self)
if row is _NO_ROW:
raise StopAsyncIteration()
else:
return row
async def first(self) -> Optional[Row[_TP]]:
"""Fetch the first row or ``None`` if no row is present.
Closes the result set and discards remaining rows.
.. note:: This method returns one **row**, e.g. tuple, by default.
To return exactly one single scalar value, that is, the first
column of the first row, use the
:meth:`_asyncio.AsyncResult.scalar` method,
or combine :meth:`_asyncio.AsyncResult.scalars` and
:meth:`_asyncio.AsyncResult.first`.
Additionally, in contrast to the behavior of the legacy ORM
:meth:`_orm.Query.first` method, **no limit is applied** to the
SQL query which was invoked to produce this
:class:`_asyncio.AsyncResult`;
for a DBAPI driver that buffers results in memory before yielding
rows, all rows will be sent to the Python process and all but
the first row will be discarded.
.. seealso::
:ref:`migration_20_unify_select`
:return: a :class:`_engine.Row` object, or None
if no rows remain.
.. seealso::
:meth:`_asyncio.AsyncResult.scalar`
:meth:`_asyncio.AsyncResult.one`
"""
return await greenlet_spawn(self._only_one_row, False, False, False)
async def one_or_none(self) -> Optional[Row[_TP]]:
"""Return at most one result or raise an exception.
Returns ``None`` if the result has no rows.
Raises :class:`.MultipleResultsFound`
if multiple rows are returned.
.. versionadded:: 1.4
:return: The first :class:`_engine.Row` or ``None`` if no row
is available.
:raises: :class:`.MultipleResultsFound`
.. seealso::
:meth:`_asyncio.AsyncResult.first`
:meth:`_asyncio.AsyncResult.one`
"""
return await greenlet_spawn(self._only_one_row, True, False, False)
@overload
async def scalar_one(self: AsyncResult[Tuple[_T]]) -> _T: ...
@overload
async def scalar_one(self) -> Any: ...
async def scalar_one(self) -> Any:
"""Return exactly one scalar result or raise an exception.
This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
then :meth:`_asyncio.AsyncScalarResult.one`.
.. seealso::
:meth:`_asyncio.AsyncScalarResult.one`
:meth:`_asyncio.AsyncResult.scalars`
"""
return await greenlet_spawn(self._only_one_row, True, True, True)
@overload
async def scalar_one_or_none(
self: AsyncResult[Tuple[_T]],
) -> Optional[_T]: ...
@overload
async def scalar_one_or_none(self) -> Optional[Any]: ...
async def scalar_one_or_none(self) -> Optional[Any]:
"""Return exactly one scalar result or ``None``.
This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
then :meth:`_asyncio.AsyncScalarResult.one_or_none`.
.. seealso::
:meth:`_asyncio.AsyncScalarResult.one_or_none`
:meth:`_asyncio.AsyncResult.scalars`
"""
return await greenlet_spawn(self._only_one_row, True, False, True)
async def one(self) -> Row[_TP]:
"""Return exactly one row or raise an exception.
Raises :class:`.NoResultFound` if the result returns no
rows, or :class:`.MultipleResultsFound` if multiple rows
would be returned.
.. note:: This method returns one **row**, e.g. tuple, by default.
To return exactly one single scalar value, that is, the first
column of the first row, use the
:meth:`_asyncio.AsyncResult.scalar_one` method, or combine
:meth:`_asyncio.AsyncResult.scalars` and
:meth:`_asyncio.AsyncResult.one`.
.. versionadded:: 1.4
:return: The first :class:`_engine.Row`.
:raises: :class:`.MultipleResultsFound`, :class:`.NoResultFound`
.. seealso::
:meth:`_asyncio.AsyncResult.first`
:meth:`_asyncio.AsyncResult.one_or_none`
:meth:`_asyncio.AsyncResult.scalar_one`
"""
return await greenlet_spawn(self._only_one_row, True, True, False)
@overload
async def scalar(self: AsyncResult[Tuple[_T]]) -> Optional[_T]: ...
@overload
async def scalar(self) -> Any: ...
async def scalar(self) -> Any:
"""Fetch the first column of the first row, and close the result set.
Returns ``None`` if there are no rows to fetch.
No validation is performed to test if additional rows remain.
After calling this method, the object is fully closed,
e.g. the :meth:`_engine.CursorResult.close`
method will have been called.
:return: a Python scalar value, or ``None`` if no rows remain.
"""
return await greenlet_spawn(self._only_one_row, False, False, True)
async def freeze(self) -> FrozenResult[_TP]:
"""Return a callable object that will produce copies of this
:class:`_asyncio.AsyncResult` when invoked.
The callable object returned is an instance of
:class:`_engine.FrozenResult`.
This is used for result set caching. The method must be called
on the result when it has been unconsumed, and calling the method
will consume the result fully. When the :class:`_engine.FrozenResult`
is retrieved from a cache, it can be called any number of times where
it will produce a new :class:`_engine.Result` object each time
against its stored set of rows.
.. seealso::
:ref:`do_orm_execute_re_executing` - example usage within the
ORM to implement a result-set cache.
"""
return await greenlet_spawn(FrozenResult, self)
@overload
def scalars(
self: AsyncResult[Tuple[_T]], index: Literal[0]
) -> AsyncScalarResult[_T]: ...
@overload
def scalars(self: AsyncResult[Tuple[_T]]) -> AsyncScalarResult[_T]: ...
@overload
def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: ...
def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]:
"""Return an :class:`_asyncio.AsyncScalarResult` filtering object which
will return single elements rather than :class:`_row.Row` objects.
Refer to :meth:`_result.Result.scalars` in the synchronous
SQLAlchemy API for a complete behavioral description.
:param index: integer or row key indicating the column to be fetched
from each row, defaults to ``0`` indicating the first column.
:return: a new :class:`_asyncio.AsyncScalarResult` filtering object
referring to this :class:`_asyncio.AsyncResult` object.
"""
return AsyncScalarResult(self._real_result, index)
def mappings(self) -> AsyncMappingResult:
"""Apply a mappings filter to returned rows, returning an instance of
:class:`_asyncio.AsyncMappingResult`.
When this filter is applied, fetching rows will return
:class:`_engine.RowMapping` objects instead of :class:`_engine.Row`
objects.
:return: a new :class:`_asyncio.AsyncMappingResult` filtering object
referring to the underlying :class:`_result.Result` object.
"""
return AsyncMappingResult(self._real_result)
class AsyncScalarResult(AsyncCommon[_R]):
"""A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values
rather than :class:`_row.Row` values.
The :class:`_asyncio.AsyncScalarResult` object is acquired by calling the
:meth:`_asyncio.AsyncResult.scalars` method.
Refer to the :class:`_result.ScalarResult` object in the synchronous
SQLAlchemy API for a complete behavioral description.
.. versionadded:: 1.4
"""
__slots__ = ()
_generate_rows = False
def __init__(self, real_result: Result[Any], index: _KeyIndexType):
self._real_result = real_result
if real_result._source_supports_scalars:
self._metadata = real_result._metadata
self._post_creational_filter = None
else:
self._metadata = real_result._metadata._reduce([index])
self._post_creational_filter = operator.itemgetter(0)
self._unique_filter_state = real_result._unique_filter_state
def unique(
self,
strategy: Optional[_UniqueFilterType] = None,
) -> Self:
"""Apply unique filtering to the objects returned by this
:class:`_asyncio.AsyncScalarResult`.
See :meth:`_asyncio.AsyncResult.unique` for usage details.
"""
self._unique_filter_state = (set(), strategy)
return self
async def partitions(
self, size: Optional[int] = None
) -> AsyncIterator[Sequence[_R]]:
"""Iterate through sub-lists of elements of the size given.
Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
scalar values, rather than :class:`_engine.Row` objects,
are returned.
"""
getter = self._manyrow_getter
while True:
partition = await greenlet_spawn(getter, self, size)
if partition:
yield partition
else:
break
async def fetchall(self) -> Sequence[_R]:
"""A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method."""
return await greenlet_spawn(self._allrows)
async def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]:
"""Fetch many objects.
Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
scalar values, rather than :class:`_engine.Row` objects,
are returned.
"""
return await greenlet_spawn(self._manyrow_getter, self, size)
async def all(self) -> Sequence[_R]:
"""Return all scalar values in a list.
Equivalent to :meth:`_asyncio.AsyncResult.all` except that
scalar values, rather than :class:`_engine.Row` objects,
are returned.
"""
return await greenlet_spawn(self._allrows)
def __aiter__(self) -> AsyncScalarResult[_R]:
return self
async def __anext__(self) -> _R:
row = await greenlet_spawn(self._onerow_getter, self)
if row is _NO_ROW:
raise StopAsyncIteration()
else:
return row
async def first(self) -> Optional[_R]:
"""Fetch the first object or ``None`` if no object is present.
Equivalent to :meth:`_asyncio.AsyncResult.first` except that
scalar values, rather than :class:`_engine.Row` objects,
are returned.
"""
return await greenlet_spawn(self._only_one_row, False, False, False)
async def one_or_none(self) -> Optional[_R]:
"""Return at most one object or raise an exception.
Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
scalar values, rather than :class:`_engine.Row` objects,
are returned.
"""
return await greenlet_spawn(self._only_one_row, True, False, False)
async def one(self) -> _R:
"""Return exactly one object or raise an exception.
Equivalent to :meth:`_asyncio.AsyncResult.one` except that
scalar values, rather than :class:`_engine.Row` objects,
are returned.
"""
return await greenlet_spawn(self._only_one_row, True, True, False)
class AsyncMappingResult(_WithKeys, AsyncCommon[RowMapping]):
"""A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary
values rather than :class:`_engine.Row` values.
The :class:`_asyncio.AsyncMappingResult` object is acquired by calling the
:meth:`_asyncio.AsyncResult.mappings` method.
Refer to the :class:`_result.MappingResult` object in the synchronous
SQLAlchemy API for a complete behavioral description.
.. versionadded:: 1.4
"""
__slots__ = ()
_generate_rows = True
_post_creational_filter = operator.attrgetter("_mapping")
def __init__(self, result: Result[Any]):
self._real_result = result
self._unique_filter_state = result._unique_filter_state
self._metadata = result._metadata
if result._source_supports_scalars:
self._metadata = self._metadata._reduce([0])
def unique(
self,
strategy: Optional[_UniqueFilterType] = None,
) -> Self:
"""Apply unique filtering to the objects returned by this
:class:`_asyncio.AsyncMappingResult`.
See :meth:`_asyncio.AsyncResult.unique` for usage details.
"""
self._unique_filter_state = (set(), strategy)
return self
def columns(self, *col_expressions: _KeyIndexType) -> Self:
r"""Establish the columns that should be returned in each row."""
return self._column_slices(col_expressions)
async def partitions(
self, size: Optional[int] = None
) -> AsyncIterator[Sequence[RowMapping]]:
"""Iterate through sub-lists of elements of the size given.
Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
:class:`_engine.RowMapping` values, rather than :class:`_engine.Row`
objects, are returned.
"""
getter = self._manyrow_getter
while True:
partition = await greenlet_spawn(getter, self, size)
if partition:
yield partition
else:
break
async def fetchall(self) -> Sequence[RowMapping]:
"""A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method."""
return await greenlet_spawn(self._allrows)
async def fetchone(self) -> Optional[RowMapping]:
"""Fetch one object.
Equivalent to :meth:`_asyncio.AsyncResult.fetchone` except that
:class:`_engine.RowMapping` values, rather than :class:`_engine.Row`
objects, are returned.
"""
row = await greenlet_spawn(self._onerow_getter, self)
if row is _NO_ROW:
return None
else:
return row
async def fetchmany(
self, size: Optional[int] = None
) -> Sequence[RowMapping]:
"""Fetch many rows.
Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
:class:`_engine.RowMapping` values, rather than :class:`_engine.Row`
objects, are returned.
"""
return await greenlet_spawn(self._manyrow_getter, self, size)
async def all(self) -> Sequence[RowMapping]:
"""Return all rows in a list.
Equivalent to :meth:`_asyncio.AsyncResult.all` except that
:class:`_engine.RowMapping` values, rather than :class:`_engine.Row`
objects, are returned.
"""
return await greenlet_spawn(self._allrows)
def __aiter__(self) -> AsyncMappingResult:
return self
async def __anext__(self) -> RowMapping:
row = await greenlet_spawn(self._onerow_getter, self)
if row is _NO_ROW:
raise StopAsyncIteration()
else:
return row
async def first(self) -> Optional[RowMapping]:
"""Fetch the first object or ``None`` if no object is present.
Equivalent to :meth:`_asyncio.AsyncResult.first` except that
:class:`_engine.RowMapping` values, rather than :class:`_engine.Row`
objects, are returned.
"""
return await greenlet_spawn(self._only_one_row, False, False, False)
async def one_or_none(self) -> Optional[RowMapping]:
"""Return at most one object or raise an exception.
Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
:class:`_engine.RowMapping` values, rather than :class:`_engine.Row`
objects, are returned.
"""
return await greenlet_spawn(self._only_one_row, True, False, False)
async def one(self) -> RowMapping:
"""Return exactly one object or raise an exception.
Equivalent to :meth:`_asyncio.AsyncResult.one` except that
:class:`_engine.RowMapping` values, rather than :class:`_engine.Row`
objects, are returned.
"""
return await greenlet_spawn(self._only_one_row, True, True, False)
class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly):
"""A :class:`_asyncio.AsyncResult` that's typed as returning plain
Python tuples instead of rows.
Since :class:`_engine.Row` acts like a tuple in every way already,
this class is a typing only class, regular :class:`_asyncio.AsyncResult` is
still used at runtime.
"""
__slots__ = ()
if TYPE_CHECKING:
async def partitions(
self, size: Optional[int] = None
) -> AsyncIterator[Sequence[_R]]:
"""Iterate through sub-lists of elements of the size given.
Equivalent to :meth:`_result.Result.partitions` except that
tuple values, rather than :class:`_engine.Row` objects,
are returned.
"""
...
async def fetchone(self) -> Optional[_R]:
"""Fetch one tuple.
Equivalent to :meth:`_result.Result.fetchone` except that
tuple values, rather than :class:`_engine.Row`
objects, are returned.
"""
...
async def fetchall(self) -> Sequence[_R]:
"""A synonym for the :meth:`_engine.ScalarResult.all` method."""
...
async def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]:
"""Fetch many objects.
Equivalent to :meth:`_result.Result.fetchmany` except that
tuple values, rather than :class:`_engine.Row` objects,
are returned.
"""
...
async def all(self) -> Sequence[_R]: # noqa: A001
"""Return all scalar values in a list.
Equivalent to :meth:`_result.Result.all` except that
tuple values, rather than :class:`_engine.Row` objects,
are returned.
"""
...
async def __aiter__(self) -> AsyncIterator[_R]: ...
async def __anext__(self) -> _R: ...
async def first(self) -> Optional[_R]:
"""Fetch the first object or ``None`` if no object is present.
Equivalent to :meth:`_result.Result.first` except that
tuple values, rather than :class:`_engine.Row` objects,
are returned.
"""
...
async def one_or_none(self) -> Optional[_R]:
"""Return at most one object or raise an exception.
Equivalent to :meth:`_result.Result.one_or_none` except that
tuple values, rather than :class:`_engine.Row` objects,
are returned.
"""
...
async def one(self) -> _R:
"""Return exactly one object or raise an exception.
Equivalent to :meth:`_result.Result.one` except that
tuple values, rather than :class:`_engine.Row` objects,
are returned.
"""
...
@overload
async def scalar_one(self: AsyncTupleResult[Tuple[_T]]) -> _T: ...
@overload
async def scalar_one(self) -> Any: ...
async def scalar_one(self) -> Any:
"""Return exactly one scalar result or raise an exception.
This is equivalent to calling :meth:`_engine.Result.scalars`
and then :meth:`_engine.AsyncScalarResult.one`.
.. seealso::
:meth:`_engine.AsyncScalarResult.one`
:meth:`_engine.Result.scalars`
"""
...
@overload
async def scalar_one_or_none(
self: AsyncTupleResult[Tuple[_T]],
) -> Optional[_T]: ...
@overload
async def scalar_one_or_none(self) -> Optional[Any]: ...
async def scalar_one_or_none(self) -> Optional[Any]:
"""Return exactly one or no scalar result.
This is equivalent to calling :meth:`_engine.Result.scalars`
and then :meth:`_engine.AsyncScalarResult.one_or_none`.
.. seealso::
:meth:`_engine.AsyncScalarResult.one_or_none`
:meth:`_engine.Result.scalars`
"""
...
@overload
async def scalar(
self: AsyncTupleResult[Tuple[_T]],
) -> Optional[_T]: ...
@overload
async def scalar(self) -> Any: ...
async def scalar(self) -> Any:
"""Fetch the first column of the first row, and close the result
set.
Returns ``None`` if there are no rows to fetch.
No validation is performed to test if additional rows remain.
After calling this method, the object is fully closed,
e.g. the :meth:`_engine.CursorResult.close`
method will have been called.
:return: a Python scalar value , or ``None`` if no rows remain.
"""
...
_RT = TypeVar("_RT", bound="Result[Any]")
async def _ensure_sync_result(result: _RT, calling_method: Any) -> _RT:
cursor_result: CursorResult[Any]
try:
is_cursor = result._is_cursor
except AttributeError:
# legacy execute(DefaultGenerator) case
return result
if not is_cursor:
cursor_result = getattr(result, "raw", None) # type: ignore
else:
cursor_result = result # type: ignore
if cursor_result and cursor_result.context._is_server_side:
await greenlet_spawn(cursor_result.close)
raise async_exc.AsyncMethodRequired(
"Can't use the %s.%s() method with a "
"server-side cursor. "
"Use the %s.stream() method for an async "
"streaming result set."
% (
calling_method.__self__.__class__.__name__,
calling_method.__name__,
calling_method.__self__.__class__.__name__,
)
)
return result

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,570 @@
# ext/baked.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""Baked query extension.
Provides a creational pattern for the :class:`.query.Query` object which
allows the fully constructed object, Core select statement, and string
compiled result to be fully cached.
"""
import collections.abc as collections_abc
import logging
from .. import exc as sa_exc
from .. import util
from ..orm import exc as orm_exc
from ..orm.query import Query
from ..orm.session import Session
from ..sql import func
from ..sql import literal_column
from ..sql import util as sql_util
log = logging.getLogger(__name__)
class Bakery:
"""Callable which returns a :class:`.BakedQuery`.
This object is returned by the class method
:meth:`.BakedQuery.bakery`. It exists as an object
so that the "cache" can be easily inspected.
.. versionadded:: 1.2
"""
__slots__ = "cls", "cache"
def __init__(self, cls_, cache):
self.cls = cls_
self.cache = cache
def __call__(self, initial_fn, *args):
return self.cls(self.cache, initial_fn, args)
class BakedQuery:
"""A builder object for :class:`.query.Query` objects."""
__slots__ = "steps", "_bakery", "_cache_key", "_spoiled"
def __init__(self, bakery, initial_fn, args=()):
self._cache_key = ()
self._update_cache_key(initial_fn, args)
self.steps = [initial_fn]
self._spoiled = False
self._bakery = bakery
@classmethod
def bakery(cls, size=200, _size_alert=None):
"""Construct a new bakery.
:return: an instance of :class:`.Bakery`
"""
return Bakery(cls, util.LRUCache(size, size_alert=_size_alert))
def _clone(self):
b1 = BakedQuery.__new__(BakedQuery)
b1._cache_key = self._cache_key
b1.steps = list(self.steps)
b1._bakery = self._bakery
b1._spoiled = self._spoiled
return b1
def _update_cache_key(self, fn, args=()):
self._cache_key += (fn.__code__,) + args
def __iadd__(self, other):
if isinstance(other, tuple):
self.add_criteria(*other)
else:
self.add_criteria(other)
return self
def __add__(self, other):
if isinstance(other, tuple):
return self.with_criteria(*other)
else:
return self.with_criteria(other)
def add_criteria(self, fn, *args):
"""Add a criteria function to this :class:`.BakedQuery`.
This is equivalent to using the ``+=`` operator to
modify a :class:`.BakedQuery` in-place.
"""
self._update_cache_key(fn, args)
self.steps.append(fn)
return self
def with_criteria(self, fn, *args):
"""Add a criteria function to a :class:`.BakedQuery` cloned from this
one.
This is equivalent to using the ``+`` operator to
produce a new :class:`.BakedQuery` with modifications.
"""
return self._clone().add_criteria(fn, *args)
def for_session(self, session):
"""Return a :class:`_baked.Result` object for this
:class:`.BakedQuery`.
This is equivalent to calling the :class:`.BakedQuery` as a
Python callable, e.g. ``result = my_baked_query(session)``.
"""
return Result(self, session)
def __call__(self, session):
return self.for_session(session)
def spoil(self, full=False):
"""Cancel any query caching that will occur on this BakedQuery object.
The BakedQuery can continue to be used normally, however additional
creational functions will not be cached; they will be called
on every invocation.
This is to support the case where a particular step in constructing
a baked query disqualifies the query from being cacheable, such
as a variant that relies upon some uncacheable value.
:param full: if False, only functions added to this
:class:`.BakedQuery` object subsequent to the spoil step will be
non-cached; the state of the :class:`.BakedQuery` up until
this point will be pulled from the cache. If True, then the
entire :class:`_query.Query` object is built from scratch each
time, with all creational functions being called on each
invocation.
"""
if not full and not self._spoiled:
_spoil_point = self._clone()
_spoil_point._cache_key += ("_query_only",)
self.steps = [_spoil_point._retrieve_baked_query]
self._spoiled = True
return self
def _effective_key(self, session):
"""Return the key that actually goes into the cache dictionary for
this :class:`.BakedQuery`, taking into account the given
:class:`.Session`.
This basically means we also will include the session's query_class,
as the actual :class:`_query.Query` object is part of what's cached
and needs to match the type of :class:`_query.Query` that a later
session will want to use.
"""
return self._cache_key + (session._query_cls,)
def _with_lazyload_options(self, options, effective_path, cache_path=None):
"""Cloning version of _add_lazyload_options."""
q = self._clone()
q._add_lazyload_options(options, effective_path, cache_path=cache_path)
return q
def _add_lazyload_options(self, options, effective_path, cache_path=None):
"""Used by per-state lazy loaders to add options to the
"lazy load" query from a parent query.
Creates a cache key based on given load path and query options;
if a repeatable cache key cannot be generated, the query is
"spoiled" so that it won't use caching.
"""
key = ()
if not cache_path:
cache_path = effective_path
for opt in options:
if opt._is_legacy_option or opt._is_compile_state:
ck = opt._generate_cache_key()
if ck is None:
self.spoil(full=True)
else:
assert not ck[1], (
"loader options with variable bound parameters "
"not supported with baked queries. Please "
"use new-style select() statements for cached "
"ORM queries."
)
key += ck[0]
self.add_criteria(
lambda q: q._with_current_path(effective_path).options(*options),
cache_path.path,
key,
)
def _retrieve_baked_query(self, session):
query = self._bakery.get(self._effective_key(session), None)
if query is None:
query = self._as_query(session)
self._bakery[self._effective_key(session)] = query.with_session(
None
)
return query.with_session(session)
def _bake(self, session):
query = self._as_query(session)
query.session = None
# in 1.4, this is where before_compile() event is
# invoked
statement = query._statement_20()
# if the query is not safe to cache, we still do everything as though
# we did cache it, since the receiver of _bake() assumes subqueryload
# context was set up, etc.
#
# note also we want to cache the statement itself because this
# allows the statement itself to hold onto its cache key that is
# used by the Connection, which in itself is more expensive to
# generate than what BakedQuery was able to provide in 1.3 and prior
if statement._compile_options._bake_ok:
self._bakery[self._effective_key(session)] = (
query,
statement,
)
return query, statement
def to_query(self, query_or_session):
"""Return the :class:`_query.Query` object for use as a subquery.
This method should be used within the lambda callable being used
to generate a step of an enclosing :class:`.BakedQuery`. The
parameter should normally be the :class:`_query.Query` object that
is passed to the lambda::
sub_bq = self.bakery(lambda s: s.query(User.name))
sub_bq += lambda q: q.filter(User.id == Address.user_id).correlate(Address)
main_bq = self.bakery(lambda s: s.query(Address))
main_bq += lambda q: q.filter(sub_bq.to_query(q).exists())
In the case where the subquery is used in the first callable against
a :class:`.Session`, the :class:`.Session` is also accepted::
sub_bq = self.bakery(lambda s: s.query(User.name))
sub_bq += lambda q: q.filter(User.id == Address.user_id).correlate(Address)
main_bq = self.bakery(
lambda s: s.query(Address.id, sub_bq.to_query(q).scalar_subquery())
)
:param query_or_session: a :class:`_query.Query` object or a class
:class:`.Session` object, that is assumed to be within the context
of an enclosing :class:`.BakedQuery` callable.
.. versionadded:: 1.3
""" # noqa: E501
if isinstance(query_or_session, Session):
session = query_or_session
elif isinstance(query_or_session, Query):
session = query_or_session.session
if session is None:
raise sa_exc.ArgumentError(
"Given Query needs to be associated with a Session"
)
else:
raise TypeError(
"Query or Session object expected, got %r."
% type(query_or_session)
)
return self._as_query(session)
def _as_query(self, session):
query = self.steps[0](session)
for step in self.steps[1:]:
query = step(query)
return query
class Result:
"""Invokes a :class:`.BakedQuery` against a :class:`.Session`.
The :class:`_baked.Result` object is where the actual :class:`.query.Query`
object gets created, or retrieved from the cache,
against a target :class:`.Session`, and is then invoked for results.
"""
__slots__ = "bq", "session", "_params", "_post_criteria"
def __init__(self, bq, session):
self.bq = bq
self.session = session
self._params = {}
self._post_criteria = []
def params(self, *args, **kw):
"""Specify parameters to be replaced into the string SQL statement."""
if len(args) == 1:
kw.update(args[0])
elif len(args) > 0:
raise sa_exc.ArgumentError(
"params() takes zero or one positional argument, "
"which is a dictionary."
)
self._params.update(kw)
return self
def _using_post_criteria(self, fns):
if fns:
self._post_criteria.extend(fns)
return self
def with_post_criteria(self, fn):
"""Add a criteria function that will be applied post-cache.
This adds a function that will be run against the
:class:`_query.Query` object after it is retrieved from the
cache. This currently includes **only** the
:meth:`_query.Query.params` and :meth:`_query.Query.execution_options`
methods.
.. warning:: :meth:`_baked.Result.with_post_criteria`
functions are applied
to the :class:`_query.Query`
object **after** the query's SQL statement
object has been retrieved from the cache. Only
:meth:`_query.Query.params` and
:meth:`_query.Query.execution_options`
methods should be used.
.. versionadded:: 1.2
"""
return self._using_post_criteria([fn])
def _as_query(self):
q = self.bq._as_query(self.session).params(self._params)
for fn in self._post_criteria:
q = fn(q)
return q
def __str__(self):
return str(self._as_query())
def __iter__(self):
return self._iter().__iter__()
def _iter(self):
bq = self.bq
if not self.session.enable_baked_queries or bq._spoiled:
return self._as_query()._iter()
query, statement = bq._bakery.get(
bq._effective_key(self.session), (None, None)
)
if query is None:
query, statement = bq._bake(self.session)
if self._params:
q = query.params(self._params)
else:
q = query
for fn in self._post_criteria:
q = fn(q)
params = q._params
execution_options = dict(q._execution_options)
execution_options.update(
{
"_sa_orm_load_options": q.load_options,
"compiled_cache": bq._bakery,
}
)
result = self.session.execute(
statement, params, execution_options=execution_options
)
if result._attributes.get("is_single_entity", False):
result = result.scalars()
if result._attributes.get("filtered", False):
result = result.unique()
return result
def count(self):
"""return the 'count'.
Equivalent to :meth:`_query.Query.count`.
Note this uses a subquery to ensure an accurate count regardless
of the structure of the original statement.
"""
col = func.count(literal_column("*"))
bq = self.bq.with_criteria(lambda q: q._legacy_from_self(col))
return bq.for_session(self.session).params(self._params).scalar()
def scalar(self):
"""Return the first element of the first result or None
if no rows present. If multiple rows are returned,
raises MultipleResultsFound.
Equivalent to :meth:`_query.Query.scalar`.
"""
try:
ret = self.one()
if not isinstance(ret, collections_abc.Sequence):
return ret
return ret[0]
except orm_exc.NoResultFound:
return None
def first(self):
"""Return the first row.
Equivalent to :meth:`_query.Query.first`.
"""
bq = self.bq.with_criteria(lambda q: q.slice(0, 1))
return (
bq.for_session(self.session)
.params(self._params)
._using_post_criteria(self._post_criteria)
._iter()
.first()
)
def one(self):
"""Return exactly one result or raise an exception.
Equivalent to :meth:`_query.Query.one`.
"""
return self._iter().one()
def one_or_none(self):
"""Return one or zero results, or raise an exception for multiple
rows.
Equivalent to :meth:`_query.Query.one_or_none`.
"""
return self._iter().one_or_none()
def all(self):
"""Return all rows.
Equivalent to :meth:`_query.Query.all`.
"""
return self._iter().all()
def get(self, ident):
"""Retrieve an object based on identity.
Equivalent to :meth:`_query.Query.get`.
"""
query = self.bq.steps[0](self.session)
return query._get_impl(ident, self._load_on_pk_identity)
def _load_on_pk_identity(self, session, query, primary_key_identity, **kw):
"""Load the given primary key identity from the database."""
mapper = query._raw_columns[0]._annotations["parententity"]
_get_clause, _get_params = mapper._get_clause
def setup(query):
_lcl_get_clause = _get_clause
q = query._clone()
q._get_condition()
q._order_by = None
# None present in ident - turn those comparisons
# into "IS NULL"
if None in primary_key_identity:
nones = {
_get_params[col].key
for col, value in zip(
mapper.primary_key, primary_key_identity
)
if value is None
}
_lcl_get_clause = sql_util.adapt_criterion_to_null(
_lcl_get_clause, nones
)
# TODO: can mapper._get_clause be pre-adapted?
q._where_criteria = (
sql_util._deep_annotate(_lcl_get_clause, {"_orm_adapt": True}),
)
for fn in self._post_criteria:
q = fn(q)
return q
# cache the query against a key that includes
# which positions in the primary key are NULL
# (remember, we can map to an OUTER JOIN)
bq = self.bq
# add the clause we got from mapper._get_clause to the cache
# key so that if a race causes multiple calls to _get_clause,
# we've cached on ours
bq = bq._clone()
bq._cache_key += (_get_clause,)
bq = bq.with_criteria(
setup, tuple(elem is None for elem in primary_key_identity)
)
params = {
_get_params[primary_key].key: id_val
for id_val, primary_key in zip(
primary_key_identity, mapper.primary_key
)
}
result = list(bq.for_session(self.session).params(**params))
l = len(result)
if l > 1:
raise orm_exc.MultipleResultsFound()
elif l:
return result[0]
else:
return None
bakery = BakedQuery.bakery

View File

@ -0,0 +1,600 @@
# ext/compiler.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
r"""Provides an API for creation of custom ClauseElements and compilers.
Synopsis
========
Usage involves the creation of one or more
:class:`~sqlalchemy.sql.expression.ClauseElement` subclasses and one or
more callables defining its compilation::
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import ColumnClause
class MyColumn(ColumnClause):
inherit_cache = True
@compiles(MyColumn)
def compile_mycolumn(element, compiler, **kw):
return "[%s]" % element.name
Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`,
the base expression element for named column objects. The ``compiles``
decorator registers itself with the ``MyColumn`` class so that it is invoked
when the object is compiled to a string::
from sqlalchemy import select
s = select(MyColumn("x"), MyColumn("y"))
print(str(s))
Produces:
.. sourcecode:: sql
SELECT [x], [y]
Dialect-specific compilation rules
==================================
Compilers can also be made dialect-specific. The appropriate compiler will be
invoked for the dialect in use::
from sqlalchemy.schema import DDLElement
class AlterColumn(DDLElement):
inherit_cache = False
def __init__(self, column, cmd):
self.column = column
self.cmd = cmd
@compiles(AlterColumn)
def visit_alter_column(element, compiler, **kw):
return "ALTER COLUMN %s ..." % element.column.name
@compiles(AlterColumn, "postgresql")
def visit_alter_column(element, compiler, **kw):
return "ALTER TABLE %s ALTER COLUMN %s ..." % (
element.table.name,
element.column.name,
)
The second ``visit_alter_table`` will be invoked when any ``postgresql``
dialect is used.
.. _compilerext_compiling_subelements:
Compiling sub-elements of a custom expression construct
=======================================================
The ``compiler`` argument is the
:class:`~sqlalchemy.engine.interfaces.Compiled` object in use. This object
can be inspected for any information about the in-progress compilation,
including ``compiler.dialect``, ``compiler.statement`` etc. The
:class:`~sqlalchemy.sql.compiler.SQLCompiler` and
:class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()``
method which can be used for compilation of embedded attributes::
from sqlalchemy.sql.expression import Executable, ClauseElement
class InsertFromSelect(Executable, ClauseElement):
inherit_cache = False
def __init__(self, table, select):
self.table = table
self.select = select
@compiles(InsertFromSelect)
def visit_insert_from_select(element, compiler, **kw):
return "INSERT INTO %s (%s)" % (
compiler.process(element.table, asfrom=True, **kw),
compiler.process(element.select, **kw),
)
insert = InsertFromSelect(t1, select(t1).where(t1.c.x > 5))
print(insert)
Produces (formatted for readability):
.. sourcecode:: sql
INSERT INTO mytable (
SELECT mytable.x, mytable.y, mytable.z
FROM mytable
WHERE mytable.x > :x_1
)
.. note::
The above ``InsertFromSelect`` construct is only an example, this actual
functionality is already available using the
:meth:`_expression.Insert.from_select` method.
Cross Compiling between SQL and DDL compilers
---------------------------------------------
SQL and DDL constructs are each compiled using different base compilers -
``SQLCompiler`` and ``DDLCompiler``. A common need is to access the
compilation rules of SQL expressions from within a DDL expression. The
``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as
below where we generate a CHECK constraint that embeds a SQL expression::
@compiles(MyConstraint)
def compile_my_constraint(constraint, ddlcompiler, **kw):
kw["literal_binds"] = True
return "CONSTRAINT %s CHECK (%s)" % (
constraint.name,
ddlcompiler.sql_compiler.process(constraint.expression, **kw),
)
Above, we add an additional flag to the process step as called by
:meth:`.SQLCompiler.process`, which is the ``literal_binds`` flag. This
indicates that any SQL expression which refers to a :class:`.BindParameter`
object or other "literal" object such as those which refer to strings or
integers should be rendered **in-place**, rather than being referred to as
a bound parameter; when emitting DDL, bound parameters are typically not
supported.
Changing the default compilation of existing constructs
=======================================================
The compiler extension applies just as well to the existing constructs. When
overriding the compilation of a built in SQL construct, the @compiles
decorator is invoked upon the appropriate class (be sure to use the class,
i.e. ``Insert`` or ``Select``, instead of the creation function such
as ``insert()`` or ``select()``).
Within the new compilation function, to get at the "original" compilation
routine, use the appropriate visit_XXX method - this
because compiler.process() will call upon the overriding routine and cause
an endless loop. Such as, to add "prefix" to all insert statements::
from sqlalchemy.sql.expression import Insert
@compiles(Insert)
def prefix_inserts(insert, compiler, **kw):
return compiler.visit_insert(insert.prefix_with("some prefix"), **kw)
The above compiler will prefix all INSERT statements with "some prefix" when
compiled.
.. _type_compilation_extension:
Changing Compilation of Types
=============================
``compiler`` works for types, too, such as below where we implement the
MS-SQL specific 'max' keyword for ``String``/``VARCHAR``::
@compiles(String, "mssql")
@compiles(VARCHAR, "mssql")
def compile_varchar(element, compiler, **kw):
if element.length == "max":
return "VARCHAR('max')"
else:
return compiler.visit_VARCHAR(element, **kw)
foo = Table("foo", metadata, Column("data", VARCHAR("max")))
Subclassing Guidelines
======================
A big part of using the compiler extension is subclassing SQLAlchemy
expression constructs. To make this easier, the expression and
schema packages feature a set of "bases" intended for common tasks.
A synopsis is as follows:
* :class:`~sqlalchemy.sql.expression.ClauseElement` - This is the root
expression class. Any SQL expression can be derived from this base, and is
probably the best choice for longer constructs such as specialized INSERT
statements.
* :class:`~sqlalchemy.sql.expression.ColumnElement` - The root of all
"column-like" elements. Anything that you'd place in the "columns" clause of
a SELECT statement (as well as order by and group by) can derive from this -
the object will automatically have Python "comparison" behavior.
:class:`~sqlalchemy.sql.expression.ColumnElement` classes want to have a
``type`` member which is expression's return type. This can be established
at the instance level in the constructor, or at the class level if its
generally constant::
class timestamp(ColumnElement):
type = TIMESTAMP()
inherit_cache = True
* :class:`~sqlalchemy.sql.functions.FunctionElement` - This is a hybrid of a
``ColumnElement`` and a "from clause" like object, and represents a SQL
function or stored procedure type of call. Since most databases support
statements along the line of "SELECT FROM <some function>"
``FunctionElement`` adds in the ability to be used in the FROM clause of a
``select()`` construct::
from sqlalchemy.sql.expression import FunctionElement
class coalesce(FunctionElement):
name = "coalesce"
inherit_cache = True
@compiles(coalesce)
def compile(element, compiler, **kw):
return "coalesce(%s)" % compiler.process(element.clauses, **kw)
@compiles(coalesce, "oracle")
def compile(element, compiler, **kw):
if len(element.clauses) > 2:
raise TypeError(
"coalesce only supports two arguments on " "Oracle Database"
)
return "nvl(%s)" % compiler.process(element.clauses, **kw)
* :class:`.ExecutableDDLElement` - The root of all DDL expressions,
like CREATE TABLE, ALTER TABLE, etc. Compilation of
:class:`.ExecutableDDLElement` subclasses is issued by a
:class:`.DDLCompiler` instead of a :class:`.SQLCompiler`.
:class:`.ExecutableDDLElement` can also be used as an event hook in
conjunction with event hooks like :meth:`.DDLEvents.before_create` and
:meth:`.DDLEvents.after_create`, allowing the construct to be invoked
automatically during CREATE TABLE and DROP TABLE sequences.
.. seealso::
:ref:`metadata_ddl_toplevel` - contains examples of associating
:class:`.DDL` objects (which are themselves :class:`.ExecutableDDLElement`
instances) with :class:`.DDLEvents` event hooks.
* :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which
should be used with any expression class that represents a "standalone"
SQL statement that can be passed directly to an ``execute()`` method. It
is already implicit within ``DDLElement`` and ``FunctionElement``.
Most of the above constructs also respond to SQL statement caching. A
subclassed construct will want to define the caching behavior for the object,
which usually means setting the flag ``inherit_cache`` to the value of
``False`` or ``True``. See the next section :ref:`compilerext_caching`
for background.
.. _compilerext_caching:
Enabling Caching Support for Custom Constructs
==============================================
SQLAlchemy as of version 1.4 includes a
:ref:`SQL compilation caching facility <sql_caching>` which will allow
equivalent SQL constructs to cache their stringified form, along with other
structural information used to fetch results from the statement.
For reasons discussed at :ref:`caching_caveats`, the implementation of this
caching system takes a conservative approach towards including custom SQL
constructs and/or subclasses within the caching system. This includes that
any user-defined SQL constructs, including all the examples for this
extension, will not participate in caching by default unless they positively
assert that they are able to do so. The :attr:`.HasCacheKey.inherit_cache`
attribute when set to ``True`` at the class level of a specific subclass
will indicate that instances of this class may be safely cached, using the
cache key generation scheme of the immediate superclass. This applies
for example to the "synopsis" example indicated previously::
class MyColumn(ColumnClause):
inherit_cache = True
@compiles(MyColumn)
def compile_mycolumn(element, compiler, **kw):
return "[%s]" % element.name
Above, the ``MyColumn`` class does not include any new state that
affects its SQL compilation; the cache key of ``MyColumn`` instances will
make use of that of the ``ColumnClause`` superclass, meaning it will take
into account the class of the object (``MyColumn``), the string name and
datatype of the object::
>>> MyColumn("some_name", String())._generate_cache_key()
CacheKey(
key=('0', <class '__main__.MyColumn'>,
'name', 'some_name',
'type', (<class 'sqlalchemy.sql.sqltypes.String'>,
('length', None), ('collation', None))
), bindparams=[])
For objects that are likely to be **used liberally as components within many
larger statements**, such as :class:`_schema.Column` subclasses and custom SQL
datatypes, it's important that **caching be enabled as much as possible**, as
this may otherwise negatively affect performance.
An example of an object that **does** contain state which affects its SQL
compilation is the one illustrated at :ref:`compilerext_compiling_subelements`;
this is an "INSERT FROM SELECT" construct that combines together a
:class:`_schema.Table` as well as a :class:`_sql.Select` construct, each of
which independently affect the SQL string generation of the construct. For
this class, the example illustrates that it simply does not participate in
caching::
class InsertFromSelect(Executable, ClauseElement):
inherit_cache = False
def __init__(self, table, select):
self.table = table
self.select = select
@compiles(InsertFromSelect)
def visit_insert_from_select(element, compiler, **kw):
return "INSERT INTO %s (%s)" % (
compiler.process(element.table, asfrom=True, **kw),
compiler.process(element.select, **kw),
)
While it is also possible that the above ``InsertFromSelect`` could be made to
produce a cache key that is composed of that of the :class:`_schema.Table` and
:class:`_sql.Select` components together, the API for this is not at the moment
fully public. However, for an "INSERT FROM SELECT" construct, which is only
used by itself for specific operations, caching is not as critical as in the
previous example.
For objects that are **used in relative isolation and are generally
standalone**, such as custom :term:`DML` constructs like an "INSERT FROM
SELECT", **caching is generally less critical** as the lack of caching for such
a construct will have only localized implications for that specific operation.
Further Examples
================
"UTC timestamp" function
-------------------------
A function that works like "CURRENT_TIMESTAMP" except applies the
appropriate conversions so that the time is in UTC time. Timestamps are best
stored in relational databases as UTC, without time zones. UTC so that your
database doesn't think time has gone backwards in the hour when daylight
savings ends, without timezones because timezones are like character
encodings - they're best applied only at the endpoints of an application
(i.e. convert to UTC upon user input, re-apply desired timezone upon display).
For PostgreSQL and Microsoft SQL Server::
from sqlalchemy.sql import expression
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.types import DateTime
class utcnow(expression.FunctionElement):
type = DateTime()
inherit_cache = True
@compiles(utcnow, "postgresql")
def pg_utcnow(element, compiler, **kw):
return "TIMEZONE('utc', CURRENT_TIMESTAMP)"
@compiles(utcnow, "mssql")
def ms_utcnow(element, compiler, **kw):
return "GETUTCDATE()"
Example usage::
from sqlalchemy import Table, Column, Integer, String, DateTime, MetaData
metadata = MetaData()
event = Table(
"event",
metadata,
Column("id", Integer, primary_key=True),
Column("description", String(50), nullable=False),
Column("timestamp", DateTime, server_default=utcnow()),
)
"GREATEST" function
-------------------
The "GREATEST" function is given any number of arguments and returns the one
that is of the highest value - its equivalent to Python's ``max``
function. A SQL standard version versus a CASE based version which only
accommodates two arguments::
from sqlalchemy.sql import expression, case
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.types import Numeric
class greatest(expression.FunctionElement):
type = Numeric()
name = "greatest"
inherit_cache = True
@compiles(greatest)
def default_greatest(element, compiler, **kw):
return compiler.visit_function(element)
@compiles(greatest, "sqlite")
@compiles(greatest, "mssql")
@compiles(greatest, "oracle")
def case_greatest(element, compiler, **kw):
arg1, arg2 = list(element.clauses)
return compiler.process(case((arg1 > arg2, arg1), else_=arg2), **kw)
Example usage::
Session.query(Account).filter(
greatest(Account.checking_balance, Account.savings_balance) > 10000
)
"false" expression
------------------
Render a "false" constant expression, rendering as "0" on platforms that
don't have a "false" constant::
from sqlalchemy.sql import expression
from sqlalchemy.ext.compiler import compiles
class sql_false(expression.ColumnElement):
inherit_cache = True
@compiles(sql_false)
def default_false(element, compiler, **kw):
return "false"
@compiles(sql_false, "mssql")
@compiles(sql_false, "mysql")
@compiles(sql_false, "oracle")
def int_false(element, compiler, **kw):
return "0"
Example usage::
from sqlalchemy import select, union_all
exp = union_all(
select(users.c.name, sql_false().label("enrolled")),
select(customers.c.name, customers.c.enrolled),
)
"""
from __future__ import annotations
from typing import Any
from typing import Callable
from typing import Dict
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from .. import exc
from ..sql import sqltypes
if TYPE_CHECKING:
from ..sql.compiler import SQLCompiler
_F = TypeVar("_F", bound=Callable[..., Any])
def compiles(class_: Type[Any], *specs: str) -> Callable[[_F], _F]:
"""Register a function as a compiler for a
given :class:`_expression.ClauseElement` type."""
def decorate(fn: _F) -> _F:
# get an existing @compiles handler
existing = class_.__dict__.get("_compiler_dispatcher", None)
# get the original handler. All ClauseElement classes have one
# of these, but some TypeEngine classes will not.
existing_dispatch = getattr(class_, "_compiler_dispatch", None)
if not existing:
existing = _dispatcher()
if existing_dispatch:
def _wrap_existing_dispatch(
element: Any, compiler: SQLCompiler, **kw: Any
) -> Any:
try:
return existing_dispatch(element, compiler, **kw)
except exc.UnsupportedCompilationError as uce:
raise exc.UnsupportedCompilationError(
compiler,
type(element),
message="%s construct has no default "
"compilation handler." % type(element),
) from uce
existing.specs["default"] = _wrap_existing_dispatch
# TODO: why is the lambda needed ?
setattr(
class_,
"_compiler_dispatch",
lambda *arg, **kw: existing(*arg, **kw),
)
setattr(class_, "_compiler_dispatcher", existing)
if specs:
for s in specs:
existing.specs[s] = fn
else:
existing.specs["default"] = fn
return fn
return decorate
def deregister(class_: Type[Any]) -> None:
"""Remove all custom compilers associated with a given
:class:`_expression.ClauseElement` type.
"""
if hasattr(class_, "_compiler_dispatcher"):
class_._compiler_dispatch = class_._original_compiler_dispatch
del class_._compiler_dispatcher
class _dispatcher:
def __init__(self) -> None:
self.specs: Dict[str, Callable[..., Any]] = {}
def __call__(self, element: Any, compiler: SQLCompiler, **kw: Any) -> Any:
# TODO: yes, this could also switch off of DBAPI in use.
fn = self.specs.get(compiler.dialect.name, None)
if not fn:
try:
fn = self.specs["default"]
except KeyError as ke:
raise exc.UnsupportedCompilationError(
compiler,
type(element),
message="%s construct has no default "
"compilation handler." % type(element),
) from ke
# if compilation includes add_to_result_map, collect add_to_result_map
# arguments from the user-defined callable, which are probably none
# because this is not public API. if it wasn't called, then call it
# ourselves.
arm = kw.get("add_to_result_map", None)
if arm:
arm_collection = []
kw["add_to_result_map"] = lambda *args: arm_collection.append(args)
expr = fn(element, compiler, **kw)
if arm:
if not arm_collection:
arm_collection.append(
(None, None, (element,), sqltypes.NULLTYPE)
)
for tup in arm_collection:
arm(*tup)
return expr

View File

@ -0,0 +1,65 @@
# ext/declarative/__init__.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from .extensions import AbstractConcreteBase
from .extensions import ConcreteBase
from .extensions import DeferredReflection
from ... import util
from ...orm.decl_api import as_declarative as _as_declarative
from ...orm.decl_api import declarative_base as _declarative_base
from ...orm.decl_api import DeclarativeMeta
from ...orm.decl_api import declared_attr
from ...orm.decl_api import has_inherited_table as _has_inherited_table
from ...orm.decl_api import synonym_for as _synonym_for
@util.moved_20(
"The ``declarative_base()`` function is now available as "
":func:`sqlalchemy.orm.declarative_base`."
)
def declarative_base(*arg, **kw):
return _declarative_base(*arg, **kw)
@util.moved_20(
"The ``as_declarative()`` function is now available as "
":func:`sqlalchemy.orm.as_declarative`"
)
def as_declarative(*arg, **kw):
return _as_declarative(*arg, **kw)
@util.moved_20(
"The ``has_inherited_table()`` function is now available as "
":func:`sqlalchemy.orm.has_inherited_table`."
)
def has_inherited_table(*arg, **kw):
return _has_inherited_table(*arg, **kw)
@util.moved_20(
"The ``synonym_for()`` function is now available as "
":func:`sqlalchemy.orm.synonym_for`"
)
def synonym_for(*arg, **kw):
return _synonym_for(*arg, **kw)
__all__ = [
"declarative_base",
"synonym_for",
"has_inherited_table",
"instrument_declarative",
"declared_attr",
"as_declarative",
"ConcreteBase",
"AbstractConcreteBase",
"DeclarativeMeta",
"DeferredReflection",
]

View File

@ -0,0 +1,564 @@
# ext/declarative/extensions.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""Public API functions and helpers for declarative."""
from __future__ import annotations
import collections
import contextlib
from typing import Any
from typing import Callable
from typing import TYPE_CHECKING
from typing import Union
from ... import exc as sa_exc
from ...engine import Connection
from ...engine import Engine
from ...orm import exc as orm_exc
from ...orm import relationships
from ...orm.base import _mapper_or_none
from ...orm.clsregistry import _resolver
from ...orm.decl_base import _DeferredMapperConfig
from ...orm.util import polymorphic_union
from ...schema import Table
from ...util import OrderedDict
if TYPE_CHECKING:
from ...sql.schema import MetaData
class ConcreteBase:
"""A helper class for 'concrete' declarative mappings.
:class:`.ConcreteBase` will use the :func:`.polymorphic_union`
function automatically, against all tables mapped as a subclass
to this class. The function is called via the
``__declare_last__()`` function, which is essentially
a hook for the :meth:`.after_configured` event.
:class:`.ConcreteBase` produces a mapped
table for the class itself. Compare to :class:`.AbstractConcreteBase`,
which does not.
Example::
from sqlalchemy.ext.declarative import ConcreteBase
class Employee(ConcreteBase, Base):
__tablename__ = "employee"
employee_id = Column(Integer, primary_key=True)
name = Column(String(50))
__mapper_args__ = {
"polymorphic_identity": "employee",
"concrete": True,
}
class Manager(Employee):
__tablename__ = "manager"
employee_id = Column(Integer, primary_key=True)
name = Column(String(50))
manager_data = Column(String(40))
__mapper_args__ = {
"polymorphic_identity": "manager",
"concrete": True,
}
The name of the discriminator column used by :func:`.polymorphic_union`
defaults to the name ``type``. To suit the use case of a mapping where an
actual column in a mapped table is already named ``type``, the
discriminator name can be configured by setting the
``_concrete_discriminator_name`` attribute::
class Employee(ConcreteBase, Base):
_concrete_discriminator_name = "_concrete_discriminator"
.. versionadded:: 1.3.19 Added the ``_concrete_discriminator_name``
attribute to :class:`_declarative.ConcreteBase` so that the
virtual discriminator column name can be customized.
.. versionchanged:: 1.4.2 The ``_concrete_discriminator_name`` attribute
need only be placed on the basemost class to take correct effect for
all subclasses. An explicit error message is now raised if the
mapped column names conflict with the discriminator name, whereas
in the 1.3.x series there would be some warnings and then a non-useful
query would be generated.
.. seealso::
:class:`.AbstractConcreteBase`
:ref:`concrete_inheritance`
"""
@classmethod
def _create_polymorphic_union(cls, mappers, discriminator_name):
return polymorphic_union(
OrderedDict(
(mp.polymorphic_identity, mp.local_table) for mp in mappers
),
discriminator_name,
"pjoin",
)
@classmethod
def __declare_first__(cls):
m = cls.__mapper__
if m.with_polymorphic:
return
discriminator_name = (
getattr(cls, "_concrete_discriminator_name", None) or "type"
)
mappers = list(m.self_and_descendants)
pjoin = cls._create_polymorphic_union(mappers, discriminator_name)
m._set_with_polymorphic(("*", pjoin))
m._set_polymorphic_on(pjoin.c[discriminator_name])
class AbstractConcreteBase(ConcreteBase):
"""A helper class for 'concrete' declarative mappings.
:class:`.AbstractConcreteBase` will use the :func:`.polymorphic_union`
function automatically, against all tables mapped as a subclass
to this class. The function is called via the
``__declare_first__()`` function, which is essentially
a hook for the :meth:`.before_configured` event.
:class:`.AbstractConcreteBase` applies :class:`_orm.Mapper` for its
immediately inheriting class, as would occur for any other
declarative mapped class. However, the :class:`_orm.Mapper` is not
mapped to any particular :class:`.Table` object. Instead, it's
mapped directly to the "polymorphic" selectable produced by
:func:`.polymorphic_union`, and performs no persistence operations on its
own. Compare to :class:`.ConcreteBase`, which maps its
immediately inheriting class to an actual
:class:`.Table` that stores rows directly.
.. note::
The :class:`.AbstractConcreteBase` delays the mapper creation of the
base class until all the subclasses have been defined,
as it needs to create a mapping against a selectable that will include
all subclass tables. In order to achieve this, it waits for the
**mapper configuration event** to occur, at which point it scans
through all the configured subclasses and sets up a mapping that will
query against all subclasses at once.
While this event is normally invoked automatically, in the case of
:class:`.AbstractConcreteBase`, it may be necessary to invoke it
explicitly after **all** subclass mappings are defined, if the first
operation is to be a query against this base class. To do so, once all
the desired classes have been configured, the
:meth:`_orm.registry.configure` method on the :class:`_orm.registry`
in use can be invoked, which is available in relation to a particular
declarative base class::
Base.registry.configure()
Example::
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.ext.declarative import AbstractConcreteBase
class Base(DeclarativeBase):
pass
class Employee(AbstractConcreteBase, Base):
pass
class Manager(Employee):
__tablename__ = "manager"
employee_id = Column(Integer, primary_key=True)
name = Column(String(50))
manager_data = Column(String(40))
__mapper_args__ = {
"polymorphic_identity": "manager",
"concrete": True,
}
Base.registry.configure()
The abstract base class is handled by declarative in a special way;
at class configuration time, it behaves like a declarative mixin
or an ``__abstract__`` base class. Once classes are configured
and mappings are produced, it then gets mapped itself, but
after all of its descendants. This is a very unique system of mapping
not found in any other SQLAlchemy API feature.
Using this approach, we can specify columns and properties
that will take place on mapped subclasses, in the way that
we normally do as in :ref:`declarative_mixins`::
from sqlalchemy.ext.declarative import AbstractConcreteBase
class Company(Base):
__tablename__ = "company"
id = Column(Integer, primary_key=True)
class Employee(AbstractConcreteBase, Base):
strict_attrs = True
employee_id = Column(Integer, primary_key=True)
@declared_attr
def company_id(cls):
return Column(ForeignKey("company.id"))
@declared_attr
def company(cls):
return relationship("Company")
class Manager(Employee):
__tablename__ = "manager"
name = Column(String(50))
manager_data = Column(String(40))
__mapper_args__ = {
"polymorphic_identity": "manager",
"concrete": True,
}
Base.registry.configure()
When we make use of our mappings however, both ``Manager`` and
``Employee`` will have an independently usable ``.company`` attribute::
session.execute(select(Employee).filter(Employee.company.has(id=5)))
:param strict_attrs: when specified on the base class, "strict" attribute
mode is enabled which attempts to limit ORM mapped attributes on the
base class to only those that are immediately present, while still
preserving "polymorphic" loading behavior.
.. versionadded:: 2.0
.. seealso::
:class:`.ConcreteBase`
:ref:`concrete_inheritance`
:ref:`abstract_concrete_base`
"""
__no_table__ = True
@classmethod
def __declare_first__(cls):
cls._sa_decl_prepare_nocascade()
@classmethod
def _sa_decl_prepare_nocascade(cls):
if getattr(cls, "__mapper__", None):
return
to_map = _DeferredMapperConfig.config_for_cls(cls)
# can't rely on 'self_and_descendants' here
# since technically an immediate subclass
# might not be mapped, but a subclass
# may be.
mappers = []
stack = list(cls.__subclasses__())
while stack:
klass = stack.pop()
stack.extend(klass.__subclasses__())
mn = _mapper_or_none(klass)
if mn is not None:
mappers.append(mn)
discriminator_name = (
getattr(cls, "_concrete_discriminator_name", None) or "type"
)
pjoin = cls._create_polymorphic_union(mappers, discriminator_name)
# For columns that were declared on the class, these
# are normally ignored with the "__no_table__" mapping,
# unless they have a different attribute key vs. col name
# and are in the properties argument.
# In that case, ensure we update the properties entry
# to the correct column from the pjoin target table.
declared_cols = set(to_map.declared_columns)
declared_col_keys = {c.key for c in declared_cols}
for k, v in list(to_map.properties.items()):
if v in declared_cols:
to_map.properties[k] = pjoin.c[v.key]
declared_col_keys.remove(v.key)
to_map.local_table = pjoin
strict_attrs = cls.__dict__.get("strict_attrs", False)
m_args = to_map.mapper_args_fn or dict
def mapper_args():
args = m_args()
args["polymorphic_on"] = pjoin.c[discriminator_name]
args["polymorphic_abstract"] = True
if strict_attrs:
args["include_properties"] = (
set(pjoin.primary_key)
| declared_col_keys
| {discriminator_name}
)
args["with_polymorphic"] = ("*", pjoin)
return args
to_map.mapper_args_fn = mapper_args
to_map.map()
stack = [cls]
while stack:
scls = stack.pop(0)
stack.extend(scls.__subclasses__())
sm = _mapper_or_none(scls)
if sm and sm.concrete and sm.inherits is None:
for sup_ in scls.__mro__[1:]:
sup_sm = _mapper_or_none(sup_)
if sup_sm:
sm._set_concrete_base(sup_sm)
break
@classmethod
def _sa_raise_deferred_config(cls):
raise orm_exc.UnmappedClassError(
cls,
msg="Class %s is a subclass of AbstractConcreteBase and "
"has a mapping pending until all subclasses are defined. "
"Call the sqlalchemy.orm.configure_mappers() function after "
"all subclasses have been defined to "
"complete the mapping of this class."
% orm_exc._safe_cls_name(cls),
)
class DeferredReflection:
"""A helper class for construction of mappings based on
a deferred reflection step.
Normally, declarative can be used with reflection by
setting a :class:`_schema.Table` object using autoload_with=engine
as the ``__table__`` attribute on a declarative class.
The caveat is that the :class:`_schema.Table` must be fully
reflected, or at the very least have a primary key column,
at the point at which a normal declarative mapping is
constructed, meaning the :class:`_engine.Engine` must be available
at class declaration time.
The :class:`.DeferredReflection` mixin moves the construction
of mappers to be at a later point, after a specific
method is called which first reflects all :class:`_schema.Table`
objects created so far. Classes can define it as such::
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.declarative import DeferredReflection
Base = declarative_base()
class MyClass(DeferredReflection, Base):
__tablename__ = "mytable"
Above, ``MyClass`` is not yet mapped. After a series of
classes have been defined in the above fashion, all tables
can be reflected and mappings created using
:meth:`.prepare`::
engine = create_engine("someengine://...")
DeferredReflection.prepare(engine)
The :class:`.DeferredReflection` mixin can be applied to individual
classes, used as the base for the declarative base itself,
or used in a custom abstract class. Using an abstract base
allows that only a subset of classes to be prepared for a
particular prepare step, which is necessary for applications
that use more than one engine. For example, if an application
has two engines, you might use two bases, and prepare each
separately, e.g.::
class ReflectedOne(DeferredReflection, Base):
__abstract__ = True
class ReflectedTwo(DeferredReflection, Base):
__abstract__ = True
class MyClass(ReflectedOne):
__tablename__ = "mytable"
class MyOtherClass(ReflectedOne):
__tablename__ = "myothertable"
class YetAnotherClass(ReflectedTwo):
__tablename__ = "yetanothertable"
# ... etc.
Above, the class hierarchies for ``ReflectedOne`` and
``ReflectedTwo`` can be configured separately::
ReflectedOne.prepare(engine_one)
ReflectedTwo.prepare(engine_two)
.. seealso::
:ref:`orm_declarative_reflected_deferred_reflection` - in the
:ref:`orm_declarative_table_config_toplevel` section.
"""
@classmethod
def prepare(
cls, bind: Union[Engine, Connection], **reflect_kw: Any
) -> None:
r"""Reflect all :class:`_schema.Table` objects for all current
:class:`.DeferredReflection` subclasses
:param bind: :class:`_engine.Engine` or :class:`_engine.Connection`
instance
..versionchanged:: 2.0.16 a :class:`_engine.Connection` is also
accepted.
:param \**reflect_kw: additional keyword arguments passed to
:meth:`_schema.MetaData.reflect`, such as
:paramref:`_schema.MetaData.reflect.views`.
.. versionadded:: 2.0.16
"""
to_map = _DeferredMapperConfig.classes_for_base(cls)
metadata_to_table = collections.defaultdict(set)
# first collect the primary __table__ for each class into a
# collection of metadata/schemaname -> table names
for thingy in to_map:
if thingy.local_table is not None:
metadata_to_table[
(thingy.local_table.metadata, thingy.local_table.schema)
].add(thingy.local_table.name)
# then reflect all those tables into their metadatas
if isinstance(bind, Connection):
conn = bind
ctx = contextlib.nullcontext(enter_result=conn)
elif isinstance(bind, Engine):
ctx = bind.connect()
else:
raise sa_exc.ArgumentError(
f"Expected Engine or Connection, got {bind!r}"
)
with ctx as conn:
for (metadata, schema), table_names in metadata_to_table.items():
metadata.reflect(
conn,
only=table_names,
schema=schema,
extend_existing=True,
autoload_replace=False,
**reflect_kw,
)
metadata_to_table.clear()
# .map() each class, then go through relationships and look
# for secondary
for thingy in to_map:
thingy.map()
mapper = thingy.cls.__mapper__
metadata = mapper.class_.metadata
for rel in mapper._props.values():
if (
isinstance(rel, relationships.RelationshipProperty)
and rel._init_args.secondary._is_populated()
):
secondary_arg = rel._init_args.secondary
if isinstance(secondary_arg.argument, Table):
secondary_table = secondary_arg.argument
metadata_to_table[
(
secondary_table.metadata,
secondary_table.schema,
)
].add(secondary_table.name)
elif isinstance(secondary_arg.argument, str):
_, resolve_arg = _resolver(rel.parent.class_, rel)
resolver = resolve_arg(
secondary_arg.argument, True
)
metadata_to_table[
(metadata, thingy.local_table.schema)
].add(secondary_arg.argument)
resolver._resolvers += (
cls._sa_deferred_table_resolver(metadata),
)
secondary_arg.argument = resolver()
for (metadata, schema), table_names in metadata_to_table.items():
metadata.reflect(
conn,
only=table_names,
schema=schema,
extend_existing=True,
autoload_replace=False,
)
@classmethod
def _sa_deferred_table_resolver(
cls, metadata: MetaData
) -> Callable[[str], Table]:
def _resolve(key: str) -> Table:
# reflection has already occurred so this Table would have
# its contents already
return Table(key, metadata)
return _resolve
_sa_decl_prepare = True
@classmethod
def _sa_raise_deferred_config(cls):
raise orm_exc.UnmappedClassError(
cls,
msg="Class %s is a subclass of DeferredReflection. "
"Mappings are not produced until the .prepare() "
"method is called on the class hierarchy."
% orm_exc._safe_cls_name(cls),
)

View File

@ -0,0 +1,478 @@
# ext/horizontal_shard.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""Horizontal sharding support.
Defines a rudimental 'horizontal sharding' system which allows a Session to
distribute queries and persistence operations across multiple databases.
For a usage example, see the :ref:`examples_sharding` example included in
the source distribution.
.. deepalchemy:: The horizontal sharding extension is an advanced feature,
involving a complex statement -> database interaction as well as
use of semi-public APIs for non-trivial cases. Simpler approaches to
refering to multiple database "shards", most commonly using a distinct
:class:`_orm.Session` per "shard", should always be considered first
before using this more complex and less-production-tested system.
"""
from __future__ import annotations
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Optional
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from .. import event
from .. import exc
from .. import inspect
from .. import util
from ..orm import PassiveFlag
from ..orm._typing import OrmExecuteOptionsParameter
from ..orm.interfaces import ORMOption
from ..orm.mapper import Mapper
from ..orm.query import Query
from ..orm.session import _BindArguments
from ..orm.session import _PKIdentityArgument
from ..orm.session import Session
from ..util.typing import Protocol
from ..util.typing import Self
if TYPE_CHECKING:
from ..engine.base import Connection
from ..engine.base import Engine
from ..engine.base import OptionEngine
from ..engine.result import IteratorResult
from ..engine.result import Result
from ..orm import LoaderCallableStatus
from ..orm._typing import _O
from ..orm.bulk_persistence import BulkUDCompileState
from ..orm.context import QueryContext
from ..orm.session import _EntityBindKey
from ..orm.session import _SessionBind
from ..orm.session import ORMExecuteState
from ..orm.state import InstanceState
from ..sql import Executable
from ..sql._typing import _TP
from ..sql.elements import ClauseElement
__all__ = ["ShardedSession", "ShardedQuery"]
_T = TypeVar("_T", bound=Any)
ShardIdentifier = str
class ShardChooser(Protocol):
def __call__(
self,
mapper: Optional[Mapper[_T]],
instance: Any,
clause: Optional[ClauseElement],
) -> Any: ...
class IdentityChooser(Protocol):
def __call__(
self,
mapper: Mapper[_T],
primary_key: _PKIdentityArgument,
*,
lazy_loaded_from: Optional[InstanceState[Any]],
execution_options: OrmExecuteOptionsParameter,
bind_arguments: _BindArguments,
**kw: Any,
) -> Any: ...
class ShardedQuery(Query[_T]):
"""Query class used with :class:`.ShardedSession`.
.. legacy:: The :class:`.ShardedQuery` is a subclass of the legacy
:class:`.Query` class. The :class:`.ShardedSession` now supports
2.0 style execution via the :meth:`.ShardedSession.execute` method.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
assert isinstance(self.session, ShardedSession)
self.identity_chooser = self.session.identity_chooser
self.execute_chooser = self.session.execute_chooser
self._shard_id = None
def set_shard(self, shard_id: ShardIdentifier) -> Self:
"""Return a new query, limited to a single shard ID.
All subsequent operations with the returned query will
be against the single shard regardless of other state.
The shard_id can be passed for a 2.0 style execution to the
bind_arguments dictionary of :meth:`.Session.execute`::
results = session.execute(stmt, bind_arguments={"shard_id": "my_shard"})
""" # noqa: E501
return self.execution_options(_sa_shard_id=shard_id)
class ShardedSession(Session):
shard_chooser: ShardChooser
identity_chooser: IdentityChooser
execute_chooser: Callable[[ORMExecuteState], Iterable[Any]]
def __init__(
self,
shard_chooser: ShardChooser,
identity_chooser: Optional[IdentityChooser] = None,
execute_chooser: Optional[
Callable[[ORMExecuteState], Iterable[Any]]
] = None,
shards: Optional[Dict[str, Any]] = None,
query_cls: Type[Query[_T]] = ShardedQuery,
*,
id_chooser: Optional[
Callable[[Query[_T], Iterable[_T]], Iterable[Any]]
] = None,
query_chooser: Optional[Callable[[Executable], Iterable[Any]]] = None,
**kwargs: Any,
) -> None:
"""Construct a ShardedSession.
:param shard_chooser: A callable which, passed a Mapper, a mapped
instance, and possibly a SQL clause, returns a shard ID. This id
may be based off of the attributes present within the object, or on
some round-robin scheme. If the scheme is based on a selection, it
should set whatever state on the instance to mark it in the future as
participating in that shard.
:param identity_chooser: A callable, passed a Mapper and primary key
argument, which should return a list of shard ids where this
primary key might reside.
.. versionchanged:: 2.0 The ``identity_chooser`` parameter
supersedes the ``id_chooser`` parameter.
:param execute_chooser: For a given :class:`.ORMExecuteState`,
returns the list of shard_ids
where the query should be issued. Results from all shards returned
will be combined together into a single listing.
.. versionchanged:: 1.4 The ``execute_chooser`` parameter
supersedes the ``query_chooser`` parameter.
:param shards: A dictionary of string shard names
to :class:`~sqlalchemy.engine.Engine` objects.
"""
super().__init__(query_cls=query_cls, **kwargs)
event.listen(
self, "do_orm_execute", execute_and_instances, retval=True
)
self.shard_chooser = shard_chooser
if id_chooser:
_id_chooser = id_chooser
util.warn_deprecated(
"The ``id_chooser`` parameter is deprecated; "
"please use ``identity_chooser``.",
"2.0",
)
def _legacy_identity_chooser(
mapper: Mapper[_T],
primary_key: _PKIdentityArgument,
*,
lazy_loaded_from: Optional[InstanceState[Any]],
execution_options: OrmExecuteOptionsParameter,
bind_arguments: _BindArguments,
**kw: Any,
) -> Any:
q = self.query(mapper)
if lazy_loaded_from:
q = q._set_lazyload_from(lazy_loaded_from)
return _id_chooser(q, primary_key)
self.identity_chooser = _legacy_identity_chooser
elif identity_chooser:
self.identity_chooser = identity_chooser
else:
raise exc.ArgumentError(
"identity_chooser or id_chooser is required"
)
if query_chooser:
_query_chooser = query_chooser
util.warn_deprecated(
"The ``query_chooser`` parameter is deprecated; "
"please use ``execute_chooser``.",
"1.4",
)
if execute_chooser:
raise exc.ArgumentError(
"Can't pass query_chooser and execute_chooser "
"at the same time."
)
def _default_execute_chooser(
orm_context: ORMExecuteState,
) -> Iterable[Any]:
return _query_chooser(orm_context.statement)
if execute_chooser is None:
execute_chooser = _default_execute_chooser
if execute_chooser is None:
raise exc.ArgumentError(
"execute_chooser or query_chooser is required"
)
self.execute_chooser = execute_chooser
self.__shards: Dict[ShardIdentifier, _SessionBind] = {}
if shards is not None:
for k in shards:
self.bind_shard(k, shards[k])
def _identity_lookup(
self,
mapper: Mapper[_O],
primary_key_identity: Union[Any, Tuple[Any, ...]],
identity_token: Optional[Any] = None,
passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
lazy_loaded_from: Optional[InstanceState[Any]] = None,
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> Union[Optional[_O], LoaderCallableStatus]:
"""override the default :meth:`.Session._identity_lookup` method so
that we search for a given non-token primary key identity across all
possible identity tokens (e.g. shard ids).
.. versionchanged:: 1.4 Moved :meth:`.Session._identity_lookup` from
the :class:`_query.Query` object to the :class:`.Session`.
"""
if identity_token is not None:
obj = super()._identity_lookup(
mapper,
primary_key_identity,
identity_token=identity_token,
**kw,
)
return obj
else:
for shard_id in self.identity_chooser(
mapper,
primary_key_identity,
lazy_loaded_from=lazy_loaded_from,
execution_options=execution_options,
bind_arguments=dict(bind_arguments) if bind_arguments else {},
):
obj2 = super()._identity_lookup(
mapper,
primary_key_identity,
identity_token=shard_id,
lazy_loaded_from=lazy_loaded_from,
**kw,
)
if obj2 is not None:
return obj2
return None
def _choose_shard_and_assign(
self,
mapper: Optional[_EntityBindKey[_O]],
instance: Any,
**kw: Any,
) -> Any:
if instance is not None:
state = inspect(instance)
if state.key:
token = state.key[2]
assert token is not None
return token
elif state.identity_token:
return state.identity_token
assert isinstance(mapper, Mapper)
shard_id = self.shard_chooser(mapper, instance, **kw)
if instance is not None:
state.identity_token = shard_id
return shard_id
def connection_callable(
self,
mapper: Optional[Mapper[_T]] = None,
instance: Optional[Any] = None,
shard_id: Optional[ShardIdentifier] = None,
**kw: Any,
) -> Connection:
"""Provide a :class:`_engine.Connection` to use in the unit of work
flush process.
"""
if shard_id is None:
shard_id = self._choose_shard_and_assign(mapper, instance)
if self.in_transaction():
trans = self.get_transaction()
assert trans is not None
return trans.connection(mapper, shard_id=shard_id)
else:
bind = self.get_bind(
mapper=mapper, shard_id=shard_id, instance=instance
)
if isinstance(bind, Engine):
return bind.connect(**kw)
else:
assert isinstance(bind, Connection)
return bind
def get_bind(
self,
mapper: Optional[_EntityBindKey[_O]] = None,
*,
shard_id: Optional[ShardIdentifier] = None,
instance: Optional[Any] = None,
clause: Optional[ClauseElement] = None,
**kw: Any,
) -> _SessionBind:
if shard_id is None:
shard_id = self._choose_shard_and_assign(
mapper, instance=instance, clause=clause
)
assert shard_id is not None
return self.__shards[shard_id]
def bind_shard(
self, shard_id: ShardIdentifier, bind: Union[Engine, OptionEngine]
) -> None:
self.__shards[shard_id] = bind
class set_shard_id(ORMOption):
"""a loader option for statements to apply a specific shard id to the
primary query as well as for additional relationship and column
loaders.
The :class:`_horizontal.set_shard_id` option may be applied using
the :meth:`_sql.Executable.options` method of any executable statement::
stmt = (
select(MyObject)
.where(MyObject.name == "some name")
.options(set_shard_id("shard1"))
)
Above, the statement when invoked will limit to the "shard1" shard
identifier for the primary query as well as for all relationship and
column loading strategies, including eager loaders such as
:func:`_orm.selectinload`, deferred column loaders like :func:`_orm.defer`,
and the lazy relationship loader :func:`_orm.lazyload`.
In this way, the :class:`_horizontal.set_shard_id` option has much wider
scope than using the "shard_id" argument within the
:paramref:`_orm.Session.execute.bind_arguments` dictionary.
.. versionadded:: 2.0.0
"""
__slots__ = ("shard_id", "propagate_to_loaders")
def __init__(
self, shard_id: ShardIdentifier, propagate_to_loaders: bool = True
):
"""Construct a :class:`_horizontal.set_shard_id` option.
:param shard_id: shard identifier
:param propagate_to_loaders: if left at its default of ``True``, the
shard option will take place for lazy loaders such as
:func:`_orm.lazyload` and :func:`_orm.defer`; if False, the option
will not be propagated to loaded objects. Note that :func:`_orm.defer`
always limits to the shard_id of the parent row in any case, so the
parameter only has a net effect on the behavior of the
:func:`_orm.lazyload` strategy.
"""
self.shard_id = shard_id
self.propagate_to_loaders = propagate_to_loaders
def execute_and_instances(
orm_context: ORMExecuteState,
) -> Union[Result[_T], IteratorResult[_TP]]:
active_options: Union[
None,
QueryContext.default_load_options,
Type[QueryContext.default_load_options],
BulkUDCompileState.default_update_options,
Type[BulkUDCompileState.default_update_options],
]
if orm_context.is_select:
active_options = orm_context.load_options
elif orm_context.is_update or orm_context.is_delete:
active_options = orm_context.update_delete_options
else:
active_options = None
session = orm_context.session
assert isinstance(session, ShardedSession)
def iter_for_shard(
shard_id: ShardIdentifier,
) -> Union[Result[_T], IteratorResult[_TP]]:
bind_arguments = dict(orm_context.bind_arguments)
bind_arguments["shard_id"] = shard_id
orm_context.update_execution_options(identity_token=shard_id)
return orm_context.invoke_statement(bind_arguments=bind_arguments)
for orm_opt in orm_context._non_compile_orm_options:
# TODO: if we had an ORMOption that gets applied at ORM statement
# execution time, that would allow this to be more generalized.
# for now just iterate and look for our options
if isinstance(orm_opt, set_shard_id):
shard_id = orm_opt.shard_id
break
else:
if active_options and active_options._identity_token is not None:
shard_id = active_options._identity_token
elif "_sa_shard_id" in orm_context.execution_options:
shard_id = orm_context.execution_options["_sa_shard_id"]
elif "shard_id" in orm_context.bind_arguments:
shard_id = orm_context.bind_arguments["shard_id"]
else:
shard_id = None
if shard_id is not None:
return iter_for_shard(shard_id)
else:
partial = []
for shard_id in session.execute_chooser(orm_context):
result_ = iter_for_shard(shard_id)
partial.append(result_)
return partial[0].merge(*partial[1:])

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,345 @@
# ext/indexable.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""Define attributes on ORM-mapped classes that have "index" attributes for
columns with :class:`_types.Indexable` types.
"index" means the attribute is associated with an element of an
:class:`_types.Indexable` column with the predefined index to access it.
The :class:`_types.Indexable` types include types such as
:class:`_types.ARRAY`, :class:`_types.JSON` and
:class:`_postgresql.HSTORE`.
The :mod:`~sqlalchemy.ext.indexable` extension provides
:class:`_schema.Column`-like interface for any element of an
:class:`_types.Indexable` typed column. In simple cases, it can be
treated as a :class:`_schema.Column` - mapped attribute.
Synopsis
========
Given ``Person`` as a model with a primary key and JSON data field.
While this field may have any number of elements encoded within it,
we would like to refer to the element called ``name`` individually
as a dedicated attribute which behaves like a standalone column::
from sqlalchemy import Column, JSON, Integer
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.indexable import index_property
Base = declarative_base()
class Person(Base):
__tablename__ = "person"
id = Column(Integer, primary_key=True)
data = Column(JSON)
name = index_property("data", "name")
Above, the ``name`` attribute now behaves like a mapped column. We
can compose a new ``Person`` and set the value of ``name``::
>>> person = Person(name="Alchemist")
The value is now accessible::
>>> person.name
'Alchemist'
Behind the scenes, the JSON field was initialized to a new blank dictionary
and the field was set::
>>> person.data
{'name': 'Alchemist'}
The field is mutable in place::
>>> person.name = "Renamed"
>>> person.name
'Renamed'
>>> person.data
{'name': 'Renamed'}
When using :class:`.index_property`, the change that we make to the indexable
structure is also automatically tracked as history; we no longer need
to use :class:`~.mutable.MutableDict` in order to track this change
for the unit of work.
Deletions work normally as well::
>>> del person.name
>>> person.data
{}
Above, deletion of ``person.name`` deletes the value from the dictionary,
but not the dictionary itself.
A missing key will produce ``AttributeError``::
>>> person = Person()
>>> person.name
AttributeError: 'name'
Unless you set a default value::
>>> class Person(Base):
... __tablename__ = "person"
...
... id = Column(Integer, primary_key=True)
... data = Column(JSON)
...
... name = index_property("data", "name", default=None) # See default
>>> person = Person()
>>> print(person.name)
None
The attributes are also accessible at the class level.
Below, we illustrate ``Person.name`` used to generate
an indexed SQL criteria::
>>> from sqlalchemy.orm import Session
>>> session = Session()
>>> query = session.query(Person).filter(Person.name == "Alchemist")
The above query is equivalent to::
>>> query = session.query(Person).filter(Person.data["name"] == "Alchemist")
Multiple :class:`.index_property` objects can be chained to produce
multiple levels of indexing::
from sqlalchemy import Column, JSON, Integer
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.indexable import index_property
Base = declarative_base()
class Person(Base):
__tablename__ = "person"
id = Column(Integer, primary_key=True)
data = Column(JSON)
birthday = index_property("data", "birthday")
year = index_property("birthday", "year")
month = index_property("birthday", "month")
day = index_property("birthday", "day")
Above, a query such as::
q = session.query(Person).filter(Person.year == "1980")
On a PostgreSQL backend, the above query will render as:
.. sourcecode:: sql
SELECT person.id, person.data
FROM person
WHERE person.data -> %(data_1)s -> %(param_1)s = %(param_2)s
Default Values
==============
:class:`.index_property` includes special behaviors for when the indexed
data structure does not exist, and a set operation is called:
* For an :class:`.index_property` that is given an integer index value,
the default data structure will be a Python list of ``None`` values,
at least as long as the index value; the value is then set at its
place in the list. This means for an index value of zero, the list
will be initialized to ``[None]`` before setting the given value,
and for an index value of five, the list will be initialized to
``[None, None, None, None, None]`` before setting the fifth element
to the given value. Note that an existing list is **not** extended
in place to receive a value.
* for an :class:`.index_property` that is given any other kind of index
value (e.g. strings usually), a Python dictionary is used as the
default data structure.
* The default data structure can be set to any Python callable using the
:paramref:`.index_property.datatype` parameter, overriding the previous
rules.
Subclassing
===========
:class:`.index_property` can be subclassed, in particular for the common
use case of providing coercion of values or SQL expressions as they are
accessed. Below is a common recipe for use with a PostgreSQL JSON type,
where we want to also include automatic casting plus ``astext()``::
class pg_json_property(index_property):
def __init__(self, attr_name, index, cast_type):
super(pg_json_property, self).__init__(attr_name, index)
self.cast_type = cast_type
def expr(self, model):
expr = super(pg_json_property, self).expr(model)
return expr.astext.cast(self.cast_type)
The above subclass can be used with the PostgreSQL-specific
version of :class:`_postgresql.JSON`::
from sqlalchemy import Column, Integer
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.dialects.postgresql import JSON
Base = declarative_base()
class Person(Base):
__tablename__ = "person"
id = Column(Integer, primary_key=True)
data = Column(JSON)
age = pg_json_property("data", "age", Integer)
The ``age`` attribute at the instance level works as before; however
when rendering SQL, PostgreSQL's ``->>`` operator will be used
for indexed access, instead of the usual index operator of ``->``::
>>> query = session.query(Person).filter(Person.age < 20)
The above query will render:
.. sourcecode:: sql
SELECT person.id, person.data
FROM person
WHERE CAST(person.data ->> %(data_1)s AS INTEGER) < %(param_1)s
""" # noqa
from .. import inspect
from ..ext.hybrid import hybrid_property
from ..orm.attributes import flag_modified
__all__ = ["index_property"]
class index_property(hybrid_property): # noqa
"""A property generator. The generated property describes an object
attribute that corresponds to an :class:`_types.Indexable`
column.
.. seealso::
:mod:`sqlalchemy.ext.indexable`
"""
_NO_DEFAULT_ARGUMENT = object()
def __init__(
self,
attr_name,
index,
default=_NO_DEFAULT_ARGUMENT,
datatype=None,
mutable=True,
onebased=True,
):
"""Create a new :class:`.index_property`.
:param attr_name:
An attribute name of an `Indexable` typed column, or other
attribute that returns an indexable structure.
:param index:
The index to be used for getting and setting this value. This
should be the Python-side index value for integers.
:param default:
A value which will be returned instead of `AttributeError`
when there is not a value at given index.
:param datatype: default datatype to use when the field is empty.
By default, this is derived from the type of index used; a
Python list for an integer index, or a Python dictionary for
any other style of index. For a list, the list will be
initialized to a list of None values that is at least
``index`` elements long.
:param mutable: if False, writes and deletes to the attribute will
be disallowed.
:param onebased: assume the SQL representation of this value is
one-based; that is, the first index in SQL is 1, not zero.
"""
if mutable:
super().__init__(self.fget, self.fset, self.fdel, self.expr)
else:
super().__init__(self.fget, None, None, self.expr)
self.attr_name = attr_name
self.index = index
self.default = default
is_numeric = isinstance(index, int)
onebased = is_numeric and onebased
if datatype is not None:
self.datatype = datatype
else:
if is_numeric:
self.datatype = lambda: [None for x in range(index + 1)]
else:
self.datatype = dict
self.onebased = onebased
def _fget_default(self, err=None):
if self.default == self._NO_DEFAULT_ARGUMENT:
raise AttributeError(self.attr_name) from err
else:
return self.default
def fget(self, instance):
attr_name = self.attr_name
column_value = getattr(instance, attr_name)
if column_value is None:
return self._fget_default()
try:
value = column_value[self.index]
except (KeyError, IndexError) as err:
return self._fget_default(err)
else:
return value
def fset(self, instance, value):
attr_name = self.attr_name
column_value = getattr(instance, attr_name, None)
if column_value is None:
column_value = self.datatype()
setattr(instance, attr_name, column_value)
column_value[self.index] = value
setattr(instance, attr_name, column_value)
if attr_name in inspect(instance).mapper.attrs:
flag_modified(instance, attr_name)
def fdel(self, instance):
attr_name = self.attr_name
column_value = getattr(instance, attr_name)
if column_value is None:
raise AttributeError(self.attr_name)
try:
del column_value[self.index]
except KeyError as err:
raise AttributeError(self.attr_name) from err
else:
setattr(instance, attr_name, column_value)
flag_modified(instance, attr_name)
def expr(self, model):
column = getattr(model, self.attr_name)
index = self.index
if self.onebased:
index += 1
return column[index]

View File

@ -0,0 +1,450 @@
# ext/instrumentation.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""Extensible class instrumentation.
The :mod:`sqlalchemy.ext.instrumentation` package provides for alternate
systems of class instrumentation within the ORM. Class instrumentation
refers to how the ORM places attributes on the class which maintain
data and track changes to that data, as well as event hooks installed
on the class.
.. note::
The extension package is provided for the benefit of integration
with other object management packages, which already perform
their own instrumentation. It is not intended for general use.
For examples of how the instrumentation extension is used,
see the example :ref:`examples_instrumentation`.
"""
import weakref
from .. import util
from ..orm import attributes
from ..orm import base as orm_base
from ..orm import collections
from ..orm import exc as orm_exc
from ..orm import instrumentation as orm_instrumentation
from ..orm import util as orm_util
from ..orm.instrumentation import _default_dict_getter
from ..orm.instrumentation import _default_manager_getter
from ..orm.instrumentation import _default_opt_manager_getter
from ..orm.instrumentation import _default_state_getter
from ..orm.instrumentation import ClassManager
from ..orm.instrumentation import InstrumentationFactory
INSTRUMENTATION_MANAGER = "__sa_instrumentation_manager__"
"""Attribute, elects custom instrumentation when present on a mapped class.
Allows a class to specify a slightly or wildly different technique for
tracking changes made to mapped attributes and collections.
Only one instrumentation implementation is allowed in a given object
inheritance hierarchy.
The value of this attribute must be a callable and will be passed a class
object. The callable must return one of:
- An instance of an :class:`.InstrumentationManager` or subclass
- An object implementing all or some of InstrumentationManager (TODO)
- A dictionary of callables, implementing all or some of the above (TODO)
- An instance of a :class:`.ClassManager` or subclass
This attribute is consulted by SQLAlchemy instrumentation
resolution, once the :mod:`sqlalchemy.ext.instrumentation` module
has been imported. If custom finders are installed in the global
instrumentation_finders list, they may or may not choose to honor this
attribute.
"""
def find_native_user_instrumentation_hook(cls):
"""Find user-specified instrumentation management for a class."""
return getattr(cls, INSTRUMENTATION_MANAGER, None)
instrumentation_finders = [find_native_user_instrumentation_hook]
"""An extensible sequence of callables which return instrumentation
implementations
When a class is registered, each callable will be passed a class object.
If None is returned, the
next finder in the sequence is consulted. Otherwise the return must be an
instrumentation factory that follows the same guidelines as
sqlalchemy.ext.instrumentation.INSTRUMENTATION_MANAGER.
By default, the only finder is find_native_user_instrumentation_hook, which
searches for INSTRUMENTATION_MANAGER. If all finders return None, standard
ClassManager instrumentation is used.
"""
class ExtendedInstrumentationRegistry(InstrumentationFactory):
"""Extends :class:`.InstrumentationFactory` with additional
bookkeeping, to accommodate multiple types of
class managers.
"""
_manager_finders = weakref.WeakKeyDictionary()
_state_finders = weakref.WeakKeyDictionary()
_dict_finders = weakref.WeakKeyDictionary()
_extended = False
def _locate_extended_factory(self, class_):
for finder in instrumentation_finders:
factory = finder(class_)
if factory is not None:
manager = self._extended_class_manager(class_, factory)
return manager, factory
else:
return None, None
def _check_conflicts(self, class_, factory):
existing_factories = self._collect_management_factories_for(
class_
).difference([factory])
if existing_factories:
raise TypeError(
"multiple instrumentation implementations specified "
"in %s inheritance hierarchy: %r"
% (class_.__name__, list(existing_factories))
)
def _extended_class_manager(self, class_, factory):
manager = factory(class_)
if not isinstance(manager, ClassManager):
manager = _ClassInstrumentationAdapter(class_, manager)
if factory != ClassManager and not self._extended:
# somebody invoked a custom ClassManager.
# reinstall global "getter" functions with the more
# expensive ones.
self._extended = True
_install_instrumented_lookups()
self._manager_finders[class_] = manager.manager_getter()
self._state_finders[class_] = manager.state_getter()
self._dict_finders[class_] = manager.dict_getter()
return manager
def _collect_management_factories_for(self, cls):
"""Return a collection of factories in play or specified for a
hierarchy.
Traverses the entire inheritance graph of a cls and returns a
collection of instrumentation factories for those classes. Factories
are extracted from active ClassManagers, if available, otherwise
instrumentation_finders is consulted.
"""
hierarchy = util.class_hierarchy(cls)
factories = set()
for member in hierarchy:
manager = self.opt_manager_of_class(member)
if manager is not None:
factories.add(manager.factory)
else:
for finder in instrumentation_finders:
factory = finder(member)
if factory is not None:
break
else:
factory = None
factories.add(factory)
factories.discard(None)
return factories
def unregister(self, class_):
super().unregister(class_)
if class_ in self._manager_finders:
del self._manager_finders[class_]
del self._state_finders[class_]
del self._dict_finders[class_]
def opt_manager_of_class(self, cls):
try:
finder = self._manager_finders.get(
cls, _default_opt_manager_getter
)
except TypeError:
# due to weakref lookup on invalid object
return None
else:
return finder(cls)
def manager_of_class(self, cls):
try:
finder = self._manager_finders.get(cls, _default_manager_getter)
except TypeError:
# due to weakref lookup on invalid object
raise orm_exc.UnmappedClassError(
cls, f"Can't locate an instrumentation manager for class {cls}"
)
else:
manager = finder(cls)
if manager is None:
raise orm_exc.UnmappedClassError(
cls,
f"Can't locate an instrumentation manager for class {cls}",
)
return manager
def state_of(self, instance):
if instance is None:
raise AttributeError("None has no persistent state.")
return self._state_finders.get(
instance.__class__, _default_state_getter
)(instance)
def dict_of(self, instance):
if instance is None:
raise AttributeError("None has no persistent state.")
return self._dict_finders.get(
instance.__class__, _default_dict_getter
)(instance)
orm_instrumentation._instrumentation_factory = _instrumentation_factory = (
ExtendedInstrumentationRegistry()
)
orm_instrumentation.instrumentation_finders = instrumentation_finders
class InstrumentationManager:
"""User-defined class instrumentation extension.
:class:`.InstrumentationManager` can be subclassed in order
to change
how class instrumentation proceeds. This class exists for
the purposes of integration with other object management
frameworks which would like to entirely modify the
instrumentation methodology of the ORM, and is not intended
for regular usage. For interception of class instrumentation
events, see :class:`.InstrumentationEvents`.
The API for this class should be considered as semi-stable,
and may change slightly with new releases.
"""
# r4361 added a mandatory (cls) constructor to this interface.
# given that, perhaps class_ should be dropped from all of these
# signatures.
def __init__(self, class_):
pass
def manage(self, class_, manager):
setattr(class_, "_default_class_manager", manager)
def unregister(self, class_, manager):
delattr(class_, "_default_class_manager")
def manager_getter(self, class_):
def get(cls):
return cls._default_class_manager
return get
def instrument_attribute(self, class_, key, inst):
pass
def post_configure_attribute(self, class_, key, inst):
pass
def install_descriptor(self, class_, key, inst):
setattr(class_, key, inst)
def uninstall_descriptor(self, class_, key):
delattr(class_, key)
def install_member(self, class_, key, implementation):
setattr(class_, key, implementation)
def uninstall_member(self, class_, key):
delattr(class_, key)
def instrument_collection_class(self, class_, key, collection_class):
return collections.prepare_instrumentation(collection_class)
def get_instance_dict(self, class_, instance):
return instance.__dict__
def initialize_instance_dict(self, class_, instance):
pass
def install_state(self, class_, instance, state):
setattr(instance, "_default_state", state)
def remove_state(self, class_, instance):
delattr(instance, "_default_state")
def state_getter(self, class_):
return lambda instance: getattr(instance, "_default_state")
def dict_getter(self, class_):
return lambda inst: self.get_instance_dict(class_, inst)
class _ClassInstrumentationAdapter(ClassManager):
"""Adapts a user-defined InstrumentationManager to a ClassManager."""
def __init__(self, class_, override):
self._adapted = override
self._get_state = self._adapted.state_getter(class_)
self._get_dict = self._adapted.dict_getter(class_)
ClassManager.__init__(self, class_)
def manage(self):
self._adapted.manage(self.class_, self)
def unregister(self):
self._adapted.unregister(self.class_, self)
def manager_getter(self):
return self._adapted.manager_getter(self.class_)
def instrument_attribute(self, key, inst, propagated=False):
ClassManager.instrument_attribute(self, key, inst, propagated)
if not propagated:
self._adapted.instrument_attribute(self.class_, key, inst)
def post_configure_attribute(self, key):
super().post_configure_attribute(key)
self._adapted.post_configure_attribute(self.class_, key, self[key])
def install_descriptor(self, key, inst):
self._adapted.install_descriptor(self.class_, key, inst)
def uninstall_descriptor(self, key):
self._adapted.uninstall_descriptor(self.class_, key)
def install_member(self, key, implementation):
self._adapted.install_member(self.class_, key, implementation)
def uninstall_member(self, key):
self._adapted.uninstall_member(self.class_, key)
def instrument_collection_class(self, key, collection_class):
return self._adapted.instrument_collection_class(
self.class_, key, collection_class
)
def initialize_collection(self, key, state, factory):
delegate = getattr(self._adapted, "initialize_collection", None)
if delegate:
return delegate(key, state, factory)
else:
return ClassManager.initialize_collection(
self, key, state, factory
)
def new_instance(self, state=None):
instance = self.class_.__new__(self.class_)
self.setup_instance(instance, state)
return instance
def _new_state_if_none(self, instance):
"""Install a default InstanceState if none is present.
A private convenience method used by the __init__ decorator.
"""
if self.has_state(instance):
return False
else:
return self.setup_instance(instance)
def setup_instance(self, instance, state=None):
self._adapted.initialize_instance_dict(self.class_, instance)
if state is None:
state = self._state_constructor(instance, self)
# the given instance is assumed to have no state
self._adapted.install_state(self.class_, instance, state)
return state
def teardown_instance(self, instance):
self._adapted.remove_state(self.class_, instance)
def has_state(self, instance):
try:
self._get_state(instance)
except orm_exc.NO_STATE:
return False
else:
return True
def state_getter(self):
return self._get_state
def dict_getter(self):
return self._get_dict
def _install_instrumented_lookups():
"""Replace global class/object management functions
with ExtendedInstrumentationRegistry implementations, which
allow multiple types of class managers to be present,
at the cost of performance.
This function is called only by ExtendedInstrumentationRegistry
and unit tests specific to this behavior.
The _reinstall_default_lookups() function can be called
after this one to re-establish the default functions.
"""
_install_lookups(
dict(
instance_state=_instrumentation_factory.state_of,
instance_dict=_instrumentation_factory.dict_of,
manager_of_class=_instrumentation_factory.manager_of_class,
opt_manager_of_class=_instrumentation_factory.opt_manager_of_class,
)
)
def _reinstall_default_lookups():
"""Restore simplified lookups."""
_install_lookups(
dict(
instance_state=_default_state_getter,
instance_dict=_default_dict_getter,
manager_of_class=_default_manager_getter,
opt_manager_of_class=_default_opt_manager_getter,
)
)
_instrumentation_factory._extended = False
def _install_lookups(lookups):
global instance_state, instance_dict
global manager_of_class, opt_manager_of_class
instance_state = lookups["instance_state"]
instance_dict = lookups["instance_dict"]
manager_of_class = lookups["manager_of_class"]
opt_manager_of_class = lookups["opt_manager_of_class"]
orm_base.instance_state = attributes.instance_state = (
orm_instrumentation.instance_state
) = instance_state
orm_base.instance_dict = attributes.instance_dict = (
orm_instrumentation.instance_dict
) = instance_dict
orm_base.manager_of_class = attributes.manager_of_class = (
orm_instrumentation.manager_of_class
) = manager_of_class
orm_base.opt_manager_of_class = orm_util.opt_manager_of_class = (
attributes.opt_manager_of_class
) = orm_instrumentation.opt_manager_of_class = opt_manager_of_class

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,6 @@
# ext/mypy/__init__.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php

View File

@ -0,0 +1,324 @@
# ext/mypy/apply.py
# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import List
from typing import Optional
from typing import Union
from mypy.nodes import ARG_NAMED_OPT
from mypy.nodes import Argument
from mypy.nodes import AssignmentStmt
from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import MDEF
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import RefExpr
from mypy.nodes import StrExpr
from mypy.nodes import SymbolTableNode
from mypy.nodes import TempNode
from mypy.nodes import TypeInfo
from mypy.nodes import Var
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import add_method_to_class
from mypy.types import AnyType
from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import NoneTyp
from mypy.types import ProperType
from mypy.types import TypeOfAny
from mypy.types import UnboundType
from mypy.types import UnionType
from . import infer
from . import util
from .names import expr_to_mapped_constructor
from .names import NAMED_TYPE_SQLA_MAPPED
def apply_mypy_mapped_attr(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
item: Union[NameExpr, StrExpr],
attributes: List[util.SQLAlchemyAttribute],
) -> None:
if isinstance(item, NameExpr):
name = item.name
elif isinstance(item, StrExpr):
name = item.value
else:
return None
for stmt in cls.defs.body:
if (
isinstance(stmt, AssignmentStmt)
and isinstance(stmt.lvalues[0], NameExpr)
and stmt.lvalues[0].name == name
):
break
else:
util.fail(api, f"Can't find mapped attribute {name}", cls)
return None
if stmt.type is None:
util.fail(
api,
"Statement linked from _mypy_mapped_attrs has no "
"typing information",
stmt,
)
return None
left_hand_explicit_type = get_proper_type(stmt.type)
assert isinstance(
left_hand_explicit_type, (Instance, UnionType, UnboundType)
)
attributes.append(
util.SQLAlchemyAttribute(
name=name,
line=item.line,
column=item.column,
typ=left_hand_explicit_type,
info=cls.info,
)
)
apply_type_to_mapped_statement(
api, stmt, stmt.lvalues[0], left_hand_explicit_type, None
)
def re_apply_declarative_assignments(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""For multiple class passes, re-apply our left-hand side types as mypy
seems to reset them in place.
"""
mapped_attr_lookup = {attr.name: attr for attr in attributes}
update_cls_metadata = False
for stmt in cls.defs.body:
# for a re-apply, all of our statements are AssignmentStmt;
# @declared_attr calls will have been converted and this
# currently seems to be preserved by mypy (but who knows if this
# will change).
if (
isinstance(stmt, AssignmentStmt)
and isinstance(stmt.lvalues[0], NameExpr)
and stmt.lvalues[0].name in mapped_attr_lookup
and isinstance(stmt.lvalues[0].node, Var)
):
left_node = stmt.lvalues[0].node
python_type_for_type = mapped_attr_lookup[
stmt.lvalues[0].name
].type
left_node_proper_type = get_proper_type(left_node.type)
# if we have scanned an UnboundType and now there's a more
# specific type than UnboundType, call the re-scan so we
# can get that set up correctly
if (
isinstance(python_type_for_type, UnboundType)
and not isinstance(left_node_proper_type, UnboundType)
and (
isinstance(stmt.rvalue, CallExpr)
and isinstance(stmt.rvalue.callee, MemberExpr)
and isinstance(stmt.rvalue.callee.expr, NameExpr)
and stmt.rvalue.callee.expr.node is not None
and stmt.rvalue.callee.expr.node.fullname
== NAMED_TYPE_SQLA_MAPPED
and stmt.rvalue.callee.name == "_empty_constructor"
and isinstance(stmt.rvalue.args[0], CallExpr)
and isinstance(stmt.rvalue.args[0].callee, RefExpr)
)
):
new_python_type_for_type = (
infer.infer_type_from_right_hand_nameexpr(
api,
stmt,
left_node,
left_node_proper_type,
stmt.rvalue.args[0].callee,
)
)
if new_python_type_for_type is not None and not isinstance(
new_python_type_for_type, UnboundType
):
python_type_for_type = new_python_type_for_type
# update the SQLAlchemyAttribute with the better
# information
mapped_attr_lookup[stmt.lvalues[0].name].type = (
python_type_for_type
)
update_cls_metadata = True
if (
not isinstance(left_node.type, Instance)
or left_node.type.type.fullname != NAMED_TYPE_SQLA_MAPPED
):
assert python_type_for_type is not None
left_node.type = api.named_type(
NAMED_TYPE_SQLA_MAPPED, [python_type_for_type]
)
if update_cls_metadata:
util.set_mapped_attributes(cls.info, attributes)
def apply_type_to_mapped_statement(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
lvalue: NameExpr,
left_hand_explicit_type: Optional[ProperType],
python_type_for_type: Optional[ProperType],
) -> None:
"""Apply the Mapped[<type>] annotation and right hand object to a
declarative assignment statement.
This converts a Python declarative class statement such as::
class User(Base):
# ...
attrname = Column(Integer)
To one that describes the final Python behavior to Mypy::
... format: off
class User(Base):
# ...
attrname : Mapped[Optional[int]] = <meaningless temp node>
... format: on
"""
left_node = lvalue.node
assert isinstance(left_node, Var)
# to be completely honest I have no idea what the difference between
# left_node.type and stmt.type is, what it means if these are different
# vs. the same, why in order to get tests to pass I have to assign
# to stmt.type for the second case and not the first. this is complete
# trying every combination until it works stuff.
if left_hand_explicit_type is not None:
lvalue.is_inferred_def = False
left_node.type = api.named_type(
NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
)
else:
lvalue.is_inferred_def = False
left_node.type = api.named_type(
NAMED_TYPE_SQLA_MAPPED,
(
[AnyType(TypeOfAny.special_form)]
if python_type_for_type is None
else [python_type_for_type]
),
)
# so to have it skip the right side totally, we can do this:
# stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
# however, if we instead manufacture a new node that uses the old
# one, then we can still get type checking for the call itself,
# e.g. the Column, relationship() call, etc.
# rewrite the node as:
# <attr> : Mapped[<typ>] =
# _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
# the original right-hand side is maintained so it gets type checked
# internally
stmt.rvalue = expr_to_mapped_constructor(stmt.rvalue)
if stmt.type is not None and python_type_for_type is not None:
stmt.type = python_type_for_type
def add_additional_orm_attributes(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""Apply __init__, __table__ and other attributes to the mapped class."""
info = util.info_for_cls(cls, api)
if info is None:
return
is_base = util.get_is_base(info)
if "__init__" not in info.names and not is_base:
mapped_attr_names = {attr.name: attr.type for attr in attributes}
for base in info.mro[1:-1]:
if "sqlalchemy" not in info.metadata:
continue
base_cls_attributes = util.get_mapped_attributes(base, api)
if base_cls_attributes is None:
continue
for attr in base_cls_attributes:
mapped_attr_names.setdefault(attr.name, attr.type)
arguments = []
for name, typ in mapped_attr_names.items():
if typ is None:
typ = AnyType(TypeOfAny.special_form)
arguments.append(
Argument(
variable=Var(name, typ),
type_annotation=typ,
initializer=TempNode(typ),
kind=ARG_NAMED_OPT,
)
)
add_method_to_class(api, cls, "__init__", arguments, NoneTyp())
if "__table__" not in info.names and util.get_has_table(info):
_apply_placeholder_attr_to_class(
api, cls, "sqlalchemy.sql.schema.Table", "__table__"
)
if not is_base:
_apply_placeholder_attr_to_class(
api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__"
)
def _apply_placeholder_attr_to_class(
api: SemanticAnalyzerPluginInterface,
cls: ClassDef,
qualified_name: str,
attrname: str,
) -> None:
sym = api.lookup_fully_qualified_or_none(qualified_name)
if sym:
assert isinstance(sym.node, TypeInfo)
type_: ProperType = Instance(sym.node, [])
else:
type_ = AnyType(TypeOfAny.special_form)
var = Var(attrname)
var._fullname = cls.fullname + "." + attrname
var.info = cls.info
var.type = type_
cls.info.names[attrname] = SymbolTableNode(MDEF, var)

View File

@ -0,0 +1,515 @@
# ext/mypy/decl_class.py
# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import List
from typing import Optional
from typing import Union
from mypy.nodes import AssignmentStmt
from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import Decorator
from mypy.nodes import LambdaExpr
from mypy.nodes import ListExpr
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import PlaceholderNode
from mypy.nodes import RefExpr
from mypy.nodes import StrExpr
from mypy.nodes import SymbolNode
from mypy.nodes import SymbolTableNode
from mypy.nodes import TempNode
from mypy.nodes import TypeInfo
from mypy.nodes import Var
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.types import AnyType
from mypy.types import CallableType
from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import NoneType
from mypy.types import ProperType
from mypy.types import Type
from mypy.types import TypeOfAny
from mypy.types import UnboundType
from mypy.types import UnionType
from . import apply
from . import infer
from . import names
from . import util
def scan_declarative_assignments_and_apply_types(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
is_mixin_scan: bool = False,
) -> Optional[List[util.SQLAlchemyAttribute]]:
info = util.info_for_cls(cls, api)
if info is None:
# this can occur during cached passes
return None
elif cls.fullname.startswith("builtins"):
return None
mapped_attributes: Optional[List[util.SQLAlchemyAttribute]] = (
util.get_mapped_attributes(info, api)
)
# used by assign.add_additional_orm_attributes among others
util.establish_as_sqlalchemy(info)
if mapped_attributes is not None:
# ensure that a class that's mapped is always picked up by
# its mapped() decorator or declarative metaclass before
# it would be detected as an unmapped mixin class
if not is_mixin_scan:
# mypy can call us more than once. it then *may* have reset the
# left hand side of everything, but not the right that we removed,
# removing our ability to re-scan. but we have the types
# here, so lets re-apply them, or if we have an UnboundType,
# we can re-scan
apply.re_apply_declarative_assignments(cls, api, mapped_attributes)
return mapped_attributes
mapped_attributes = []
if not cls.defs.body:
# when we get a mixin class from another file, the body is
# empty (!) but the names are in the symbol table. so use that.
for sym_name, sym in info.names.items():
_scan_symbol_table_entry(
cls, api, sym_name, sym, mapped_attributes
)
else:
for stmt in util.flatten_typechecking(cls.defs.body):
if isinstance(stmt, AssignmentStmt):
_scan_declarative_assignment_stmt(
cls, api, stmt, mapped_attributes
)
elif isinstance(stmt, Decorator):
_scan_declarative_decorator_stmt(
cls, api, stmt, mapped_attributes
)
_scan_for_mapped_bases(cls, api)
if not is_mixin_scan:
apply.add_additional_orm_attributes(cls, api, mapped_attributes)
util.set_mapped_attributes(info, mapped_attributes)
return mapped_attributes
def _scan_symbol_table_entry(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
name: str,
value: SymbolTableNode,
attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""Extract mapping information from a SymbolTableNode that's in the
type.names dictionary.
"""
value_type = get_proper_type(value.type)
if not isinstance(value_type, Instance):
return
left_hand_explicit_type = None
type_id = names.type_id_for_named_node(value_type.type)
# type_id = names._type_id_for_unbound_type(value.type.type, cls, api)
err = False
# TODO: this is nearly the same logic as that of
# _scan_declarative_decorator_stmt, likely can be merged
if type_id in {
names.MAPPED,
names.RELATIONSHIP,
names.COMPOSITE_PROPERTY,
names.MAPPER_PROPERTY,
names.SYNONYM_PROPERTY,
names.COLUMN_PROPERTY,
}:
if value_type.args:
left_hand_explicit_type = get_proper_type(value_type.args[0])
else:
err = True
elif type_id is names.COLUMN:
if not value_type.args:
err = True
else:
typeengine_arg: Union[ProperType, TypeInfo] = get_proper_type(
value_type.args[0]
)
if isinstance(typeengine_arg, Instance):
typeengine_arg = typeengine_arg.type
if isinstance(typeengine_arg, (UnboundType, TypeInfo)):
sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
if sym is not None and isinstance(sym.node, TypeInfo):
if names.has_base_type_id(sym.node, names.TYPEENGINE):
left_hand_explicit_type = UnionType(
[
infer.extract_python_type_from_typeengine(
api, sym.node, []
),
NoneType(),
]
)
else:
util.fail(
api,
"Column type should be a TypeEngine "
"subclass not '{}'".format(sym.node.fullname),
value_type,
)
if err:
msg = (
"Can't infer type from attribute {} on class {}. "
"please specify a return type from this function that is "
"one of: Mapped[<python type>], relationship[<target class>], "
"Column[<TypeEngine>], MapperProperty[<python type>]"
)
util.fail(api, msg.format(name, cls.name), cls)
left_hand_explicit_type = AnyType(TypeOfAny.special_form)
if left_hand_explicit_type is not None:
assert value.node is not None
attributes.append(
util.SQLAlchemyAttribute(
name=name,
line=value.node.line,
column=value.node.column,
typ=left_hand_explicit_type,
info=cls.info,
)
)
def _scan_declarative_decorator_stmt(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
stmt: Decorator,
attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""Extract mapping information from a @declared_attr in a declarative
class.
E.g.::
@reg.mapped
class MyClass:
# ...
@declared_attr
def updated_at(cls) -> Column[DateTime]:
return Column(DateTime)
Will resolve in mypy as::
@reg.mapped
class MyClass:
# ...
updated_at: Mapped[Optional[datetime.datetime]]
"""
for dec in stmt.decorators:
if (
isinstance(dec, (NameExpr, MemberExpr, SymbolNode))
and names.type_id_for_named_node(dec) is names.DECLARED_ATTR
):
break
else:
return
dec_index = cls.defs.body.index(stmt)
left_hand_explicit_type: Optional[ProperType] = None
if util.name_is_dunder(stmt.name):
# for dunder names like __table_args__, __tablename__,
# __mapper_args__ etc., rewrite these as simple assignment
# statements; otherwise mypy doesn't like if the decorated
# function has an annotation like ``cls: Type[Foo]`` because
# it isn't @classmethod
any_ = AnyType(TypeOfAny.special_form)
left_node = NameExpr(stmt.var.name)
left_node.node = stmt.var
new_stmt = AssignmentStmt([left_node], TempNode(any_))
new_stmt.type = left_node.node.type
cls.defs.body[dec_index] = new_stmt
return
elif isinstance(stmt.func.type, CallableType):
func_type = stmt.func.type.ret_type
if isinstance(func_type, UnboundType):
type_id = names.type_id_for_unbound_type(func_type, cls, api)
else:
# this does not seem to occur unless the type argument is
# incorrect
return
if (
type_id
in {
names.MAPPED,
names.RELATIONSHIP,
names.COMPOSITE_PROPERTY,
names.MAPPER_PROPERTY,
names.SYNONYM_PROPERTY,
names.COLUMN_PROPERTY,
}
and func_type.args
):
left_hand_explicit_type = get_proper_type(func_type.args[0])
elif type_id is names.COLUMN and func_type.args:
typeengine_arg = func_type.args[0]
if isinstance(typeengine_arg, UnboundType):
sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
if sym is not None and isinstance(sym.node, TypeInfo):
if names.has_base_type_id(sym.node, names.TYPEENGINE):
left_hand_explicit_type = UnionType(
[
infer.extract_python_type_from_typeengine(
api, sym.node, []
),
NoneType(),
]
)
else:
util.fail(
api,
"Column type should be a TypeEngine "
"subclass not '{}'".format(sym.node.fullname),
func_type,
)
if left_hand_explicit_type is None:
# no type on the decorated function. our option here is to
# dig into the function body and get the return type, but they
# should just have an annotation.
msg = (
"Can't infer type from @declared_attr on function '{}'; "
"please specify a return type from this function that is "
"one of: Mapped[<python type>], relationship[<target class>], "
"Column[<TypeEngine>], MapperProperty[<python type>]"
)
util.fail(api, msg.format(stmt.var.name), stmt)
left_hand_explicit_type = AnyType(TypeOfAny.special_form)
left_node = NameExpr(stmt.var.name)
left_node.node = stmt.var
# totally feeling around in the dark here as I don't totally understand
# the significance of UnboundType. It seems to be something that is
# not going to do what's expected when it is applied as the type of
# an AssignmentStatement. So do a feeling-around-in-the-dark version
# of converting it to the regular Instance/TypeInfo/UnionType structures
# we see everywhere else.
if isinstance(left_hand_explicit_type, UnboundType):
left_hand_explicit_type = get_proper_type(
util.unbound_to_instance(api, left_hand_explicit_type)
)
left_node.node.type = api.named_type(
names.NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
)
# this will ignore the rvalue entirely
# rvalue = TempNode(AnyType(TypeOfAny.special_form))
# rewrite the node as:
# <attr> : Mapped[<typ>] =
# _sa_Mapped._empty_constructor(lambda: <function body>)
# the function body is maintained so it gets type checked internally
rvalue = names.expr_to_mapped_constructor(
LambdaExpr(stmt.func.arguments, stmt.func.body)
)
new_stmt = AssignmentStmt([left_node], rvalue)
new_stmt.type = left_node.node.type
attributes.append(
util.SQLAlchemyAttribute(
name=left_node.name,
line=stmt.line,
column=stmt.column,
typ=left_hand_explicit_type,
info=cls.info,
)
)
cls.defs.body[dec_index] = new_stmt
def _scan_declarative_assignment_stmt(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""Extract mapping information from an assignment statement in a
declarative class.
"""
lvalue = stmt.lvalues[0]
if not isinstance(lvalue, NameExpr):
return
sym = cls.info.names.get(lvalue.name)
# this establishes that semantic analysis has taken place, which
# means the nodes are populated and we are called from an appropriate
# hook.
assert sym is not None
node = sym.node
if isinstance(node, PlaceholderNode):
return
assert node is lvalue.node
assert isinstance(node, Var)
if node.name == "__abstract__":
if api.parse_bool(stmt.rvalue) is True:
util.set_is_base(cls.info)
return
elif node.name == "__tablename__":
util.set_has_table(cls.info)
elif node.name.startswith("__"):
return
elif node.name == "_mypy_mapped_attrs":
if not isinstance(stmt.rvalue, ListExpr):
util.fail(api, "_mypy_mapped_attrs is expected to be a list", stmt)
else:
for item in stmt.rvalue.items:
if isinstance(item, (NameExpr, StrExpr)):
apply.apply_mypy_mapped_attr(cls, api, item, attributes)
left_hand_mapped_type: Optional[Type] = None
left_hand_explicit_type: Optional[ProperType] = None
if node.is_inferred or node.type is None:
if isinstance(stmt.type, UnboundType):
# look for an explicit Mapped[] type annotation on the left
# side with nothing on the right
# print(stmt.type)
# Mapped?[Optional?[A?]]
left_hand_explicit_type = stmt.type
if stmt.type.name == "Mapped":
mapped_sym = api.lookup_qualified("Mapped", cls)
if (
mapped_sym is not None
and mapped_sym.node is not None
and names.type_id_for_named_node(mapped_sym.node)
is names.MAPPED
):
left_hand_explicit_type = get_proper_type(
stmt.type.args[0]
)
left_hand_mapped_type = stmt.type
# TODO: do we need to convert from unbound for this case?
# left_hand_explicit_type = util._unbound_to_instance(
# api, left_hand_explicit_type
# )
else:
node_type = get_proper_type(node.type)
if (
isinstance(node_type, Instance)
and names.type_id_for_named_node(node_type.type) is names.MAPPED
):
# print(node.type)
# sqlalchemy.orm.attributes.Mapped[<python type>]
left_hand_explicit_type = get_proper_type(node_type.args[0])
left_hand_mapped_type = node_type
else:
# print(node.type)
# <python type>
left_hand_explicit_type = node_type
left_hand_mapped_type = None
if isinstance(stmt.rvalue, TempNode) and left_hand_mapped_type is not None:
# annotation without assignment and Mapped is present
# as type annotation
# equivalent to using _infer_type_from_left_hand_type_only.
python_type_for_type = left_hand_explicit_type
elif isinstance(stmt.rvalue, CallExpr) and isinstance(
stmt.rvalue.callee, RefExpr
):
python_type_for_type = infer.infer_type_from_right_hand_nameexpr(
api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee
)
if python_type_for_type is None:
return
else:
return
assert python_type_for_type is not None
attributes.append(
util.SQLAlchemyAttribute(
name=node.name,
line=stmt.line,
column=stmt.column,
typ=python_type_for_type,
info=cls.info,
)
)
apply.apply_type_to_mapped_statement(
api,
stmt,
lvalue,
left_hand_explicit_type,
python_type_for_type,
)
def _scan_for_mapped_bases(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
) -> None:
"""Given a class, iterate through its superclass hierarchy to find
all other classes that are considered as ORM-significant.
Locates non-mapped mixins and scans them for mapped attributes to be
applied to subclasses.
"""
info = util.info_for_cls(cls, api)
if info is None:
return
for base_info in info.mro[1:-1]:
if base_info.fullname.startswith("builtins"):
continue
# scan each base for mapped attributes. if they are not already
# scanned (but have all their type info), that means they are unmapped
# mixins
scan_declarative_assignments_and_apply_types(
base_info.defn, api, is_mixin_scan=True
)

View File

@ -0,0 +1,590 @@
# ext/mypy/infer.py
# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import Optional
from typing import Sequence
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import AssignmentStmt
from mypy.nodes import CallExpr
from mypy.nodes import Expression
from mypy.nodes import FuncDef
from mypy.nodes import LambdaExpr
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import RefExpr
from mypy.nodes import StrExpr
from mypy.nodes import TypeInfo
from mypy.nodes import Var
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.subtypes import is_subtype
from mypy.types import AnyType
from mypy.types import CallableType
from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import NoneType
from mypy.types import ProperType
from mypy.types import TypeOfAny
from mypy.types import UnionType
from . import names
from . import util
def infer_type_from_right_hand_nameexpr(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
left_hand_explicit_type: Optional[ProperType],
infer_from_right_side: RefExpr,
) -> Optional[ProperType]:
type_id = names.type_id_for_callee(infer_from_right_side)
if type_id is None:
return None
elif type_id is names.MAPPED:
python_type_for_type = _infer_type_from_mapped(
api, stmt, node, left_hand_explicit_type, infer_from_right_side
)
elif type_id is names.COLUMN:
python_type_for_type = _infer_type_from_decl_column(
api, stmt, node, left_hand_explicit_type
)
elif type_id is names.RELATIONSHIP:
python_type_for_type = _infer_type_from_relationship(
api, stmt, node, left_hand_explicit_type
)
elif type_id is names.COLUMN_PROPERTY:
python_type_for_type = _infer_type_from_decl_column_property(
api, stmt, node, left_hand_explicit_type
)
elif type_id is names.SYNONYM_PROPERTY:
python_type_for_type = infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
elif type_id is names.COMPOSITE_PROPERTY:
python_type_for_type = _infer_type_from_decl_composite_property(
api, stmt, node, left_hand_explicit_type
)
else:
return None
return python_type_for_type
def _infer_type_from_relationship(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
left_hand_explicit_type: Optional[ProperType],
) -> Optional[ProperType]:
"""Infer the type of mapping from a relationship.
E.g.::
@reg.mapped
class MyClass:
# ...
addresses = relationship(Address, uselist=True)
order: Mapped["Order"] = relationship("Order")
Will resolve in mypy as::
@reg.mapped
class MyClass:
# ...
addresses: Mapped[List[Address]]
order: Mapped["Order"]
"""
assert isinstance(stmt.rvalue, CallExpr)
target_cls_arg = stmt.rvalue.args[0]
python_type_for_type: Optional[ProperType] = None
if isinstance(target_cls_arg, NameExpr) and isinstance(
target_cls_arg.node, TypeInfo
):
# type
related_object_type = target_cls_arg.node
python_type_for_type = Instance(related_object_type, [])
# other cases not covered - an error message directs the user
# to set an explicit type annotation
#
# node.type == str, it's a string
# if isinstance(target_cls_arg, NameExpr) and isinstance(
# target_cls_arg.node, Var
# )
# points to a type
# isinstance(target_cls_arg, NameExpr) and isinstance(
# target_cls_arg.node, TypeAlias
# )
# string expression
# isinstance(target_cls_arg, StrExpr)
uselist_arg = util.get_callexpr_kwarg(stmt.rvalue, "uselist")
collection_cls_arg: Optional[Expression] = util.get_callexpr_kwarg(
stmt.rvalue, "collection_class"
)
type_is_a_collection = False
# this can be used to determine Optional for a many-to-one
# in the same way nullable=False could be used, if we start supporting
# that.
# innerjoin_arg = util.get_callexpr_kwarg(stmt.rvalue, "innerjoin")
if (
uselist_arg is not None
and api.parse_bool(uselist_arg) is True
and collection_cls_arg is None
):
type_is_a_collection = True
if python_type_for_type is not None:
python_type_for_type = api.named_type(
names.NAMED_TYPE_BUILTINS_LIST, [python_type_for_type]
)
elif (
uselist_arg is None or api.parse_bool(uselist_arg) is True
) and collection_cls_arg is not None:
type_is_a_collection = True
if isinstance(collection_cls_arg, CallExpr):
collection_cls_arg = collection_cls_arg.callee
if isinstance(collection_cls_arg, NameExpr) and isinstance(
collection_cls_arg.node, TypeInfo
):
if python_type_for_type is not None:
# this can still be overridden by the left hand side
# within _infer_Type_from_left_and_inferred_right
python_type_for_type = Instance(
collection_cls_arg.node, [python_type_for_type]
)
elif (
isinstance(collection_cls_arg, NameExpr)
and isinstance(collection_cls_arg.node, FuncDef)
and collection_cls_arg.node.type is not None
):
if python_type_for_type is not None:
# this can still be overridden by the left hand side
# within _infer_Type_from_left_and_inferred_right
# TODO: handle mypy.types.Overloaded
if isinstance(collection_cls_arg.node.type, CallableType):
rt = get_proper_type(collection_cls_arg.node.type.ret_type)
if isinstance(rt, CallableType):
callable_ret_type = get_proper_type(rt.ret_type)
if isinstance(callable_ret_type, Instance):
python_type_for_type = Instance(
callable_ret_type.type,
[python_type_for_type],
)
else:
util.fail(
api,
"Expected Python collection type for "
"collection_class parameter",
stmt.rvalue,
)
python_type_for_type = None
elif uselist_arg is not None and api.parse_bool(uselist_arg) is False:
if collection_cls_arg is not None:
util.fail(
api,
"Sending uselist=False and collection_class at the same time "
"does not make sense",
stmt.rvalue,
)
if python_type_for_type is not None:
python_type_for_type = UnionType(
[python_type_for_type, NoneType()]
)
else:
if left_hand_explicit_type is None:
msg = (
"Can't infer scalar or collection for ORM mapped expression "
"assigned to attribute '{}' if both 'uselist' and "
"'collection_class' arguments are absent from the "
"relationship(); please specify a "
"type annotation on the left hand side."
)
util.fail(api, msg.format(node.name), node)
if python_type_for_type is None:
return infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
elif left_hand_explicit_type is not None:
if type_is_a_collection:
assert isinstance(left_hand_explicit_type, Instance)
assert isinstance(python_type_for_type, Instance)
return _infer_collection_type_from_left_and_inferred_right(
api, node, left_hand_explicit_type, python_type_for_type
)
else:
return _infer_type_from_left_and_inferred_right(
api,
node,
left_hand_explicit_type,
python_type_for_type,
)
else:
return python_type_for_type
def _infer_type_from_decl_composite_property(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
left_hand_explicit_type: Optional[ProperType],
) -> Optional[ProperType]:
"""Infer the type of mapping from a Composite."""
assert isinstance(stmt.rvalue, CallExpr)
target_cls_arg = stmt.rvalue.args[0]
python_type_for_type = None
if isinstance(target_cls_arg, NameExpr) and isinstance(
target_cls_arg.node, TypeInfo
):
related_object_type = target_cls_arg.node
python_type_for_type = Instance(related_object_type, [])
else:
python_type_for_type = None
if python_type_for_type is None:
return infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
elif left_hand_explicit_type is not None:
return _infer_type_from_left_and_inferred_right(
api, node, left_hand_explicit_type, python_type_for_type
)
else:
return python_type_for_type
def _infer_type_from_mapped(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
left_hand_explicit_type: Optional[ProperType],
infer_from_right_side: RefExpr,
) -> Optional[ProperType]:
"""Infer the type of mapping from a right side expression
that returns Mapped.
"""
assert isinstance(stmt.rvalue, CallExpr)
# (Pdb) print(stmt.rvalue.callee)
# NameExpr(query_expression [sqlalchemy.orm._orm_constructors.query_expression]) # noqa: E501
# (Pdb) stmt.rvalue.callee.node
# <mypy.nodes.FuncDef object at 0x7f8d92fb5940>
# (Pdb) stmt.rvalue.callee.node.type
# def [_T] (default_expr: sqlalchemy.sql.elements.ColumnElement[_T`-1] =) -> sqlalchemy.orm.base.Mapped[_T`-1] # noqa: E501
# sqlalchemy.orm.base.Mapped[_T`-1]
# the_mapped_type = stmt.rvalue.callee.node.type.ret_type
# TODO: look at generic ref and either use that,
# or reconcile w/ what's present, etc.
the_mapped_type = util.type_for_callee(infer_from_right_side) # noqa
return infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
def _infer_type_from_decl_column_property(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
left_hand_explicit_type: Optional[ProperType],
) -> Optional[ProperType]:
"""Infer the type of mapping from a ColumnProperty.
This includes mappings against ``column_property()`` as well as the
``deferred()`` function.
"""
assert isinstance(stmt.rvalue, CallExpr)
if stmt.rvalue.args:
first_prop_arg = stmt.rvalue.args[0]
if isinstance(first_prop_arg, CallExpr):
type_id = names.type_id_for_callee(first_prop_arg.callee)
# look for column_property() / deferred() etc with Column as first
# argument
if type_id is names.COLUMN:
return _infer_type_from_decl_column(
api,
stmt,
node,
left_hand_explicit_type,
right_hand_expression=first_prop_arg,
)
if isinstance(stmt.rvalue, CallExpr):
type_id = names.type_id_for_callee(stmt.rvalue.callee)
# this is probably not strictly necessary as we have to use the left
# hand type for query expression in any case. any other no-arg
# column prop objects would go here also
if type_id is names.QUERY_EXPRESSION:
return _infer_type_from_decl_column(
api,
stmt,
node,
left_hand_explicit_type,
)
return infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
def _infer_type_from_decl_column(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
left_hand_explicit_type: Optional[ProperType],
right_hand_expression: Optional[CallExpr] = None,
) -> Optional[ProperType]:
"""Infer the type of mapping from a Column.
E.g.::
@reg.mapped
class MyClass:
# ...
a = Column(Integer)
b = Column("b", String)
c: Mapped[int] = Column(Integer)
d: bool = Column(Boolean)
Will resolve in MyPy as::
@reg.mapped
class MyClass:
# ...
a: Mapped[int]
b: Mapped[str]
c: Mapped[int]
d: Mapped[bool]
"""
assert isinstance(node, Var)
callee = None
if right_hand_expression is None:
if not isinstance(stmt.rvalue, CallExpr):
return None
right_hand_expression = stmt.rvalue
for column_arg in right_hand_expression.args[0:2]:
if isinstance(column_arg, CallExpr):
if isinstance(column_arg.callee, RefExpr):
# x = Column(String(50))
callee = column_arg.callee
type_args: Sequence[Expression] = column_arg.args
break
elif isinstance(column_arg, (NameExpr, MemberExpr)):
if isinstance(column_arg.node, TypeInfo):
# x = Column(String)
callee = column_arg
type_args = ()
break
else:
# x = Column(some_name, String), go to next argument
continue
elif isinstance(column_arg, (StrExpr,)):
# x = Column("name", String), go to next argument
continue
elif isinstance(column_arg, (LambdaExpr,)):
# x = Column("name", String, default=lambda: uuid.uuid4())
# go to next argument
continue
else:
assert False
if callee is None:
return None
if isinstance(callee.node, TypeInfo) and names.mro_has_id(
callee.node.mro, names.TYPEENGINE
):
python_type_for_type = extract_python_type_from_typeengine(
api, callee.node, type_args
)
if left_hand_explicit_type is not None:
return _infer_type_from_left_and_inferred_right(
api, node, left_hand_explicit_type, python_type_for_type
)
else:
return UnionType([python_type_for_type, NoneType()])
else:
# it's not TypeEngine, it's typically implicitly typed
# like ForeignKey. we can't infer from the right side.
return infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
def _infer_type_from_left_and_inferred_right(
api: SemanticAnalyzerPluginInterface,
node: Var,
left_hand_explicit_type: ProperType,
python_type_for_type: ProperType,
orig_left_hand_type: Optional[ProperType] = None,
orig_python_type_for_type: Optional[ProperType] = None,
) -> Optional[ProperType]:
"""Validate type when a left hand annotation is present and we also
could infer the right hand side::
attrname: SomeType = Column(SomeDBType)
"""
if orig_left_hand_type is None:
orig_left_hand_type = left_hand_explicit_type
if orig_python_type_for_type is None:
orig_python_type_for_type = python_type_for_type
if not is_subtype(left_hand_explicit_type, python_type_for_type):
effective_type = api.named_type(
names.NAMED_TYPE_SQLA_MAPPED, [orig_python_type_for_type]
)
msg = (
"Left hand assignment '{}: {}' not compatible "
"with ORM mapped expression of type {}"
)
util.fail(
api,
msg.format(
node.name,
util.format_type(orig_left_hand_type, api.options),
util.format_type(effective_type, api.options),
),
node,
)
return orig_left_hand_type
def _infer_collection_type_from_left_and_inferred_right(
api: SemanticAnalyzerPluginInterface,
node: Var,
left_hand_explicit_type: Instance,
python_type_for_type: Instance,
) -> Optional[ProperType]:
orig_left_hand_type = left_hand_explicit_type
orig_python_type_for_type = python_type_for_type
if left_hand_explicit_type.args:
left_hand_arg = get_proper_type(left_hand_explicit_type.args[0])
python_type_arg = get_proper_type(python_type_for_type.args[0])
else:
left_hand_arg = left_hand_explicit_type
python_type_arg = python_type_for_type
assert isinstance(left_hand_arg, (Instance, UnionType))
assert isinstance(python_type_arg, (Instance, UnionType))
return _infer_type_from_left_and_inferred_right(
api,
node,
left_hand_arg,
python_type_arg,
orig_left_hand_type=orig_left_hand_type,
orig_python_type_for_type=orig_python_type_for_type,
)
def infer_type_from_left_hand_type_only(
api: SemanticAnalyzerPluginInterface,
node: Var,
left_hand_explicit_type: Optional[ProperType],
) -> Optional[ProperType]:
"""Determine the type based on explicit annotation only.
if no annotation were present, note that we need one there to know
the type.
"""
if left_hand_explicit_type is None:
msg = (
"Can't infer type from ORM mapped expression "
"assigned to attribute '{}'; please specify a "
"Python type or "
"Mapped[<python type>] on the left hand side."
)
util.fail(api, msg.format(node.name), node)
return api.named_type(
names.NAMED_TYPE_SQLA_MAPPED, [AnyType(TypeOfAny.special_form)]
)
else:
# use type from the left hand side
return left_hand_explicit_type
def extract_python_type_from_typeengine(
api: SemanticAnalyzerPluginInterface,
node: TypeInfo,
type_args: Sequence[Expression],
) -> ProperType:
if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args:
first_arg = type_args[0]
if isinstance(first_arg, RefExpr) and isinstance(
first_arg.node, TypeInfo
):
for base_ in first_arg.node.mro:
if base_.fullname == "enum.Enum":
return Instance(first_arg.node, [])
# TODO: support other pep-435 types here
else:
return api.named_type(names.NAMED_TYPE_BUILTINS_STR, [])
assert node.has_base("sqlalchemy.sql.type_api.TypeEngine"), (
"could not extract Python type from node: %s" % node
)
type_engine_sym = api.lookup_fully_qualified_or_none(
"sqlalchemy.sql.type_api.TypeEngine"
)
assert type_engine_sym is not None and isinstance(
type_engine_sym.node, TypeInfo
)
type_engine = map_instance_to_supertype(
Instance(node, []),
type_engine_sym.node,
)
return get_proper_type(type_engine.args[-1])

View File

@ -0,0 +1,335 @@
# ext/mypy/names.py
# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import Dict
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Union
from mypy.nodes import ARG_POS
from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import Decorator
from mypy.nodes import Expression
from mypy.nodes import FuncDef
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import OverloadedFuncDef
from mypy.nodes import SymbolNode
from mypy.nodes import TypeAlias
from mypy.nodes import TypeInfo
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.types import CallableType
from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import UnboundType
from ... import util
COLUMN: int = util.symbol("COLUMN")
RELATIONSHIP: int = util.symbol("RELATIONSHIP")
REGISTRY: int = util.symbol("REGISTRY")
COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY")
TYPEENGINE: int = util.symbol("TYPEENGNE")
MAPPED: int = util.symbol("MAPPED")
DECLARATIVE_BASE: int = util.symbol("DECLARATIVE_BASE")
DECLARATIVE_META: int = util.symbol("DECLARATIVE_META")
MAPPED_DECORATOR: int = util.symbol("MAPPED_DECORATOR")
SYNONYM_PROPERTY: int = util.symbol("SYNONYM_PROPERTY")
COMPOSITE_PROPERTY: int = util.symbol("COMPOSITE_PROPERTY")
DECLARED_ATTR: int = util.symbol("DECLARED_ATTR")
MAPPER_PROPERTY: int = util.symbol("MAPPER_PROPERTY")
AS_DECLARATIVE: int = util.symbol("AS_DECLARATIVE")
AS_DECLARATIVE_BASE: int = util.symbol("AS_DECLARATIVE_BASE")
DECLARATIVE_MIXIN: int = util.symbol("DECLARATIVE_MIXIN")
QUERY_EXPRESSION: int = util.symbol("QUERY_EXPRESSION")
# names that must succeed with mypy.api.named_type
NAMED_TYPE_BUILTINS_OBJECT = "builtins.object"
NAMED_TYPE_BUILTINS_STR = "builtins.str"
NAMED_TYPE_BUILTINS_LIST = "builtins.list"
NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.base.Mapped"
_RelFullNames = {
"sqlalchemy.orm.relationships.Relationship",
"sqlalchemy.orm.relationships.RelationshipProperty",
"sqlalchemy.orm.relationships._RelationshipDeclared",
"sqlalchemy.orm.Relationship",
"sqlalchemy.orm.RelationshipProperty",
}
_lookup: Dict[str, Tuple[int, Set[str]]] = {
"Column": (
COLUMN,
{
"sqlalchemy.sql.schema.Column",
"sqlalchemy.sql.Column",
},
),
"Relationship": (RELATIONSHIP, _RelFullNames),
"RelationshipProperty": (RELATIONSHIP, _RelFullNames),
"_RelationshipDeclared": (RELATIONSHIP, _RelFullNames),
"registry": (
REGISTRY,
{
"sqlalchemy.orm.decl_api.registry",
"sqlalchemy.orm.registry",
},
),
"ColumnProperty": (
COLUMN_PROPERTY,
{
"sqlalchemy.orm.properties.MappedSQLExpression",
"sqlalchemy.orm.MappedSQLExpression",
"sqlalchemy.orm.properties.ColumnProperty",
"sqlalchemy.orm.ColumnProperty",
},
),
"MappedSQLExpression": (
COLUMN_PROPERTY,
{
"sqlalchemy.orm.properties.MappedSQLExpression",
"sqlalchemy.orm.MappedSQLExpression",
"sqlalchemy.orm.properties.ColumnProperty",
"sqlalchemy.orm.ColumnProperty",
},
),
"Synonym": (
SYNONYM_PROPERTY,
{
"sqlalchemy.orm.descriptor_props.Synonym",
"sqlalchemy.orm.Synonym",
"sqlalchemy.orm.descriptor_props.SynonymProperty",
"sqlalchemy.orm.SynonymProperty",
},
),
"SynonymProperty": (
SYNONYM_PROPERTY,
{
"sqlalchemy.orm.descriptor_props.Synonym",
"sqlalchemy.orm.Synonym",
"sqlalchemy.orm.descriptor_props.SynonymProperty",
"sqlalchemy.orm.SynonymProperty",
},
),
"Composite": (
COMPOSITE_PROPERTY,
{
"sqlalchemy.orm.descriptor_props.Composite",
"sqlalchemy.orm.Composite",
"sqlalchemy.orm.descriptor_props.CompositeProperty",
"sqlalchemy.orm.CompositeProperty",
},
),
"CompositeProperty": (
COMPOSITE_PROPERTY,
{
"sqlalchemy.orm.descriptor_props.Composite",
"sqlalchemy.orm.Composite",
"sqlalchemy.orm.descriptor_props.CompositeProperty",
"sqlalchemy.orm.CompositeProperty",
},
),
"MapperProperty": (
MAPPER_PROPERTY,
{
"sqlalchemy.orm.interfaces.MapperProperty",
"sqlalchemy.orm.MapperProperty",
},
),
"TypeEngine": (TYPEENGINE, {"sqlalchemy.sql.type_api.TypeEngine"}),
"Mapped": (MAPPED, {NAMED_TYPE_SQLA_MAPPED}),
"declarative_base": (
DECLARATIVE_BASE,
{
"sqlalchemy.ext.declarative.declarative_base",
"sqlalchemy.orm.declarative_base",
"sqlalchemy.orm.decl_api.declarative_base",
},
),
"DeclarativeMeta": (
DECLARATIVE_META,
{
"sqlalchemy.ext.declarative.DeclarativeMeta",
"sqlalchemy.orm.DeclarativeMeta",
"sqlalchemy.orm.decl_api.DeclarativeMeta",
},
),
"mapped": (
MAPPED_DECORATOR,
{
"sqlalchemy.orm.decl_api.registry.mapped",
"sqlalchemy.orm.registry.mapped",
},
),
"as_declarative": (
AS_DECLARATIVE,
{
"sqlalchemy.ext.declarative.as_declarative",
"sqlalchemy.orm.decl_api.as_declarative",
"sqlalchemy.orm.as_declarative",
},
),
"as_declarative_base": (
AS_DECLARATIVE_BASE,
{
"sqlalchemy.orm.decl_api.registry.as_declarative_base",
"sqlalchemy.orm.registry.as_declarative_base",
},
),
"declared_attr": (
DECLARED_ATTR,
{
"sqlalchemy.orm.decl_api.declared_attr",
"sqlalchemy.orm.declared_attr",
},
),
"declarative_mixin": (
DECLARATIVE_MIXIN,
{
"sqlalchemy.orm.decl_api.declarative_mixin",
"sqlalchemy.orm.declarative_mixin",
},
),
"query_expression": (
QUERY_EXPRESSION,
{
"sqlalchemy.orm.query_expression",
"sqlalchemy.orm._orm_constructors.query_expression",
},
),
}
def has_base_type_id(info: TypeInfo, type_id: int) -> bool:
for mr in info.mro:
check_type_id, fullnames = _lookup.get(mr.name, (None, None))
if check_type_id == type_id:
break
else:
return False
if fullnames is None:
return False
return mr.fullname in fullnames
def mro_has_id(mro: List[TypeInfo], type_id: int) -> bool:
for mr in mro:
check_type_id, fullnames = _lookup.get(mr.name, (None, None))
if check_type_id == type_id:
break
else:
return False
if fullnames is None:
return False
return mr.fullname in fullnames
def type_id_for_unbound_type(
type_: UnboundType, cls: ClassDef, api: SemanticAnalyzerPluginInterface
) -> Optional[int]:
sym = api.lookup_qualified(type_.name, type_)
if sym is not None:
if isinstance(sym.node, TypeAlias):
target_type = get_proper_type(sym.node.target)
if isinstance(target_type, Instance):
return type_id_for_named_node(target_type.type)
elif isinstance(sym.node, TypeInfo):
return type_id_for_named_node(sym.node)
return None
def type_id_for_callee(callee: Expression) -> Optional[int]:
if isinstance(callee, (MemberExpr, NameExpr)):
if isinstance(callee.node, Decorator) and isinstance(
callee.node.func, FuncDef
):
if callee.node.func.type and isinstance(
callee.node.func.type, CallableType
):
ret_type = get_proper_type(callee.node.func.type.ret_type)
if isinstance(ret_type, Instance):
return type_id_for_fullname(ret_type.type.fullname)
return None
elif isinstance(callee.node, OverloadedFuncDef):
if (
callee.node.impl
and callee.node.impl.type
and isinstance(callee.node.impl.type, CallableType)
):
ret_type = get_proper_type(callee.node.impl.type.ret_type)
if isinstance(ret_type, Instance):
return type_id_for_fullname(ret_type.type.fullname)
return None
elif isinstance(callee.node, FuncDef):
if callee.node.type and isinstance(callee.node.type, CallableType):
ret_type = get_proper_type(callee.node.type.ret_type)
if isinstance(ret_type, Instance):
return type_id_for_fullname(ret_type.type.fullname)
return None
elif isinstance(callee.node, TypeAlias):
target_type = get_proper_type(callee.node.target)
if isinstance(target_type, Instance):
return type_id_for_fullname(target_type.type.fullname)
elif isinstance(callee.node, TypeInfo):
return type_id_for_named_node(callee)
return None
def type_id_for_named_node(
node: Union[NameExpr, MemberExpr, SymbolNode]
) -> Optional[int]:
type_id, fullnames = _lookup.get(node.name, (None, None))
if type_id is None or fullnames is None:
return None
elif node.fullname in fullnames:
return type_id
else:
return None
def type_id_for_fullname(fullname: str) -> Optional[int]:
tokens = fullname.split(".")
immediate = tokens[-1]
type_id, fullnames = _lookup.get(immediate, (None, None))
if type_id is None or fullnames is None:
return None
elif fullname in fullnames:
return type_id
else:
return None
def expr_to_mapped_constructor(expr: Expression) -> CallExpr:
column_descriptor = NameExpr("__sa_Mapped")
column_descriptor.fullname = NAMED_TYPE_SQLA_MAPPED
member_expr = MemberExpr(column_descriptor, "_empty_constructor")
return CallExpr(
member_expr,
[expr],
[ARG_POS],
["arg1"],
)

View File

@ -0,0 +1,303 @@
# ext/mypy/plugin.py
# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""
Mypy plugin for SQLAlchemy ORM.
"""
from __future__ import annotations
from typing import Callable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Type as TypingType
from typing import Union
from mypy import nodes
from mypy.mro import calculate_mro
from mypy.mro import MroError
from mypy.nodes import Block
from mypy.nodes import ClassDef
from mypy.nodes import GDEF
from mypy.nodes import MypyFile
from mypy.nodes import NameExpr
from mypy.nodes import SymbolTable
from mypy.nodes import SymbolTableNode
from mypy.nodes import TypeInfo
from mypy.plugin import AttributeContext
from mypy.plugin import ClassDefContext
from mypy.plugin import DynamicClassDefContext
from mypy.plugin import Plugin
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import Type
from . import decl_class
from . import names
from . import util
try:
__import__("sqlalchemy-stubs")
except ImportError:
pass
else:
raise ImportError(
"The SQLAlchemy mypy plugin in SQLAlchemy "
"2.0 does not work with sqlalchemy-stubs or "
"sqlalchemy2-stubs installed, as well as with any other third party "
"SQLAlchemy stubs. Please uninstall all SQLAlchemy stubs "
"packages."
)
class SQLAlchemyPlugin(Plugin):
def get_dynamic_class_hook(
self, fullname: str
) -> Optional[Callable[[DynamicClassDefContext], None]]:
if names.type_id_for_fullname(fullname) is names.DECLARATIVE_BASE:
return _dynamic_class_hook
return None
def get_customize_class_mro_hook(
self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
return _fill_in_decorators
def get_class_decorator_hook(
self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
sym = self.lookup_fully_qualified(fullname)
if sym is not None and sym.node is not None:
type_id = names.type_id_for_named_node(sym.node)
if type_id is names.MAPPED_DECORATOR:
return _cls_decorator_hook
elif type_id in (
names.AS_DECLARATIVE,
names.AS_DECLARATIVE_BASE,
):
return _base_cls_decorator_hook
elif type_id is names.DECLARATIVE_MIXIN:
return _declarative_mixin_hook
return None
def get_metaclass_hook(
self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if names.type_id_for_fullname(fullname) is names.DECLARATIVE_META:
# Set any classes that explicitly have metaclass=DeclarativeMeta
# as declarative so the check in `get_base_class_hook()` works
return _metaclass_cls_hook
return None
def get_base_class_hook(
self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
sym = self.lookup_fully_qualified(fullname)
if (
sym
and isinstance(sym.node, TypeInfo)
and util.has_declarative_base(sym.node)
):
return _base_cls_hook
return None
def get_attribute_hook(
self, fullname: str
) -> Optional[Callable[[AttributeContext], Type]]:
if fullname.startswith(
"sqlalchemy.orm.attributes.QueryableAttribute."
):
return _queryable_getattr_hook
return None
def get_additional_deps(
self, file: MypyFile
) -> List[Tuple[int, str, int]]:
return [
#
(10, "sqlalchemy.orm", -1),
(10, "sqlalchemy.orm.attributes", -1),
(10, "sqlalchemy.orm.decl_api", -1),
]
def plugin(version: str) -> TypingType[SQLAlchemyPlugin]:
return SQLAlchemyPlugin
def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None:
"""Generate a declarative Base class when the declarative_base() function
is encountered."""
_add_globals(ctx)
cls = ClassDef(ctx.name, Block([]))
cls.fullname = ctx.api.qualified_name(ctx.name)
info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id)
cls.info = info
_set_declarative_metaclass(ctx.api, cls)
cls_arg = util.get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,))
if cls_arg is not None and isinstance(cls_arg.node, TypeInfo):
util.set_is_base(cls_arg.node)
decl_class.scan_declarative_assignments_and_apply_types(
cls_arg.node.defn, ctx.api, is_mixin_scan=True
)
info.bases = [Instance(cls_arg.node, [])]
else:
obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT)
info.bases = [obj]
try:
calculate_mro(info)
except MroError:
util.fail(
ctx.api, "Not able to calculate MRO for declarative base", ctx.call
)
obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT)
info.bases = [obj]
info.fallback_to_any = True
ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
util.set_is_base(info)
def _fill_in_decorators(ctx: ClassDefContext) -> None:
for decorator in ctx.cls.decorators:
# set the ".fullname" attribute of a class decorator
# that is a MemberExpr. This causes the logic in
# semanal.py->apply_class_plugin_hooks to invoke the
# get_class_decorator_hook for our "registry.map_class()"
# and "registry.as_declarative_base()" methods.
# this seems like a bug in mypy that these decorators are otherwise
# skipped.
if (
isinstance(decorator, nodes.CallExpr)
and isinstance(decorator.callee, nodes.MemberExpr)
and decorator.callee.name == "as_declarative_base"
):
target = decorator.callee
elif (
isinstance(decorator, nodes.MemberExpr)
and decorator.name == "mapped"
):
target = decorator
else:
continue
if isinstance(target.expr, NameExpr):
sym = ctx.api.lookup_qualified(
target.expr.name, target, suppress_errors=True
)
else:
continue
if sym and sym.node:
sym_type = get_proper_type(sym.type)
if isinstance(sym_type, Instance):
target.fullname = f"{sym_type.type.fullname}.{target.name}"
else:
# if the registry is in the same file as where the
# decorator is used, it might not have semantic
# symbols applied and we can't get a fully qualified
# name or an inferred type, so we are actually going to
# flag an error in this case that they need to annotate
# it. The "registry" is declared just
# once (or few times), so they have to just not use
# type inference for its assignment in this one case.
util.fail(
ctx.api,
"Class decorator called %s(), but we can't "
"tell if it's from an ORM registry. Please "
"annotate the registry assignment, e.g. "
"my_registry: registry = registry()" % target.name,
sym.node,
)
def _cls_decorator_hook(ctx: ClassDefContext) -> None:
_add_globals(ctx)
assert isinstance(ctx.reason, nodes.MemberExpr)
expr = ctx.reason.expr
assert isinstance(expr, nodes.RefExpr) and isinstance(expr.node, nodes.Var)
node_type = get_proper_type(expr.node.type)
assert (
isinstance(node_type, Instance)
and names.type_id_for_named_node(node_type.type) is names.REGISTRY
)
decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
def _base_cls_decorator_hook(ctx: ClassDefContext) -> None:
_add_globals(ctx)
cls = ctx.cls
_set_declarative_metaclass(ctx.api, cls)
util.set_is_base(ctx.cls.info)
decl_class.scan_declarative_assignments_and_apply_types(
cls, ctx.api, is_mixin_scan=True
)
def _declarative_mixin_hook(ctx: ClassDefContext) -> None:
_add_globals(ctx)
util.set_is_base(ctx.cls.info)
decl_class.scan_declarative_assignments_and_apply_types(
ctx.cls, ctx.api, is_mixin_scan=True
)
def _metaclass_cls_hook(ctx: ClassDefContext) -> None:
util.set_is_base(ctx.cls.info)
def _base_cls_hook(ctx: ClassDefContext) -> None:
_add_globals(ctx)
decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
def _queryable_getattr_hook(ctx: AttributeContext) -> Type:
# how do I....tell it it has no attribute of a certain name?
# can't find any Type that seems to match that
return ctx.default_attr_type
def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None:
"""Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space
for all class defs
"""
util.add_global(ctx, "sqlalchemy.orm", "Mapped", "__sa_Mapped")
def _set_declarative_metaclass(
api: SemanticAnalyzerPluginInterface, target_cls: ClassDef
) -> None:
info = target_cls.info
sym = api.lookup_fully_qualified_or_none(
"sqlalchemy.orm.decl_api.DeclarativeMeta"
)
assert sym is not None and isinstance(sym.node, TypeInfo)
info.declared_metaclass = info.metaclass_type = Instance(sym.node, [])

View File

@ -0,0 +1,357 @@
# ext/mypy/util.py
# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
import re
from typing import Any
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import overload
from typing import Tuple
from typing import Type as TypingType
from typing import TypeVar
from typing import Union
from mypy import version
from mypy.messages import format_type as _mypy_format_type
from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import CLASSDEF_NO_INFO
from mypy.nodes import Context
from mypy.nodes import Expression
from mypy.nodes import FuncDef
from mypy.nodes import IfStmt
from mypy.nodes import JsonDict
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import Statement
from mypy.nodes import SymbolTableNode
from mypy.nodes import TypeAlias
from mypy.nodes import TypeInfo
from mypy.options import Options
from mypy.plugin import ClassDefContext
from mypy.plugin import DynamicClassDefContext
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import deserialize_and_fixup_type
from mypy.typeops import map_type_from_supertype
from mypy.types import CallableType
from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import NoneType
from mypy.types import Type
from mypy.types import TypeVarType
from mypy.types import UnboundType
from mypy.types import UnionType
_vers = tuple(
[int(x) for x in version.__version__.split(".") if re.match(r"^\d+$", x)]
)
mypy_14 = _vers >= (1, 4)
_TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr])
class SQLAlchemyAttribute:
def __init__(
self,
name: str,
line: int,
column: int,
typ: Optional[Type],
info: TypeInfo,
) -> None:
self.name = name
self.line = line
self.column = column
self.type = typ
self.info = info
def serialize(self) -> JsonDict:
assert self.type
return {
"name": self.name,
"line": self.line,
"column": self.column,
"type": serialize_type(self.type),
}
def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
"""Expands type vars in the context of a subtype when an attribute is
inherited from a generic super type.
"""
if not isinstance(self.type, TypeVarType):
return
self.type = map_type_from_supertype(self.type, sub_type, self.info)
@classmethod
def deserialize(
cls,
info: TypeInfo,
data: JsonDict,
api: SemanticAnalyzerPluginInterface,
) -> SQLAlchemyAttribute:
data = data.copy()
typ = deserialize_and_fixup_type(data.pop("type"), api)
return cls(typ=typ, info=info, **data)
def name_is_dunder(name: str) -> bool:
return bool(re.match(r"^__.+?__$", name))
def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None:
info.metadata.setdefault("sqlalchemy", {})[key] = data
def _get_info_metadata(info: TypeInfo, key: str) -> Optional[Any]:
return info.metadata.get("sqlalchemy", {}).get(key, None)
def _get_info_mro_metadata(info: TypeInfo, key: str) -> Optional[Any]:
if info.mro:
for base in info.mro:
metadata = _get_info_metadata(base, key)
if metadata is not None:
return metadata
return None
def establish_as_sqlalchemy(info: TypeInfo) -> None:
info.metadata.setdefault("sqlalchemy", {})
def set_is_base(info: TypeInfo) -> None:
_set_info_metadata(info, "is_base", True)
def get_is_base(info: TypeInfo) -> bool:
is_base = _get_info_metadata(info, "is_base")
return is_base is True
def has_declarative_base(info: TypeInfo) -> bool:
is_base = _get_info_mro_metadata(info, "is_base")
return is_base is True
def set_has_table(info: TypeInfo) -> None:
_set_info_metadata(info, "has_table", True)
def get_has_table(info: TypeInfo) -> bool:
is_base = _get_info_metadata(info, "has_table")
return is_base is True
def get_mapped_attributes(
info: TypeInfo, api: SemanticAnalyzerPluginInterface
) -> Optional[List[SQLAlchemyAttribute]]:
mapped_attributes: Optional[List[JsonDict]] = _get_info_metadata(
info, "mapped_attributes"
)
if mapped_attributes is None:
return None
attributes: List[SQLAlchemyAttribute] = []
for data in mapped_attributes:
attr = SQLAlchemyAttribute.deserialize(info, data, api)
attr.expand_typevar_from_subtype(info)
attributes.append(attr)
return attributes
def format_type(typ_: Type, options: Options) -> str:
if mypy_14:
return _mypy_format_type(typ_, options)
else:
return _mypy_format_type(typ_) # type: ignore
def set_mapped_attributes(
info: TypeInfo, attributes: List[SQLAlchemyAttribute]
) -> None:
_set_info_metadata(
info,
"mapped_attributes",
[attribute.serialize() for attribute in attributes],
)
def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None:
msg = "[SQLAlchemy Mypy plugin] %s" % msg
return api.fail(msg, ctx)
def add_global(
ctx: Union[ClassDefContext, DynamicClassDefContext],
module: str,
symbol_name: str,
asname: str,
) -> None:
module_globals = ctx.api.modules[ctx.api.cur_mod_id].names
if asname not in module_globals:
lookup_sym: SymbolTableNode = ctx.api.modules[module].names[
symbol_name
]
module_globals[asname] = lookup_sym
@overload
def get_callexpr_kwarg(
callexpr: CallExpr, name: str, *, expr_types: None = ...
) -> Optional[Union[CallExpr, NameExpr]]: ...
@overload
def get_callexpr_kwarg(
callexpr: CallExpr,
name: str,
*,
expr_types: Tuple[TypingType[_TArgType], ...],
) -> Optional[_TArgType]: ...
def get_callexpr_kwarg(
callexpr: CallExpr,
name: str,
*,
expr_types: Optional[Tuple[TypingType[Any], ...]] = None,
) -> Optional[Any]:
try:
arg_idx = callexpr.arg_names.index(name)
except ValueError:
return None
kwarg = callexpr.args[arg_idx]
if isinstance(
kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr)
):
return kwarg
return None
def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
for stmt in stmts:
if (
isinstance(stmt, IfStmt)
and isinstance(stmt.expr[0], NameExpr)
and stmt.expr[0].fullname == "typing.TYPE_CHECKING"
):
yield from stmt.body[0].body
else:
yield stmt
def type_for_callee(callee: Expression) -> Optional[Union[Instance, TypeInfo]]:
if isinstance(callee, (MemberExpr, NameExpr)):
if isinstance(callee.node, FuncDef):
if callee.node.type and isinstance(callee.node.type, CallableType):
ret_type = get_proper_type(callee.node.type.ret_type)
if isinstance(ret_type, Instance):
return ret_type
return None
elif isinstance(callee.node, TypeAlias):
target_type = get_proper_type(callee.node.target)
if isinstance(target_type, Instance):
return target_type
elif isinstance(callee.node, TypeInfo):
return callee.node
return None
def unbound_to_instance(
api: SemanticAnalyzerPluginInterface, typ: Type
) -> Type:
"""Take the UnboundType that we seem to get as the ret_type from a FuncDef
and convert it into an Instance/TypeInfo kind of structure that seems
to work as the left-hand type of an AssignmentStatement.
"""
if not isinstance(typ, UnboundType):
return typ
# TODO: figure out a more robust way to check this. The node is some
# kind of _SpecialForm, there's a typing.Optional that's _SpecialForm,
# but I can't figure out how to get them to match up
if typ.name == "Optional":
# convert from "Optional?" to the more familiar
# UnionType[..., NoneType()]
return unbound_to_instance(
api,
UnionType(
[unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
+ [NoneType()]
),
)
node = api.lookup_qualified(typ.name, typ)
if (
node is not None
and isinstance(node, SymbolTableNode)
and isinstance(node.node, TypeInfo)
):
bound_type = node.node
return Instance(
bound_type,
[
(
unbound_to_instance(api, arg)
if isinstance(arg, UnboundType)
else arg
)
for arg in typ.args
],
)
else:
return typ
def info_for_cls(
cls: ClassDef, api: SemanticAnalyzerPluginInterface
) -> Optional[TypeInfo]:
if cls.info is CLASSDEF_NO_INFO:
sym = api.lookup_qualified(cls.name, cls)
if sym is None:
return None
assert sym and isinstance(sym.node, TypeInfo)
return sym.node
return cls.info
def serialize_type(typ: Type) -> Union[str, JsonDict]:
try:
return typ.serialize()
except Exception:
pass
if hasattr(typ, "args"):
typ.args = tuple(
(
a.resolve_string_annotation()
if hasattr(a, "resolve_string_annotation")
else a
)
for a in typ.args
)
elif hasattr(typ, "resolve_string_annotation"):
typ = typ.resolve_string_annotation()
return typ.serialize()

View File

@ -0,0 +1,427 @@
# ext/orderinglist.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""A custom list that manages index/position information for contained
elements.
:author: Jason Kirtland
``orderinglist`` is a helper for mutable ordered relationships. It will
intercept list operations performed on a :func:`_orm.relationship`-managed
collection and
automatically synchronize changes in list position onto a target scalar
attribute.
Example: A ``slide`` table, where each row refers to zero or more entries
in a related ``bullet`` table. The bullets within a slide are
displayed in order based on the value of the ``position`` column in the
``bullet`` table. As entries are reordered in memory, the value of the
``position`` attribute should be updated to reflect the new sort order::
Base = declarative_base()
class Slide(Base):
__tablename__ = "slide"
id = Column(Integer, primary_key=True)
name = Column(String)
bullets = relationship("Bullet", order_by="Bullet.position")
class Bullet(Base):
__tablename__ = "bullet"
id = Column(Integer, primary_key=True)
slide_id = Column(Integer, ForeignKey("slide.id"))
position = Column(Integer)
text = Column(String)
The standard relationship mapping will produce a list-like attribute on each
``Slide`` containing all related ``Bullet`` objects,
but coping with changes in ordering is not handled automatically.
When appending a ``Bullet`` into ``Slide.bullets``, the ``Bullet.position``
attribute will remain unset until manually assigned. When the ``Bullet``
is inserted into the middle of the list, the following ``Bullet`` objects
will also need to be renumbered.
The :class:`.OrderingList` object automates this task, managing the
``position`` attribute on all ``Bullet`` objects in the collection. It is
constructed using the :func:`.ordering_list` factory::
from sqlalchemy.ext.orderinglist import ordering_list
Base = declarative_base()
class Slide(Base):
__tablename__ = "slide"
id = Column(Integer, primary_key=True)
name = Column(String)
bullets = relationship(
"Bullet",
order_by="Bullet.position",
collection_class=ordering_list("position"),
)
class Bullet(Base):
__tablename__ = "bullet"
id = Column(Integer, primary_key=True)
slide_id = Column(Integer, ForeignKey("slide.id"))
position = Column(Integer)
text = Column(String)
With the above mapping the ``Bullet.position`` attribute is managed::
s = Slide()
s.bullets.append(Bullet())
s.bullets.append(Bullet())
s.bullets[1].position
>>> 1
s.bullets.insert(1, Bullet())
s.bullets[2].position
>>> 2
The :class:`.OrderingList` construct only works with **changes** to a
collection, and not the initial load from the database, and requires that the
list be sorted when loaded. Therefore, be sure to specify ``order_by`` on the
:func:`_orm.relationship` against the target ordering attribute, so that the
ordering is correct when first loaded.
.. warning::
:class:`.OrderingList` only provides limited functionality when a primary
key column or unique column is the target of the sort. Operations
that are unsupported or are problematic include:
* two entries must trade values. This is not supported directly in the
case of a primary key or unique constraint because it means at least
one row would need to be temporarily removed first, or changed to
a third, neutral value while the switch occurs.
* an entry must be deleted in order to make room for a new entry.
SQLAlchemy's unit of work performs all INSERTs before DELETEs within a
single flush. In the case of a primary key, it will trade
an INSERT/DELETE of the same primary key for an UPDATE statement in order
to lessen the impact of this limitation, however this does not take place
for a UNIQUE column.
A future feature will allow the "DELETE before INSERT" behavior to be
possible, alleviating this limitation, though this feature will require
explicit configuration at the mapper level for sets of columns that
are to be handled in this way.
:func:`.ordering_list` takes the name of the related object's ordering
attribute as an argument. By default, the zero-based integer index of the
object's position in the :func:`.ordering_list` is synchronized with the
ordering attribute: index 0 will get position 0, index 1 position 1, etc. To
start numbering at 1 or some other integer, provide ``count_from=1``.
"""
from __future__ import annotations
from typing import Callable
from typing import List
from typing import Optional
from typing import Sequence
from typing import TypeVar
from ..orm.collections import collection
from ..orm.collections import collection_adapter
_T = TypeVar("_T")
OrderingFunc = Callable[[int, Sequence[_T]], int]
__all__ = ["ordering_list"]
def ordering_list(
attr: str,
count_from: Optional[int] = None,
ordering_func: Optional[OrderingFunc] = None,
reorder_on_append: bool = False,
) -> Callable[[], OrderingList]:
"""Prepares an :class:`OrderingList` factory for use in mapper definitions.
Returns an object suitable for use as an argument to a Mapper
relationship's ``collection_class`` option. e.g.::
from sqlalchemy.ext.orderinglist import ordering_list
class Slide(Base):
__tablename__ = "slide"
id = Column(Integer, primary_key=True)
name = Column(String)
bullets = relationship(
"Bullet",
order_by="Bullet.position",
collection_class=ordering_list("position"),
)
:param attr:
Name of the mapped attribute to use for storage and retrieval of
ordering information
:param count_from:
Set up an integer-based ordering, starting at ``count_from``. For
example, ``ordering_list('pos', count_from=1)`` would create a 1-based
list in SQL, storing the value in the 'pos' column. Ignored if
``ordering_func`` is supplied.
Additional arguments are passed to the :class:`.OrderingList` constructor.
"""
kw = _unsugar_count_from(
count_from=count_from,
ordering_func=ordering_func,
reorder_on_append=reorder_on_append,
)
return lambda: OrderingList(attr, **kw)
# Ordering utility functions
def count_from_0(index, collection):
"""Numbering function: consecutive integers starting at 0."""
return index
def count_from_1(index, collection):
"""Numbering function: consecutive integers starting at 1."""
return index + 1
def count_from_n_factory(start):
"""Numbering function: consecutive integers starting at arbitrary start."""
def f(index, collection):
return index + start
try:
f.__name__ = "count_from_%i" % start
except TypeError:
pass
return f
def _unsugar_count_from(**kw):
"""Builds counting functions from keyword arguments.
Keyword argument filter, prepares a simple ``ordering_func`` from a
``count_from`` argument, otherwise passes ``ordering_func`` on unchanged.
"""
count_from = kw.pop("count_from", None)
if kw.get("ordering_func", None) is None and count_from is not None:
if count_from == 0:
kw["ordering_func"] = count_from_0
elif count_from == 1:
kw["ordering_func"] = count_from_1
else:
kw["ordering_func"] = count_from_n_factory(count_from)
return kw
class OrderingList(List[_T]):
"""A custom list that manages position information for its children.
The :class:`.OrderingList` object is normally set up using the
:func:`.ordering_list` factory function, used in conjunction with
the :func:`_orm.relationship` function.
"""
ordering_attr: str
ordering_func: OrderingFunc
reorder_on_append: bool
def __init__(
self,
ordering_attr: Optional[str] = None,
ordering_func: Optional[OrderingFunc] = None,
reorder_on_append: bool = False,
):
"""A custom list that manages position information for its children.
``OrderingList`` is a ``collection_class`` list implementation that
syncs position in a Python list with a position attribute on the
mapped objects.
This implementation relies on the list starting in the proper order,
so be **sure** to put an ``order_by`` on your relationship.
:param ordering_attr:
Name of the attribute that stores the object's order in the
relationship.
:param ordering_func: Optional. A function that maps the position in
the Python list to a value to store in the
``ordering_attr``. Values returned are usually (but need not be!)
integers.
An ``ordering_func`` is called with two positional parameters: the
index of the element in the list, and the list itself.
If omitted, Python list indexes are used for the attribute values.
Two basic pre-built numbering functions are provided in this module:
``count_from_0`` and ``count_from_1``. For more exotic examples
like stepped numbering, alphabetical and Fibonacci numbering, see
the unit tests.
:param reorder_on_append:
Default False. When appending an object with an existing (non-None)
ordering value, that value will be left untouched unless
``reorder_on_append`` is true. This is an optimization to avoid a
variety of dangerous unexpected database writes.
SQLAlchemy will add instances to the list via append() when your
object loads. If for some reason the result set from the database
skips a step in the ordering (say, row '1' is missing but you get
'2', '3', and '4'), reorder_on_append=True would immediately
renumber the items to '1', '2', '3'. If you have multiple sessions
making changes, any of whom happen to load this collection even in
passing, all of the sessions would try to "clean up" the numbering
in their commits, possibly causing all but one to fail with a
concurrent modification error.
Recommend leaving this with the default of False, and just call
``reorder()`` if you're doing ``append()`` operations with
previously ordered instances or when doing some housekeeping after
manual sql operations.
"""
self.ordering_attr = ordering_attr
if ordering_func is None:
ordering_func = count_from_0
self.ordering_func = ordering_func
self.reorder_on_append = reorder_on_append
# More complex serialization schemes (multi column, e.g.) are possible by
# subclassing and reimplementing these two methods.
def _get_order_value(self, entity):
return getattr(entity, self.ordering_attr)
def _set_order_value(self, entity, value):
setattr(entity, self.ordering_attr, value)
def reorder(self) -> None:
"""Synchronize ordering for the entire collection.
Sweeps through the list and ensures that each object has accurate
ordering information set.
"""
for index, entity in enumerate(self):
self._order_entity(index, entity, True)
# As of 0.5, _reorder is no longer semi-private
_reorder = reorder
def _order_entity(self, index, entity, reorder=True):
have = self._get_order_value(entity)
# Don't disturb existing ordering if reorder is False
if have is not None and not reorder:
return
should_be = self.ordering_func(index, self)
if have != should_be:
self._set_order_value(entity, should_be)
def append(self, entity):
super().append(entity)
self._order_entity(len(self) - 1, entity, self.reorder_on_append)
def _raw_append(self, entity):
"""Append without any ordering behavior."""
super().append(entity)
_raw_append = collection.adds(1)(_raw_append)
def insert(self, index, entity):
super().insert(index, entity)
self._reorder()
def remove(self, entity):
super().remove(entity)
adapter = collection_adapter(self)
if adapter and adapter._referenced_by_owner:
self._reorder()
def pop(self, index=-1):
entity = super().pop(index)
self._reorder()
return entity
def __setitem__(self, index, entity):
if isinstance(index, slice):
step = index.step or 1
start = index.start or 0
if start < 0:
start += len(self)
stop = index.stop or len(self)
if stop < 0:
stop += len(self)
for i in range(start, stop, step):
self.__setitem__(i, entity[i])
else:
self._order_entity(index, entity, True)
super().__setitem__(index, entity)
def __delitem__(self, index):
super().__delitem__(index)
self._reorder()
def __setslice__(self, start, end, values):
super().__setslice__(start, end, values)
self._reorder()
def __delslice__(self, start, end):
super().__delslice__(start, end)
self._reorder()
def __reduce__(self):
return _reconstitute, (self.__class__, self.__dict__, list(self))
for func_name, func in list(locals().items()):
if (
callable(func)
and func.__name__ == func_name
and not func.__doc__
and hasattr(list, func_name)
):
func.__doc__ = getattr(list, func_name).__doc__
del func_name, func
def _reconstitute(cls, dict_, items):
"""Reconstitute an :class:`.OrderingList`.
This is the adjoint to :meth:`.OrderingList.__reduce__`. It is used for
unpickling :class:`.OrderingList` objects.
"""
obj = cls.__new__(cls)
obj.__dict__.update(dict_)
list.extend(obj, items)
return obj

View File

@ -0,0 +1,185 @@
# ext/serializer.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""Serializer/Deserializer objects for usage with SQLAlchemy query structures,
allowing "contextual" deserialization.
.. legacy::
The serializer extension is **legacy** and should not be used for
new development.
Any SQLAlchemy query structure, either based on sqlalchemy.sql.*
or sqlalchemy.orm.* can be used. The mappers, Tables, Columns, Session
etc. which are referenced by the structure are not persisted in serialized
form, but are instead re-associated with the query structure
when it is deserialized.
.. warning:: The serializer extension uses pickle to serialize and
deserialize objects, so the same security consideration mentioned
in the `python documentation
<https://docs.python.org/3/library/pickle.html>`_ apply.
Usage is nearly the same as that of the standard Python pickle module::
from sqlalchemy.ext.serializer import loads, dumps
metadata = MetaData(bind=some_engine)
Session = scoped_session(sessionmaker())
# ... define mappers
query = (
Session.query(MyClass)
.filter(MyClass.somedata == "foo")
.order_by(MyClass.sortkey)
)
# pickle the query
serialized = dumps(query)
# unpickle. Pass in metadata + scoped_session
query2 = loads(serialized, metadata, Session)
print(query2.all())
Similar restrictions as when using raw pickle apply; mapped classes must be
themselves be pickleable, meaning they are importable from a module-level
namespace.
The serializer module is only appropriate for query structures. It is not
needed for:
* instances of user-defined classes. These contain no references to engines,
sessions or expression constructs in the typical case and can be serialized
directly.
* Table metadata that is to be loaded entirely from the serialized structure
(i.e. is not already declared in the application). Regular
pickle.loads()/dumps() can be used to fully dump any ``MetaData`` object,
typically one which was reflected from an existing database at some previous
point in time. The serializer module is specifically for the opposite case,
where the Table metadata is already present in memory.
"""
from io import BytesIO
import pickle
import re
from .. import Column
from .. import Table
from ..engine import Engine
from ..orm import class_mapper
from ..orm.interfaces import MapperProperty
from ..orm.mapper import Mapper
from ..orm.session import Session
from ..util import b64decode
from ..util import b64encode
__all__ = ["Serializer", "Deserializer", "dumps", "loads"]
class Serializer(pickle.Pickler):
def persistent_id(self, obj):
# print "serializing:", repr(obj)
if isinstance(obj, Mapper) and not obj.non_primary:
id_ = "mapper:" + b64encode(pickle.dumps(obj.class_))
elif isinstance(obj, MapperProperty) and not obj.parent.non_primary:
id_ = (
"mapperprop:"
+ b64encode(pickle.dumps(obj.parent.class_))
+ ":"
+ obj.key
)
elif isinstance(obj, Table):
if "parententity" in obj._annotations:
id_ = "mapper_selectable:" + b64encode(
pickle.dumps(obj._annotations["parententity"].class_)
)
else:
id_ = f"table:{obj.key}"
elif isinstance(obj, Column) and isinstance(obj.table, Table):
id_ = f"column:{obj.table.key}:{obj.key}"
elif isinstance(obj, Session):
id_ = "session:"
elif isinstance(obj, Engine):
id_ = "engine:"
else:
return None
return id_
our_ids = re.compile(
r"(mapperprop|mapper|mapper_selectable|table|column|"
r"session|attribute|engine):(.*)"
)
class Deserializer(pickle.Unpickler):
def __init__(self, file, metadata=None, scoped_session=None, engine=None):
super().__init__(file)
self.metadata = metadata
self.scoped_session = scoped_session
self.engine = engine
def get_engine(self):
if self.engine:
return self.engine
elif self.scoped_session and self.scoped_session().bind:
return self.scoped_session().bind
else:
return None
def persistent_load(self, id_):
m = our_ids.match(str(id_))
if not m:
return None
else:
type_, args = m.group(1, 2)
if type_ == "attribute":
key, clsarg = args.split(":")
cls = pickle.loads(b64decode(clsarg))
return getattr(cls, key)
elif type_ == "mapper":
cls = pickle.loads(b64decode(args))
return class_mapper(cls)
elif type_ == "mapper_selectable":
cls = pickle.loads(b64decode(args))
return class_mapper(cls).__clause_element__()
elif type_ == "mapperprop":
mapper, keyname = args.split(":")
cls = pickle.loads(b64decode(mapper))
return class_mapper(cls).attrs[keyname]
elif type_ == "table":
return self.metadata.tables[args]
elif type_ == "column":
table, colname = args.split(":")
return self.metadata.tables[table].c[colname]
elif type_ == "session":
return self.scoped_session()
elif type_ == "engine":
return self.get_engine()
else:
raise Exception("Unknown token: %s" % type_)
def dumps(obj, protocol=pickle.HIGHEST_PROTOCOL):
buf = BytesIO()
pickler = Serializer(buf, protocol)
pickler.dump(obj)
return buf.getvalue()
def loads(data, metadata=None, scoped_session=None, engine=None):
buf = BytesIO(data)
unpickler = Deserializer(buf, metadata, scoped_session, engine)
return unpickler.load()