Update 2025-04-24_11:44:19

This commit is contained in:
oib
2025-04-24 11:44:23 +02:00
commit e748c737f4
3408 changed files with 717481 additions and 0 deletions

View File

@ -0,0 +1,160 @@
# util/__init__.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from collections import defaultdict as defaultdict
from functools import partial as partial
from functools import update_wrapper as update_wrapper
from typing import TYPE_CHECKING
from . import preloaded as preloaded
from ._collections import coerce_generator_arg as coerce_generator_arg
from ._collections import coerce_to_immutabledict as coerce_to_immutabledict
from ._collections import column_dict as column_dict
from ._collections import column_set as column_set
from ._collections import EMPTY_DICT as EMPTY_DICT
from ._collections import EMPTY_SET as EMPTY_SET
from ._collections import FacadeDict as FacadeDict
from ._collections import flatten_iterator as flatten_iterator
from ._collections import has_dupes as has_dupes
from ._collections import has_intersection as has_intersection
from ._collections import IdentitySet as IdentitySet
from ._collections import immutabledict as immutabledict
from ._collections import LRUCache as LRUCache
from ._collections import merge_lists_w_ordering as merge_lists_w_ordering
from ._collections import NONE_SET as NONE_SET
from ._collections import ordered_column_set as ordered_column_set
from ._collections import OrderedDict as OrderedDict
from ._collections import OrderedIdentitySet as OrderedIdentitySet
from ._collections import OrderedProperties as OrderedProperties
from ._collections import OrderedSet as OrderedSet
from ._collections import PopulateDict as PopulateDict
from ._collections import Properties as Properties
from ._collections import ReadOnlyContainer as ReadOnlyContainer
from ._collections import ReadOnlyProperties as ReadOnlyProperties
from ._collections import ScopedRegistry as ScopedRegistry
from ._collections import sort_dictionary as sort_dictionary
from ._collections import ThreadLocalRegistry as ThreadLocalRegistry
from ._collections import to_column_set as to_column_set
from ._collections import to_list as to_list
from ._collections import to_set as to_set
from ._collections import unique_list as unique_list
from ._collections import UniqueAppender as UniqueAppender
from ._collections import update_copy as update_copy
from ._collections import WeakPopulateDict as WeakPopulateDict
from ._collections import WeakSequence as WeakSequence
from .compat import anext_ as anext_
from .compat import arm as arm
from .compat import b as b
from .compat import b64decode as b64decode
from .compat import b64encode as b64encode
from .compat import cmp as cmp
from .compat import cpython as cpython
from .compat import dataclass_fields as dataclass_fields
from .compat import decode_backslashreplace as decode_backslashreplace
from .compat import dottedgetter as dottedgetter
from .compat import has_refcount_gc as has_refcount_gc
from .compat import inspect_getfullargspec as inspect_getfullargspec
from .compat import is64bit as is64bit
from .compat import local_dataclass_fields as local_dataclass_fields
from .compat import osx as osx
from .compat import py310 as py310
from .compat import py311 as py311
from .compat import py312 as py312
from .compat import py313 as py313
from .compat import py38 as py38
from .compat import py39 as py39
from .compat import pypy as pypy
from .compat import win32 as win32
from .concurrency import await_fallback as await_fallback
from .concurrency import await_only as await_only
from .concurrency import greenlet_spawn as greenlet_spawn
from .concurrency import is_exit_exception as is_exit_exception
from .deprecations import became_legacy_20 as became_legacy_20
from .deprecations import deprecated as deprecated
from .deprecations import deprecated_cls as deprecated_cls
from .deprecations import deprecated_params as deprecated_params
from .deprecations import moved_20 as moved_20
from .deprecations import warn_deprecated as warn_deprecated
from .langhelpers import add_parameter_text as add_parameter_text
from .langhelpers import as_interface as as_interface
from .langhelpers import asbool as asbool
from .langhelpers import asint as asint
from .langhelpers import assert_arg_type as assert_arg_type
from .langhelpers import attrsetter as attrsetter
from .langhelpers import bool_or_str as bool_or_str
from .langhelpers import chop_traceback as chop_traceback
from .langhelpers import class_hierarchy as class_hierarchy
from .langhelpers import classproperty as classproperty
from .langhelpers import clsname_as_plain_name as clsname_as_plain_name
from .langhelpers import coerce_kw_type as coerce_kw_type
from .langhelpers import constructor_copy as constructor_copy
from .langhelpers import constructor_key as constructor_key
from .langhelpers import counter as counter
from .langhelpers import create_proxy_methods as create_proxy_methods
from .langhelpers import decode_slice as decode_slice
from .langhelpers import decorator as decorator
from .langhelpers import dictlike_iteritems as dictlike_iteritems
from .langhelpers import duck_type_collection as duck_type_collection
from .langhelpers import ellipses_string as ellipses_string
from .langhelpers import EnsureKWArg as EnsureKWArg
from .langhelpers import FastIntFlag as FastIntFlag
from .langhelpers import format_argspec_init as format_argspec_init
from .langhelpers import format_argspec_plus as format_argspec_plus
from .langhelpers import generic_fn_descriptor as generic_fn_descriptor
from .langhelpers import generic_repr as generic_repr
from .langhelpers import get_annotations as get_annotations
from .langhelpers import get_callable_argspec as get_callable_argspec
from .langhelpers import get_cls_kwargs as get_cls_kwargs
from .langhelpers import get_func_kwargs as get_func_kwargs
from .langhelpers import getargspec_init as getargspec_init
from .langhelpers import has_compiled_ext as has_compiled_ext
from .langhelpers import HasMemoized as HasMemoized
from .langhelpers import (
HasMemoized_ro_memoized_attribute as HasMemoized_ro_memoized_attribute,
)
from .langhelpers import hybridmethod as hybridmethod
from .langhelpers import hybridproperty as hybridproperty
from .langhelpers import inject_docstring_text as inject_docstring_text
from .langhelpers import iterate_attributes as iterate_attributes
from .langhelpers import map_bits as map_bits
from .langhelpers import md5_hex as md5_hex
from .langhelpers import memoized_instancemethod as memoized_instancemethod
from .langhelpers import memoized_property as memoized_property
from .langhelpers import MemoizedSlots as MemoizedSlots
from .langhelpers import method_is_overridden as method_is_overridden
from .langhelpers import methods_equivalent as methods_equivalent
from .langhelpers import (
monkeypatch_proxied_specials as monkeypatch_proxied_specials,
)
from .langhelpers import non_memoized_property as non_memoized_property
from .langhelpers import NoneType as NoneType
from .langhelpers import only_once as only_once
from .langhelpers import (
parse_user_argument_for_enum as parse_user_argument_for_enum,
)
from .langhelpers import PluginLoader as PluginLoader
from .langhelpers import portable_instancemethod as portable_instancemethod
from .langhelpers import quoted_token_parser as quoted_token_parser
from .langhelpers import ro_memoized_property as ro_memoized_property
from .langhelpers import ro_non_memoized_property as ro_non_memoized_property
from .langhelpers import rw_hybridproperty as rw_hybridproperty
from .langhelpers import safe_reraise as safe_reraise
from .langhelpers import set_creation_order as set_creation_order
from .langhelpers import string_or_unprintable as string_or_unprintable
from .langhelpers import symbol as symbol
from .langhelpers import TypingOnly as TypingOnly
from .langhelpers import (
unbound_method_to_callable as unbound_method_to_callable,
)
from .langhelpers import walk_subclasses as walk_subclasses
from .langhelpers import warn as warn
from .langhelpers import warn_exception as warn_exception
from .langhelpers import warn_limited as warn_limited
from .langhelpers import wrap_callable as wrap_callable
from .preloaded import preload_module as preload_module
from .typing import is_non_string_iterable as is_non_string_iterable

View File

@ -0,0 +1,717 @@
# util/_collections.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: allow-untyped-defs, allow-untyped-calls
"""Collection classes and helpers."""
from __future__ import annotations
import operator
import threading
import types
import typing
from typing import Any
from typing import Callable
from typing import cast
from typing import Container
from typing import Dict
from typing import FrozenSet
from typing import Generic
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Mapping
from typing import NoReturn
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import TypeVar
from typing import Union
from typing import ValuesView
import weakref
from ._has_cy import HAS_CYEXTENSION
from .typing import is_non_string_iterable
from .typing import Literal
from .typing import Protocol
if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
from ._py_collections import immutabledict as immutabledict
from ._py_collections import IdentitySet as IdentitySet
from ._py_collections import ReadOnlyContainer as ReadOnlyContainer
from ._py_collections import ImmutableDictBase as ImmutableDictBase
from ._py_collections import OrderedSet as OrderedSet
from ._py_collections import unique_list as unique_list
else:
from sqlalchemy.cyextension.immutabledict import (
ReadOnlyContainer as ReadOnlyContainer,
)
from sqlalchemy.cyextension.immutabledict import (
ImmutableDictBase as ImmutableDictBase,
)
from sqlalchemy.cyextension.immutabledict import (
immutabledict as immutabledict,
)
from sqlalchemy.cyextension.collections import IdentitySet as IdentitySet
from sqlalchemy.cyextension.collections import OrderedSet as OrderedSet
from sqlalchemy.cyextension.collections import ( # noqa
unique_list as unique_list,
)
_T = TypeVar("_T", bound=Any)
_KT = TypeVar("_KT", bound=Any)
_VT = TypeVar("_VT", bound=Any)
_T_co = TypeVar("_T_co", covariant=True)
EMPTY_SET: FrozenSet[Any] = frozenset()
NONE_SET: FrozenSet[Any] = frozenset([None])
def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]:
"""merge two lists, maintaining ordering as much as possible.
this is to reconcile vars(cls) with cls.__annotations__.
Example::
>>> a = ["__tablename__", "id", "x", "created_at"]
>>> b = ["id", "name", "data", "y", "created_at"]
>>> merge_lists_w_ordering(a, b)
['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at']
This is not necessarily the ordering that things had on the class,
in this case the class is::
class User(Base):
__tablename__ = "users"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str]
data: Mapped[Optional[str]]
x = Column(Integer)
y: Mapped[int]
created_at: Mapped[datetime.datetime] = mapped_column()
But things are *mostly* ordered.
The algorithm could also be done by creating a partial ordering for
all items in both lists and then using topological_sort(), but that
is too much overhead.
Background on how I came up with this is at:
https://gist.github.com/zzzeek/89de958cf0803d148e74861bd682ebae
"""
overlap = set(a).intersection(b)
result = []
current, other = iter(a), iter(b)
while True:
for element in current:
if element in overlap:
overlap.discard(element)
other, current = current, other
break
result.append(element)
else:
result.extend(other)
break
return result
def coerce_to_immutabledict(d: Mapping[_KT, _VT]) -> immutabledict[_KT, _VT]:
if not d:
return EMPTY_DICT
elif isinstance(d, immutabledict):
return d
else:
return immutabledict(d)
EMPTY_DICT: immutabledict[Any, Any] = immutabledict()
class FacadeDict(ImmutableDictBase[_KT, _VT]):
"""A dictionary that is not publicly mutable."""
def __new__(cls, *args: Any) -> FacadeDict[Any, Any]:
new = ImmutableDictBase.__new__(cls)
return new
def copy(self) -> NoReturn:
raise NotImplementedError(
"an immutabledict shouldn't need to be copied. use dict(d) "
"if you need a mutable dictionary."
)
def __reduce__(self) -> Any:
return FacadeDict, (dict(self),)
def _insert_item(self, key: _KT, value: _VT) -> None:
"""insert an item into the dictionary directly."""
dict.__setitem__(self, key, value)
def __repr__(self) -> str:
return "FacadeDict(%s)" % dict.__repr__(self)
_DT = TypeVar("_DT", bound=Any)
_F = TypeVar("_F", bound=Any)
class Properties(Generic[_T]):
"""Provide a __getattr__/__setattr__ interface over a dict."""
__slots__ = ("_data",)
_data: Dict[str, _T]
def __init__(self, data: Dict[str, _T]):
object.__setattr__(self, "_data", data)
def __len__(self) -> int:
return len(self._data)
def __iter__(self) -> Iterator[_T]:
return iter(list(self._data.values()))
def __dir__(self) -> List[str]:
return dir(super()) + [str(k) for k in self._data.keys()]
def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]:
return list(self) + list(other)
def __setitem__(self, key: str, obj: _T) -> None:
self._data[key] = obj
def __getitem__(self, key: str) -> _T:
return self._data[key]
def __delitem__(self, key: str) -> None:
del self._data[key]
def __setattr__(self, key: str, obj: _T) -> None:
self._data[key] = obj
def __getstate__(self) -> Dict[str, Any]:
return {"_data": self._data}
def __setstate__(self, state: Dict[str, Any]) -> None:
object.__setattr__(self, "_data", state["_data"])
def __getattr__(self, key: str) -> _T:
try:
return self._data[key]
except KeyError:
raise AttributeError(key)
def __contains__(self, key: str) -> bool:
return key in self._data
def as_readonly(self) -> ReadOnlyProperties[_T]:
"""Return an immutable proxy for this :class:`.Properties`."""
return ReadOnlyProperties(self._data)
def update(self, value: Dict[str, _T]) -> None:
self._data.update(value)
@overload
def get(self, key: str) -> Optional[_T]: ...
@overload
def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]: ...
def get(
self, key: str, default: Optional[Union[_DT, _T]] = None
) -> Optional[Union[_T, _DT]]:
if key in self:
return self[key]
else:
return default
def keys(self) -> List[str]:
return list(self._data)
def values(self) -> List[_T]:
return list(self._data.values())
def items(self) -> List[Tuple[str, _T]]:
return list(self._data.items())
def has_key(self, key: str) -> bool:
return key in self._data
def clear(self) -> None:
self._data.clear()
class OrderedProperties(Properties[_T]):
"""Provide a __getattr__/__setattr__ interface with an OrderedDict
as backing store."""
__slots__ = ()
def __init__(self):
Properties.__init__(self, OrderedDict())
class ReadOnlyProperties(ReadOnlyContainer, Properties[_T]):
"""Provide immutable dict/object attribute to an underlying dictionary."""
__slots__ = ()
def _ordered_dictionary_sort(d, key=None):
"""Sort an OrderedDict in-place."""
items = [(k, d[k]) for k in sorted(d, key=key)]
d.clear()
d.update(items)
OrderedDict = dict
sort_dictionary = _ordered_dictionary_sort
class WeakSequence(Sequence[_T]):
def __init__(self, __elements: Sequence[_T] = ()):
# adapted from weakref.WeakKeyDictionary, prevent reference
# cycles in the collection itself
def _remove(item, selfref=weakref.ref(self)):
self = selfref()
if self is not None:
self._storage.remove(item)
self._remove = _remove
self._storage = [
weakref.ref(element, _remove) for element in __elements
]
def append(self, item):
self._storage.append(weakref.ref(item, self._remove))
def __len__(self):
return len(self._storage)
def __iter__(self):
return (
obj for obj in (ref() for ref in self._storage) if obj is not None
)
def __getitem__(self, index):
try:
obj = self._storage[index]
except KeyError:
raise IndexError("Index %s out of range" % index)
else:
return obj()
class OrderedIdentitySet(IdentitySet):
def __init__(self, iterable: Optional[Iterable[Any]] = None):
IdentitySet.__init__(self)
self._members = OrderedDict()
if iterable:
for o in iterable:
self.add(o)
class PopulateDict(Dict[_KT, _VT]):
"""A dict which populates missing values via a creation function.
Note the creation function takes a key, unlike
collections.defaultdict.
"""
def __init__(self, creator: Callable[[_KT], _VT]):
self.creator = creator
def __missing__(self, key: Any) -> Any:
self[key] = val = self.creator(key)
return val
class WeakPopulateDict(Dict[_KT, _VT]):
"""Like PopulateDict, but assumes a self + a method and does not create
a reference cycle.
"""
def __init__(self, creator_method: types.MethodType):
self.creator = creator_method.__func__
weakself = creator_method.__self__
self.weakself = weakref.ref(weakself)
def __missing__(self, key: Any) -> Any:
self[key] = val = self.creator(self.weakself(), key)
return val
# Define collections that are capable of storing
# ColumnElement objects as hashable keys/elements.
# At this point, these are mostly historical, things
# used to be more complicated.
column_set = set
column_dict = dict
ordered_column_set = OrderedSet
class UniqueAppender(Generic[_T]):
"""Appends items to a collection ensuring uniqueness.
Additional appends() of the same object are ignored. Membership is
determined by identity (``is a``) not equality (``==``).
"""
__slots__ = "data", "_data_appender", "_unique"
data: Union[Iterable[_T], Set[_T], List[_T]]
_data_appender: Callable[[_T], None]
_unique: Dict[int, Literal[True]]
def __init__(
self,
data: Union[Iterable[_T], Set[_T], List[_T]],
via: Optional[str] = None,
):
self.data = data
self._unique = {}
if via:
self._data_appender = getattr(data, via)
elif hasattr(data, "append"):
self._data_appender = cast("List[_T]", data).append
elif hasattr(data, "add"):
self._data_appender = cast("Set[_T]", data).add
def append(self, item: _T) -> None:
id_ = id(item)
if id_ not in self._unique:
self._data_appender(item)
self._unique[id_] = True
def __iter__(self) -> Iterator[_T]:
return iter(self.data)
def coerce_generator_arg(arg: Any) -> List[Any]:
if len(arg) == 1 and isinstance(arg[0], types.GeneratorType):
return list(arg[0])
else:
return cast("List[Any]", arg)
def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]:
if x is None:
return default # type: ignore
if not is_non_string_iterable(x):
return [x]
elif isinstance(x, list):
return x
else:
return list(x)
def has_intersection(set_: Container[Any], iterable: Iterable[Any]) -> bool:
r"""return True if any items of set\_ are present in iterable.
Goes through special effort to ensure __hash__ is not called
on items in iterable that don't support it.
"""
return any(i in set_ for i in iterable if i.__hash__)
def to_set(x):
if x is None:
return set()
if not isinstance(x, set):
return set(to_list(x))
else:
return x
def to_column_set(x: Any) -> Set[Any]:
if x is None:
return column_set()
if not isinstance(x, column_set):
return column_set(to_list(x))
else:
return x
def update_copy(
d: Dict[Any, Any], _new: Optional[Dict[Any, Any]] = None, **kw: Any
) -> Dict[Any, Any]:
"""Copy the given dict and update with the given values."""
d = d.copy()
if _new:
d.update(_new)
d.update(**kw)
return d
def flatten_iterator(x: Iterable[_T]) -> Iterator[_T]:
"""Given an iterator of which further sub-elements may also be
iterators, flatten the sub-elements into a single iterator.
"""
elem: _T
for elem in x:
if not isinstance(elem, str) and hasattr(elem, "__iter__"):
yield from flatten_iterator(elem)
else:
yield elem
class LRUCache(typing.MutableMapping[_KT, _VT]):
"""Dictionary with 'squishy' removal of least
recently used items.
Note that either get() or [] should be used here, but
generally its not safe to do an "in" check first as the dictionary
can change subsequent to that call.
"""
__slots__ = (
"capacity",
"threshold",
"size_alert",
"_data",
"_counter",
"_mutex",
)
capacity: int
threshold: float
size_alert: Optional[Callable[[LRUCache[_KT, _VT]], None]]
def __init__(
self,
capacity: int = 100,
threshold: float = 0.5,
size_alert: Optional[Callable[..., None]] = None,
):
self.capacity = capacity
self.threshold = threshold
self.size_alert = size_alert
self._counter = 0
self._mutex = threading.Lock()
self._data: Dict[_KT, Tuple[_KT, _VT, List[int]]] = {}
def _inc_counter(self):
self._counter += 1
return self._counter
@overload
def get(self, key: _KT) -> Optional[_VT]: ...
@overload
def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ...
def get(
self, key: _KT, default: Optional[Union[_VT, _T]] = None
) -> Optional[Union[_VT, _T]]:
item = self._data.get(key)
if item is not None:
item[2][0] = self._inc_counter()
return item[1]
else:
return default
def __getitem__(self, key: _KT) -> _VT:
item = self._data[key]
item[2][0] = self._inc_counter()
return item[1]
def __iter__(self) -> Iterator[_KT]:
return iter(self._data)
def __len__(self) -> int:
return len(self._data)
def values(self) -> ValuesView[_VT]:
return typing.ValuesView({k: i[1] for k, i in self._data.items()})
def __setitem__(self, key: _KT, value: _VT) -> None:
self._data[key] = (key, value, [self._inc_counter()])
self._manage_size()
def __delitem__(self, __v: _KT) -> None:
del self._data[__v]
@property
def size_threshold(self) -> float:
return self.capacity + self.capacity * self.threshold
def _manage_size(self) -> None:
if not self._mutex.acquire(False):
return
try:
size_alert = bool(self.size_alert)
while len(self) > self.capacity + self.capacity * self.threshold:
if size_alert:
size_alert = False
self.size_alert(self) # type: ignore
by_counter = sorted(
self._data.values(),
key=operator.itemgetter(2),
reverse=True,
)
for item in by_counter[self.capacity :]:
try:
del self._data[item[0]]
except KeyError:
# deleted elsewhere; skip
continue
finally:
self._mutex.release()
class _CreateFuncType(Protocol[_T_co]):
def __call__(self) -> _T_co: ...
class _ScopeFuncType(Protocol):
def __call__(self) -> Any: ...
class ScopedRegistry(Generic[_T]):
"""A Registry that can store one or multiple instances of a single
class on the basis of a "scope" function.
The object implements ``__call__`` as the "getter", so by
calling ``myregistry()`` the contained object is returned
for the current scope.
:param createfunc:
a callable that returns a new object to be placed in the registry
:param scopefunc:
a callable that will return a key to store/retrieve an object.
"""
__slots__ = "createfunc", "scopefunc", "registry"
createfunc: _CreateFuncType[_T]
scopefunc: _ScopeFuncType
registry: Any
def __init__(
self, createfunc: Callable[[], _T], scopefunc: Callable[[], Any]
):
"""Construct a new :class:`.ScopedRegistry`.
:param createfunc: A creation function that will generate
a new value for the current scope, if none is present.
:param scopefunc: A function that returns a hashable
token representing the current scope (such as, current
thread identifier).
"""
self.createfunc = createfunc
self.scopefunc = scopefunc
self.registry = {}
def __call__(self) -> _T:
key = self.scopefunc()
try:
return self.registry[key] # type: ignore[no-any-return]
except KeyError:
return self.registry.setdefault(key, self.createfunc()) # type: ignore[no-any-return] # noqa: E501
def has(self) -> bool:
"""Return True if an object is present in the current scope."""
return self.scopefunc() in self.registry
def set(self, obj: _T) -> None:
"""Set the value for the current scope."""
self.registry[self.scopefunc()] = obj
def clear(self) -> None:
"""Clear the current scope, if any."""
try:
del self.registry[self.scopefunc()]
except KeyError:
pass
class ThreadLocalRegistry(ScopedRegistry[_T]):
"""A :class:`.ScopedRegistry` that uses a ``threading.local()``
variable for storage.
"""
def __init__(self, createfunc: Callable[[], _T]):
self.createfunc = createfunc
self.registry = threading.local()
def __call__(self) -> _T:
try:
return self.registry.value # type: ignore[no-any-return]
except AttributeError:
val = self.registry.value = self.createfunc()
return val
def has(self) -> bool:
return hasattr(self.registry, "value")
def set(self, obj: _T) -> None:
self.registry.value = obj
def clear(self) -> None:
try:
del self.registry.value
except AttributeError:
pass
def has_dupes(sequence, target):
"""Given a sequence and search object, return True if there's more
than one, False if zero or one of them.
"""
# compare to .index version below, this version introduces less function
# overhead and is usually the same speed. At 15000 items (way bigger than
# a relationship-bound collection in memory usually is) it begins to
# fall behind the other version only by microseconds.
c = 0
for item in sequence:
if item is target:
c += 1
if c > 1:
return True
return False
# .index version. the two __contains__ calls as well
# as .index() and isinstance() slow this down.
# def has_dupes(sequence, target):
# if target not in sequence:
# return False
# elif not isinstance(sequence, collections_abc.Sequence):
# return False
#
# idx = sequence.index(target)
# return target in sequence[idx + 1:]

View File

@ -0,0 +1,288 @@
# util/_concurrency_py3k.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: allow-untyped-defs, allow-untyped-calls
from __future__ import annotations
import asyncio
from contextvars import Context
import sys
import typing
from typing import Any
from typing import Awaitable
from typing import Callable
from typing import Coroutine
from typing import Optional
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from .langhelpers import memoized_property
from .. import exc
from ..util import py311
from ..util.typing import Literal
from ..util.typing import Protocol
from ..util.typing import Self
from ..util.typing import TypeGuard
_T = TypeVar("_T")
if typing.TYPE_CHECKING:
class greenlet(Protocol):
dead: bool
gr_context: Optional[Context]
def __init__(self, fn: Callable[..., Any], driver: greenlet): ...
def throw(self, *arg: Any) -> Any:
return None
def switch(self, value: Any) -> Any:
return None
def getcurrent() -> greenlet: ...
else:
from greenlet import getcurrent
from greenlet import greenlet
# If greenlet.gr_context is present in current version of greenlet,
# it will be set with the current context on creation.
# Refs: https://github.com/python-greenlet/greenlet/pull/198
_has_gr_context = hasattr(getcurrent(), "gr_context")
def is_exit_exception(e: BaseException) -> bool:
# note asyncio.CancelledError is already BaseException
# so was an exit exception in any case
return not isinstance(e, Exception) or isinstance(
e, (asyncio.TimeoutError, asyncio.CancelledError)
)
# implementation based on snaury gist at
# https://gist.github.com/snaury/202bf4f22c41ca34e56297bae5f33fef
# Issue for context: https://github.com/python-greenlet/greenlet/issues/173
class _AsyncIoGreenlet(greenlet):
dead: bool
__sqlalchemy_greenlet_provider__ = True
def __init__(self, fn: Callable[..., Any], driver: greenlet):
greenlet.__init__(self, fn, driver)
if _has_gr_context:
self.gr_context = driver.gr_context
_T_co = TypeVar("_T_co", covariant=True)
if TYPE_CHECKING:
def iscoroutine(
awaitable: Awaitable[_T_co],
) -> TypeGuard[Coroutine[Any, Any, _T_co]]: ...
else:
iscoroutine = asyncio.iscoroutine
def _safe_cancel_awaitable(awaitable: Awaitable[Any]) -> None:
# https://docs.python.org/3/reference/datamodel.html#coroutine.close
if iscoroutine(awaitable):
awaitable.close()
def in_greenlet() -> bool:
current = getcurrent()
return getattr(current, "__sqlalchemy_greenlet_provider__", False)
def await_only(awaitable: Awaitable[_T]) -> _T:
"""Awaits an async function in a sync method.
The sync method must be inside a :func:`greenlet_spawn` context.
:func:`await_only` calls cannot be nested.
:param awaitable: The coroutine to call.
"""
# this is called in the context greenlet while running fn
current = getcurrent()
if not getattr(current, "__sqlalchemy_greenlet_provider__", False):
_safe_cancel_awaitable(awaitable)
raise exc.MissingGreenlet(
"greenlet_spawn has not been called; can't call await_only() "
"here. Was IO attempted in an unexpected place?"
)
# returns the control to the driver greenlet passing it
# a coroutine to run. Once the awaitable is done, the driver greenlet
# switches back to this greenlet with the result of awaitable that is
# then returned to the caller (or raised as error)
return current.parent.switch(awaitable) # type: ignore[no-any-return,attr-defined] # noqa: E501
def await_fallback(awaitable: Awaitable[_T]) -> _T:
"""Awaits an async function in a sync method.
The sync method must be inside a :func:`greenlet_spawn` context.
:func:`await_fallback` calls cannot be nested.
:param awaitable: The coroutine to call.
.. deprecated:: 2.0.24 The ``await_fallback()`` function will be removed
in SQLAlchemy 2.1. Use :func:`_util.await_only` instead, running the
function / program / etc. within a top-level greenlet that is set up
using :func:`_util.greenlet_spawn`.
"""
# this is called in the context greenlet while running fn
current = getcurrent()
if not getattr(current, "__sqlalchemy_greenlet_provider__", False):
loop = get_event_loop()
if loop.is_running():
_safe_cancel_awaitable(awaitable)
raise exc.MissingGreenlet(
"greenlet_spawn has not been called and asyncio event "
"loop is already running; can't call await_fallback() here. "
"Was IO attempted in an unexpected place?"
)
return loop.run_until_complete(awaitable)
return current.parent.switch(awaitable) # type: ignore[no-any-return,attr-defined] # noqa: E501
async def greenlet_spawn(
fn: Callable[..., _T],
*args: Any,
_require_await: bool = False,
**kwargs: Any,
) -> _T:
"""Runs a sync function ``fn`` in a new greenlet.
The sync function can then use :func:`await_only` to wait for async
functions.
:param fn: The sync callable to call.
:param \\*args: Positional arguments to pass to the ``fn`` callable.
:param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable.
"""
result: Any
context = _AsyncIoGreenlet(fn, getcurrent())
# runs the function synchronously in gl greenlet. If the execution
# is interrupted by await_only, context is not dead and result is a
# coroutine to wait. If the context is dead the function has
# returned, and its result can be returned.
switch_occurred = False
result = context.switch(*args, **kwargs)
while not context.dead:
switch_occurred = True
try:
# wait for a coroutine from await_only and then return its
# result back to it.
value = await result
except BaseException:
# this allows an exception to be raised within
# the moderated greenlet so that it can continue
# its expected flow.
result = context.throw(*sys.exc_info())
else:
result = context.switch(value)
if _require_await and not switch_occurred:
raise exc.AwaitRequired(
"The current operation required an async execution but none was "
"detected. This will usually happen when using a non compatible "
"DBAPI driver. Please ensure that an async DBAPI is used."
)
return result # type: ignore[no-any-return]
class AsyncAdaptedLock:
@memoized_property
def mutex(self) -> asyncio.Lock:
# there should not be a race here for coroutines creating the
# new lock as we are not using await, so therefore no concurrency
return asyncio.Lock()
def __enter__(self) -> bool:
# await is used to acquire the lock only after the first calling
# coroutine has created the mutex.
return await_fallback(self.mutex.acquire())
def __exit__(self, *arg: Any, **kw: Any) -> None:
self.mutex.release()
def get_event_loop() -> asyncio.AbstractEventLoop:
"""vendor asyncio.get_event_loop() for python 3.7 and above.
Python 3.10 deprecates get_event_loop() as a standalone.
"""
try:
return asyncio.get_running_loop()
except RuntimeError:
# avoid "During handling of the above exception, another exception..."
pass
return asyncio.get_event_loop_policy().get_event_loop()
if not TYPE_CHECKING and py311:
_Runner = asyncio.Runner
else:
class _Runner:
"""Runner implementation for test only"""
_loop: Union[None, asyncio.AbstractEventLoop, Literal[False]]
def __init__(self) -> None:
self._loop = None
def __enter__(self) -> Self:
self._lazy_init()
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()
def close(self) -> None:
if self._loop:
try:
self._loop.run_until_complete(
self._loop.shutdown_asyncgens()
)
finally:
self._loop.close()
self._loop = False
def get_loop(self) -> asyncio.AbstractEventLoop:
"""Return embedded event loop."""
self._lazy_init()
assert self._loop
return self._loop
def run(self, coro: Coroutine[Any, Any, _T]) -> _T:
self._lazy_init()
assert self._loop
return self._loop.run_until_complete(coro)
def _lazy_init(self) -> None:
if self._loop is False:
raise RuntimeError("Runner is closed")
if self._loop is None:
self._loop = asyncio.new_event_loop()

View File

@ -0,0 +1,40 @@
# util/_has_cy.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
import os
import typing
def _import_cy_extensions():
# all cython extension extension modules are treated as optional by the
# setup, so to ensure that all are compiled, all should be imported here
from ..cyextension import collections
from ..cyextension import immutabledict
from ..cyextension import processors
from ..cyextension import resultproxy
from ..cyextension import util
return (collections, immutabledict, processors, resultproxy, util)
_CYEXTENSION_MSG: str
if not typing.TYPE_CHECKING:
if os.environ.get("DISABLE_SQLALCHEMY_CEXT_RUNTIME"):
HAS_CYEXTENSION = False
_CYEXTENSION_MSG = "DISABLE_SQLALCHEMY_CEXT_RUNTIME is set"
else:
try:
_import_cy_extensions()
except ImportError as err:
HAS_CYEXTENSION = False
_CYEXTENSION_MSG = str(err)
else:
_CYEXTENSION_MSG = "Loaded"
HAS_CYEXTENSION = True
else:
HAS_CYEXTENSION = False

View File

@ -0,0 +1,541 @@
# util/_py_collections.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: allow-untyped-defs, allow-untyped-calls
from __future__ import annotations
from itertools import filterfalse
from typing import AbstractSet
from typing import Any
from typing import Callable
from typing import cast
from typing import Collection
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Mapping
from typing import NoReturn
from typing import Optional
from typing import Set
from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from ..util.typing import Self
_T = TypeVar("_T", bound=Any)
_S = TypeVar("_S", bound=Any)
_KT = TypeVar("_KT", bound=Any)
_VT = TypeVar("_VT", bound=Any)
class ReadOnlyContainer:
__slots__ = ()
def _readonly(self, *arg: Any, **kw: Any) -> NoReturn:
raise TypeError(
"%s object is immutable and/or readonly" % self.__class__.__name__
)
def _immutable(self, *arg: Any, **kw: Any) -> NoReturn:
raise TypeError("%s object is immutable" % self.__class__.__name__)
def __delitem__(self, key: Any) -> NoReturn:
self._readonly()
def __setitem__(self, key: Any, value: Any) -> NoReturn:
self._readonly()
def __setattr__(self, key: str, value: Any) -> NoReturn:
self._readonly()
class ImmutableDictBase(ReadOnlyContainer, Dict[_KT, _VT]):
if TYPE_CHECKING:
def __new__(cls, *args: Any) -> Self: ...
def __init__(cls, *args: Any): ...
def _readonly(self, *arg: Any, **kw: Any) -> NoReturn:
self._immutable()
def clear(self) -> NoReturn:
self._readonly()
def pop(self, key: Any, default: Optional[Any] = None) -> NoReturn:
self._readonly()
def popitem(self) -> NoReturn:
self._readonly()
def setdefault(self, key: Any, default: Optional[Any] = None) -> NoReturn:
self._readonly()
def update(self, *arg: Any, **kw: Any) -> NoReturn:
self._readonly()
class immutabledict(ImmutableDictBase[_KT, _VT]):
def __new__(cls, *args):
new = ImmutableDictBase.__new__(cls)
dict.__init__(new, *args)
return new
def __init__(
self, *args: Union[Mapping[_KT, _VT], Iterable[Tuple[_KT, _VT]]]
):
pass
def __reduce__(self):
return immutabledict, (dict(self),)
def union(
self, __d: Optional[Mapping[_KT, _VT]] = None
) -> immutabledict[_KT, _VT]:
if not __d:
return self
new = ImmutableDictBase.__new__(self.__class__)
dict.__init__(new, self)
dict.update(new, __d) # type: ignore
return new
def _union_w_kw(
self, __d: Optional[Mapping[_KT, _VT]] = None, **kw: _VT
) -> immutabledict[_KT, _VT]:
# not sure if C version works correctly w/ this yet
if not __d and not kw:
return self
new = ImmutableDictBase.__new__(self.__class__)
dict.__init__(new, self)
if __d:
dict.update(new, __d) # type: ignore
dict.update(new, kw) # type: ignore
return new
def merge_with(
self, *dicts: Optional[Mapping[_KT, _VT]]
) -> immutabledict[_KT, _VT]:
new = None
for d in dicts:
if d:
if new is None:
new = ImmutableDictBase.__new__(self.__class__)
dict.__init__(new, self)
dict.update(new, d) # type: ignore
if new is None:
return self
return new
def __repr__(self) -> str:
return "immutabledict(%s)" % dict.__repr__(self)
# PEP 584
def __ior__(self, __value: Any) -> NoReturn: # type: ignore
self._readonly()
def __or__( # type: ignore[override]
self, __value: Mapping[_KT, _VT]
) -> immutabledict[_KT, _VT]:
return immutabledict(
super().__or__(__value), # type: ignore[call-overload]
)
def __ror__( # type: ignore[override]
self, __value: Mapping[_KT, _VT]
) -> immutabledict[_KT, _VT]:
return immutabledict(
super().__ror__(__value), # type: ignore[call-overload]
)
class OrderedSet(Set[_T]):
__slots__ = ("_list",)
_list: List[_T]
def __init__(self, d: Optional[Iterable[_T]] = None) -> None:
if d is not None:
self._list = unique_list(d)
super().update(self._list)
else:
self._list = []
def copy(self) -> OrderedSet[_T]:
cp = self.__class__()
cp._list = self._list.copy()
set.update(cp, cp._list)
return cp
def add(self, element: _T) -> None:
if element not in self:
self._list.append(element)
super().add(element)
def remove(self, element: _T) -> None:
super().remove(element)
self._list.remove(element)
def pop(self) -> _T:
try:
value = self._list.pop()
except IndexError:
raise KeyError("pop from an empty set") from None
super().remove(value)
return value
def insert(self, pos: int, element: _T) -> None:
if element not in self:
self._list.insert(pos, element)
super().add(element)
def discard(self, element: _T) -> None:
if element in self:
self._list.remove(element)
super().remove(element)
def clear(self) -> None:
super().clear()
self._list = []
def __getitem__(self, key: int) -> _T:
return self._list[key]
def __iter__(self) -> Iterator[_T]:
return iter(self._list)
def __add__(self, other: Iterator[_T]) -> OrderedSet[_T]:
return self.union(other)
def __repr__(self) -> str:
return "%s(%r)" % (self.__class__.__name__, self._list)
__str__ = __repr__
def update(self, *iterables: Iterable[_T]) -> None:
for iterable in iterables:
for e in iterable:
if e not in self:
self._list.append(e)
super().add(e)
def __ior__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
self.update(other)
return self
def union(self, *other: Iterable[_S]) -> OrderedSet[Union[_T, _S]]:
result: OrderedSet[Union[_T, _S]] = self.copy()
result.update(*other)
return result
def __or__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
return self.union(other)
def intersection(self, *other: Iterable[Any]) -> OrderedSet[_T]:
other_set: Set[Any] = set()
other_set.update(*other)
return self.__class__(a for a in self if a in other_set)
def __and__(self, other: AbstractSet[object]) -> OrderedSet[_T]:
return self.intersection(other)
def symmetric_difference(self, other: Iterable[_T]) -> OrderedSet[_T]:
collection: Collection[_T]
if isinstance(other, set):
collection = other_set = other
elif isinstance(other, Collection):
collection = other
other_set = set(other)
else:
collection = list(other)
other_set = set(collection)
result = self.__class__(a for a in self if a not in other_set)
result.update(a for a in collection if a not in self)
return result
def __xor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
return cast(OrderedSet[Union[_T, _S]], self).symmetric_difference(
other
)
def difference(self, *other: Iterable[Any]) -> OrderedSet[_T]:
other_set = super().difference(*other)
return self.__class__(a for a in self._list if a in other_set)
def __sub__(self, other: AbstractSet[Optional[_T]]) -> OrderedSet[_T]:
return self.difference(other)
def intersection_update(self, *other: Iterable[Any]) -> None:
super().intersection_update(*other)
self._list = [a for a in self._list if a in self]
def __iand__(self, other: AbstractSet[object]) -> OrderedSet[_T]:
self.intersection_update(other)
return self
def symmetric_difference_update(self, other: Iterable[Any]) -> None:
collection = other if isinstance(other, Collection) else list(other)
super().symmetric_difference_update(collection)
self._list = [a for a in self._list if a in self]
self._list += [a for a in collection if a in self]
def __ixor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
self.symmetric_difference_update(other)
return cast(OrderedSet[Union[_T, _S]], self)
def difference_update(self, *other: Iterable[Any]) -> None:
super().difference_update(*other)
self._list = [a for a in self._list if a in self]
def __isub__(self, other: AbstractSet[Optional[_T]]) -> OrderedSet[_T]: # type: ignore # noqa: E501
self.difference_update(other)
return self
class IdentitySet:
"""A set that considers only object id() for uniqueness.
This strategy has edge cases for builtin types- it's possible to have
two 'foo' strings in one of these sets, for example. Use sparingly.
"""
_members: Dict[int, Any]
def __init__(self, iterable: Optional[Iterable[Any]] = None):
self._members = dict()
if iterable:
self.update(iterable)
def add(self, value: Any) -> None:
self._members[id(value)] = value
def __contains__(self, value: Any) -> bool:
return id(value) in self._members
def remove(self, value: Any) -> None:
del self._members[id(value)]
def discard(self, value: Any) -> None:
try:
self.remove(value)
except KeyError:
pass
def pop(self) -> Any:
try:
pair = self._members.popitem()
return pair[1]
except KeyError:
raise KeyError("pop from an empty set")
def clear(self) -> None:
self._members.clear()
def __eq__(self, other: Any) -> bool:
if isinstance(other, IdentitySet):
return self._members == other._members
else:
return False
def __ne__(self, other: Any) -> bool:
if isinstance(other, IdentitySet):
return self._members != other._members
else:
return True
def issubset(self, iterable: Iterable[Any]) -> bool:
if isinstance(iterable, self.__class__):
other = iterable
else:
other = self.__class__(iterable)
if len(self) > len(other):
return False
for m in filterfalse(
other._members.__contains__, iter(self._members.keys())
):
return False
return True
def __le__(self, other: Any) -> bool:
if not isinstance(other, IdentitySet):
return NotImplemented
return self.issubset(other)
def __lt__(self, other: Any) -> bool:
if not isinstance(other, IdentitySet):
return NotImplemented
return len(self) < len(other) and self.issubset(other)
def issuperset(self, iterable: Iterable[Any]) -> bool:
if isinstance(iterable, self.__class__):
other = iterable
else:
other = self.__class__(iterable)
if len(self) < len(other):
return False
for m in filterfalse(
self._members.__contains__, iter(other._members.keys())
):
return False
return True
def __ge__(self, other: Any) -> bool:
if not isinstance(other, IdentitySet):
return NotImplemented
return self.issuperset(other)
def __gt__(self, other: Any) -> bool:
if not isinstance(other, IdentitySet):
return NotImplemented
return len(self) > len(other) and self.issuperset(other)
def union(self, iterable: Iterable[Any]) -> IdentitySet:
result = self.__class__()
members = self._members
result._members.update(members)
result._members.update((id(obj), obj) for obj in iterable)
return result
def __or__(self, other: Any) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
return self.union(other)
def update(self, iterable: Iterable[Any]) -> None:
self._members.update((id(obj), obj) for obj in iterable)
def __ior__(self, other: Any) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
self.update(other)
return self
def difference(self, iterable: Iterable[Any]) -> IdentitySet:
result = self.__new__(self.__class__)
other: Collection[Any]
if isinstance(iterable, self.__class__):
other = iterable._members
else:
other = {id(obj) for obj in iterable}
result._members = {
k: v for k, v in self._members.items() if k not in other
}
return result
def __sub__(self, other: IdentitySet) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
return self.difference(other)
def difference_update(self, iterable: Iterable[Any]) -> None:
self._members = self.difference(iterable)._members
def __isub__(self, other: IdentitySet) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
self.difference_update(other)
return self
def intersection(self, iterable: Iterable[Any]) -> IdentitySet:
result = self.__new__(self.__class__)
other: Collection[Any]
if isinstance(iterable, self.__class__):
other = iterable._members
else:
other = {id(obj) for obj in iterable}
result._members = {
k: v for k, v in self._members.items() if k in other
}
return result
def __and__(self, other: IdentitySet) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
return self.intersection(other)
def intersection_update(self, iterable: Iterable[Any]) -> None:
self._members = self.intersection(iterable)._members
def __iand__(self, other: IdentitySet) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
self.intersection_update(other)
return self
def symmetric_difference(self, iterable: Iterable[Any]) -> IdentitySet:
result = self.__new__(self.__class__)
if isinstance(iterable, self.__class__):
other = iterable._members
else:
other = {id(obj): obj for obj in iterable}
result._members = {
k: v for k, v in self._members.items() if k not in other
}
result._members.update(
(k, v) for k, v in other.items() if k not in self._members
)
return result
def __xor__(self, other: IdentitySet) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
return self.symmetric_difference(other)
def symmetric_difference_update(self, iterable: Iterable[Any]) -> None:
self._members = self.symmetric_difference(iterable)._members
def __ixor__(self, other: IdentitySet) -> IdentitySet:
if not isinstance(other, IdentitySet):
return NotImplemented
self.symmetric_difference(other)
return self
def copy(self) -> IdentitySet:
result = self.__new__(self.__class__)
result._members = self._members.copy()
return result
__copy__ = copy
def __len__(self) -> int:
return len(self._members)
def __iter__(self) -> Iterator[Any]:
return iter(self._members.values())
def __hash__(self) -> NoReturn:
raise TypeError("set objects are unhashable")
def __repr__(self) -> str:
return "%s(%r)" % (type(self).__name__, list(self._members.values()))
def unique_list(
seq: Iterable[_T], hashfunc: Optional[Callable[[_T], int]] = None
) -> List[_T]:
seen: Set[Any] = set()
seen_add = seen.add
if not hashfunc:
return [x for x in seq if x not in seen and not seen_add(x)]
else:
return [
x
for x in seq
if hashfunc(x) not in seen and not seen_add(hashfunc(x))
]

View File

@ -0,0 +1,301 @@
# util/compat.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: allow-untyped-defs, allow-untyped-calls
"""Handle Python version/platform incompatibilities."""
from __future__ import annotations
import base64
import dataclasses
import hashlib
import inspect
import operator
import platform
import sys
import typing
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Type
from typing import TypeVar
py313 = sys.version_info >= (3, 13)
py312 = sys.version_info >= (3, 12)
py311 = sys.version_info >= (3, 11)
py310 = sys.version_info >= (3, 10)
py39 = sys.version_info >= (3, 9)
py38 = sys.version_info >= (3, 8)
pypy = platform.python_implementation() == "PyPy"
cpython = platform.python_implementation() == "CPython"
win32 = sys.platform.startswith("win")
osx = sys.platform.startswith("darwin")
arm = "aarch" in platform.machine().lower()
is64bit = sys.maxsize > 2**32
has_refcount_gc = bool(cpython)
dottedgetter = operator.attrgetter
_T_co = TypeVar("_T_co", covariant=True)
class FullArgSpec(typing.NamedTuple):
args: List[str]
varargs: Optional[str]
varkw: Optional[str]
defaults: Optional[Tuple[Any, ...]]
kwonlyargs: List[str]
kwonlydefaults: Optional[Dict[str, Any]]
annotations: Dict[str, Any]
def inspect_getfullargspec(func: Callable[..., Any]) -> FullArgSpec:
"""Fully vendored version of getfullargspec from Python 3.3."""
if inspect.ismethod(func):
func = func.__func__
if not inspect.isfunction(func):
raise TypeError(f"{func!r} is not a Python function")
co = func.__code__
if not inspect.iscode(co):
raise TypeError(f"{co!r} is not a code object")
nargs = co.co_argcount
names = co.co_varnames
nkwargs = co.co_kwonlyargcount
args = list(names[:nargs])
kwonlyargs = list(names[nargs : nargs + nkwargs])
nargs += nkwargs
varargs = None
if co.co_flags & inspect.CO_VARARGS:
varargs = co.co_varnames[nargs]
nargs = nargs + 1
varkw = None
if co.co_flags & inspect.CO_VARKEYWORDS:
varkw = co.co_varnames[nargs]
return FullArgSpec(
args,
varargs,
varkw,
func.__defaults__,
kwonlyargs,
func.__kwdefaults__,
func.__annotations__,
)
if py39:
# python stubs don't have a public type for this. not worth
# making a protocol
def md5_not_for_security() -> Any:
return hashlib.md5(usedforsecurity=False)
else:
def md5_not_for_security() -> Any:
return hashlib.md5()
if typing.TYPE_CHECKING or py38:
from importlib import metadata as importlib_metadata
else:
import importlib_metadata # noqa
if typing.TYPE_CHECKING or py39:
# pep 584 dict union
dict_union = operator.or_ # noqa
else:
def dict_union(a: dict, b: dict) -> dict:
a = a.copy()
a.update(b)
return a
if py310:
anext_ = anext
else:
_NOT_PROVIDED = object()
from collections.abc import AsyncIterator
async def anext_(async_iterator, default=_NOT_PROVIDED):
"""vendored from https://github.com/python/cpython/pull/8895"""
if not isinstance(async_iterator, AsyncIterator):
raise TypeError(
f"anext expected an AsyncIterator, got {type(async_iterator)}"
)
anxt = type(async_iterator).__anext__
try:
return await anxt(async_iterator)
except StopAsyncIteration:
if default is _NOT_PROVIDED:
raise
return default
def importlib_metadata_get(group):
ep = importlib_metadata.entry_points()
if typing.TYPE_CHECKING or hasattr(ep, "select"):
return ep.select(group=group)
else:
return ep.get(group, ())
def b(s):
return s.encode("latin-1")
def b64decode(x: str) -> bytes:
return base64.b64decode(x.encode("ascii"))
def b64encode(x: bytes) -> str:
return base64.b64encode(x).decode("ascii")
def decode_backslashreplace(text: bytes, encoding: str) -> str:
return text.decode(encoding, errors="backslashreplace")
def cmp(a, b):
return (a > b) - (a < b)
def _formatannotation(annotation, base_module=None):
"""vendored from python 3.7"""
if isinstance(annotation, str):
return annotation
if getattr(annotation, "__module__", None) == "typing":
return repr(annotation).replace("typing.", "").replace("~", "")
if isinstance(annotation, type):
if annotation.__module__ in ("builtins", base_module):
return repr(annotation.__qualname__)
return annotation.__module__ + "." + annotation.__qualname__
elif isinstance(annotation, typing.TypeVar):
return repr(annotation).replace("~", "")
return repr(annotation).replace("~", "")
def inspect_formatargspec(
args: List[str],
varargs: Optional[str] = None,
varkw: Optional[str] = None,
defaults: Optional[Sequence[Any]] = None,
kwonlyargs: Optional[Sequence[str]] = (),
kwonlydefaults: Optional[Mapping[str, Any]] = {},
annotations: Mapping[str, Any] = {},
formatarg: Callable[[str], str] = str,
formatvarargs: Callable[[str], str] = lambda name: "*" + name,
formatvarkw: Callable[[str], str] = lambda name: "**" + name,
formatvalue: Callable[[Any], str] = lambda value: "=" + repr(value),
formatreturns: Callable[[Any], str] = lambda text: " -> " + str(text),
formatannotation: Callable[[Any], str] = _formatannotation,
) -> str:
"""Copy formatargspec from python 3.7 standard library.
Python 3 has deprecated formatargspec and requested that Signature
be used instead, however this requires a full reimplementation
of formatargspec() in terms of creating Parameter objects and such.
Instead of introducing all the object-creation overhead and having
to reinvent from scratch, just copy their compatibility routine.
Ultimately we would need to rewrite our "decorator" routine completely
which is not really worth it right now, until all Python 2.x support
is dropped.
"""
kwonlydefaults = kwonlydefaults or {}
annotations = annotations or {}
def formatargandannotation(arg):
result = formatarg(arg)
if arg in annotations:
result += ": " + formatannotation(annotations[arg])
return result
specs = []
if defaults:
firstdefault = len(args) - len(defaults)
else:
firstdefault = -1
for i, arg in enumerate(args):
spec = formatargandannotation(arg)
if defaults and i >= firstdefault:
spec = spec + formatvalue(defaults[i - firstdefault])
specs.append(spec)
if varargs is not None:
specs.append(formatvarargs(formatargandannotation(varargs)))
else:
if kwonlyargs:
specs.append("*")
if kwonlyargs:
for kwonlyarg in kwonlyargs:
spec = formatargandannotation(kwonlyarg)
if kwonlydefaults and kwonlyarg in kwonlydefaults:
spec += formatvalue(kwonlydefaults[kwonlyarg])
specs.append(spec)
if varkw is not None:
specs.append(formatvarkw(formatargandannotation(varkw)))
result = "(" + ", ".join(specs) + ")"
if "return" in annotations:
result += formatreturns(formatannotation(annotations["return"]))
return result
def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
"""Return a sequence of all dataclasses.Field objects associated
with a class as an already processed dataclass.
The class must **already be a dataclass** for Field objects to be returned.
"""
if dataclasses.is_dataclass(cls):
return dataclasses.fields(cls)
else:
return []
def local_dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
"""Return a sequence of all dataclasses.Field objects associated with
an already processed dataclass, excluding those that originate from a
superclass.
The class must **already be a dataclass** for Field objects to be returned.
"""
if dataclasses.is_dataclass(cls):
super_fields: Set[dataclasses.Field[Any]] = set()
for sup in cls.__bases__:
super_fields.update(dataclass_fields(sup))
return [f for f in dataclasses.fields(cls) if f not in super_fields]
else:
return []

View File

@ -0,0 +1,108 @@
# util/concurrency.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: allow-untyped-defs, allow-untyped-calls
from __future__ import annotations
import asyncio # noqa
import typing
from typing import Any
from typing import Callable
from typing import Coroutine
from typing import TypeVar
have_greenlet = False
greenlet_error = None
try:
import greenlet # type: ignore[import-untyped,unused-ignore] # noqa: F401,E501
except ImportError as e:
greenlet_error = str(e)
pass
else:
have_greenlet = True
from ._concurrency_py3k import await_only as await_only
from ._concurrency_py3k import await_fallback as await_fallback
from ._concurrency_py3k import in_greenlet as in_greenlet
from ._concurrency_py3k import greenlet_spawn as greenlet_spawn
from ._concurrency_py3k import is_exit_exception as is_exit_exception
from ._concurrency_py3k import AsyncAdaptedLock as AsyncAdaptedLock
from ._concurrency_py3k import _Runner
_T = TypeVar("_T")
class _AsyncUtil:
"""Asyncio util for test suite/ util only"""
def __init__(self) -> None:
if have_greenlet:
self.runner = _Runner()
def run(
self,
fn: Callable[..., Coroutine[Any, Any, _T]],
*args: Any,
**kwargs: Any,
) -> _T:
"""Run coroutine on the loop"""
return self.runner.run(fn(*args, **kwargs))
def run_in_greenlet(
self, fn: Callable[..., _T], *args: Any, **kwargs: Any
) -> _T:
"""Run sync function in greenlet. Support nested calls"""
if have_greenlet:
if self.runner.get_loop().is_running():
return fn(*args, **kwargs)
else:
return self.runner.run(greenlet_spawn(fn, *args, **kwargs))
else:
return fn(*args, **kwargs)
def close(self) -> None:
if have_greenlet:
self.runner.close()
if not typing.TYPE_CHECKING and not have_greenlet:
def _not_implemented():
# this conditional is to prevent pylance from considering
# greenlet_spawn() etc as "no return" and dimming out code below it
if have_greenlet:
return None
raise ValueError(
"the greenlet library is required to use this function."
" %s" % greenlet_error
if greenlet_error
else ""
)
def is_exit_exception(e): # noqa: F811
return not isinstance(e, Exception)
def await_only(thing): # type: ignore # noqa: F811
_not_implemented()
def await_fallback(thing): # type: ignore # noqa: F811
return thing
def in_greenlet(): # type: ignore # noqa: F811
_not_implemented()
def greenlet_spawn(fn, *args, **kw): # type: ignore # noqa: F811
_not_implemented()
def AsyncAdaptedLock(*args, **kw): # type: ignore # noqa: F811
_not_implemented()
def _util_async_run(fn, *arg, **kw): # type: ignore # noqa: F811
return fn(*arg, **kw)
def _util_async_run_coroutine_function(fn, *arg, **kw): # type: ignore # noqa: F811,E501
_not_implemented()

View File

@ -0,0 +1,401 @@
# util/deprecations.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: allow-untyped-defs, allow-untyped-calls
"""Helpers related to deprecation of functions, methods, classes, other
functionality."""
from __future__ import annotations
import re
from typing import Any
from typing import Callable
from typing import Dict
from typing import Match
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Type
from typing import TypeVar
from typing import Union
from . import compat
from .langhelpers import _hash_limit_string
from .langhelpers import _warnings_warn
from .langhelpers import decorator
from .langhelpers import inject_docstring_text
from .langhelpers import inject_param_text
from .. import exc
_T = TypeVar("_T", bound=Any)
# https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators
_F = TypeVar("_F", bound="Callable[..., Any]")
def _warn_with_version(
msg: str,
version: str,
type_: Type[exc.SADeprecationWarning],
stacklevel: int,
code: Optional[str] = None,
) -> None:
warn = type_(msg, code=code)
warn.deprecated_since = version
_warnings_warn(warn, stacklevel=stacklevel + 1)
def warn_deprecated(
msg: str, version: str, stacklevel: int = 3, code: Optional[str] = None
) -> None:
_warn_with_version(
msg, version, exc.SADeprecationWarning, stacklevel, code=code
)
def warn_deprecated_limited(
msg: str,
args: Sequence[Any],
version: str,
stacklevel: int = 3,
code: Optional[str] = None,
) -> None:
"""Issue a deprecation warning with a parameterized string,
limiting the number of registrations.
"""
if args:
msg = _hash_limit_string(msg, 10, args)
_warn_with_version(
msg, version, exc.SADeprecationWarning, stacklevel, code=code
)
def deprecated_cls(
version: str, message: str, constructor: Optional[str] = "__init__"
) -> Callable[[Type[_T]], Type[_T]]:
header = ".. deprecated:: %s %s" % (version, (message or ""))
def decorate(cls: Type[_T]) -> Type[_T]:
return _decorate_cls_with_warning(
cls,
constructor,
exc.SADeprecationWarning,
message % dict(func=constructor),
version,
header,
)
return decorate
def deprecated(
version: str,
message: Optional[str] = None,
add_deprecation_to_docstring: bool = True,
warning: Optional[Type[exc.SADeprecationWarning]] = None,
enable_warnings: bool = True,
) -> Callable[[_F], _F]:
"""Decorates a function and issues a deprecation warning on use.
:param version:
Issue version in the warning.
:param message:
If provided, issue message in the warning. A sensible default
is used if not provided.
:param add_deprecation_to_docstring:
Default True. If False, the wrapped function's __doc__ is left
as-is. If True, the 'message' is prepended to the docs if
provided, or sensible default if message is omitted.
"""
if add_deprecation_to_docstring:
header = ".. deprecated:: %s %s" % (
version,
(message or ""),
)
else:
header = None
if message is None:
message = "Call to deprecated function %(func)s"
if warning is None:
warning = exc.SADeprecationWarning
message += " (deprecated since: %s)" % version
def decorate(fn: _F) -> _F:
assert message is not None
assert warning is not None
return _decorate_with_warning(
fn,
warning,
message % dict(func=fn.__name__),
version,
header,
enable_warnings=enable_warnings,
)
return decorate
def moved_20(
message: str, **kw: Any
) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
return deprecated(
"2.0", message=message, warning=exc.MovedIn20Warning, **kw
)
def became_legacy_20(
api_name: str, alternative: Optional[str] = None, **kw: Any
) -> Callable[[_F], _F]:
type_reg = re.match("^:(attr|func|meth):", api_name)
if type_reg:
type_ = {"attr": "attribute", "func": "function", "meth": "method"}[
type_reg.group(1)
]
else:
type_ = "construct"
message = (
"The %s %s is considered legacy as of the "
"1.x series of SQLAlchemy and %s in 2.0."
% (
api_name,
type_,
"becomes a legacy construct",
)
)
if ":attr:" in api_name:
attribute_ok = kw.pop("warn_on_attribute_access", False)
if not attribute_ok:
assert kw.get("enable_warnings") is False, (
"attribute %s will emit a warning on read access. "
"If you *really* want this, "
"add warn_on_attribute_access=True. Otherwise please add "
"enable_warnings=False." % api_name
)
if alternative:
message += " " + alternative
warning_cls = exc.LegacyAPIWarning
return deprecated("2.0", message=message, warning=warning_cls, **kw)
def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_F], _F]:
"""Decorates a function to warn on use of certain parameters.
e.g. ::
@deprecated_params(
weak_identity_map=(
"0.7",
"the :paramref:`.Session.weak_identity_map parameter "
"is deprecated.",
)
)
def some_function(**kwargs): ...
"""
messages: Dict[str, str] = {}
versions: Dict[str, str] = {}
version_warnings: Dict[str, Type[exc.SADeprecationWarning]] = {}
for param, (version, message) in specs.items():
versions[param] = version
messages[param] = _sanitize_restructured_text(message)
version_warnings[param] = exc.SADeprecationWarning
def decorate(fn: _F) -> _F:
spec = compat.inspect_getfullargspec(fn)
check_defaults: Union[Set[str], Tuple[()]]
if spec.defaults is not None:
defaults = dict(
zip(
spec.args[(len(spec.args) - len(spec.defaults)) :],
spec.defaults,
)
)
check_defaults = set(defaults).intersection(messages)
check_kw = set(messages).difference(defaults)
elif spec.kwonlydefaults is not None:
defaults = spec.kwonlydefaults
check_defaults = set(defaults).intersection(messages)
check_kw = set(messages).difference(defaults)
else:
check_defaults = ()
check_kw = set(messages)
check_any_kw = spec.varkw
# latest mypy has opinions here, not sure if they implemented
# Concatenate or something
@decorator
def warned(fn: _F, *args: Any, **kwargs: Any) -> _F:
for m in check_defaults:
if (defaults[m] is None and kwargs[m] is not None) or (
defaults[m] is not None and kwargs[m] != defaults[m]
):
_warn_with_version(
messages[m],
versions[m],
version_warnings[m],
stacklevel=3,
)
if check_any_kw in messages and set(kwargs).difference(
check_defaults
):
assert check_any_kw is not None
_warn_with_version(
messages[check_any_kw],
versions[check_any_kw],
version_warnings[check_any_kw],
stacklevel=3,
)
for m in check_kw:
if m in kwargs:
_warn_with_version(
messages[m],
versions[m],
version_warnings[m],
stacklevel=3,
)
return fn(*args, **kwargs) # type: ignore[no-any-return]
doc = fn.__doc__ is not None and fn.__doc__ or ""
if doc:
doc = inject_param_text(
doc,
{
param: ".. deprecated:: %s %s"
% ("1.4" if version == "2.0" else version, (message or ""))
for param, (version, message) in specs.items()
},
)
decorated = warned(fn)
decorated.__doc__ = doc
return decorated
return decorate
def _sanitize_restructured_text(text: str) -> str:
def repl(m: Match[str]) -> str:
type_, name = m.group(1, 2)
if type_ in ("func", "meth"):
name += "()"
return name
text = re.sub(r":ref:`(.+) <.*>`", lambda m: '"%s"' % m.group(1), text)
return re.sub(r"\:(\w+)\:`~?(?:_\w+)?\.?(.+?)`", repl, text)
def _decorate_cls_with_warning(
cls: Type[_T],
constructor: Optional[str],
wtype: Type[exc.SADeprecationWarning],
message: str,
version: str,
docstring_header: Optional[str] = None,
) -> Type[_T]:
doc = cls.__doc__ is not None and cls.__doc__ or ""
if docstring_header is not None:
if constructor is not None:
docstring_header %= dict(func=constructor)
if issubclass(wtype, exc.Base20DeprecationWarning):
docstring_header += (
" (Background on SQLAlchemy 2.0 at: "
":ref:`migration_20_toplevel`)"
)
doc = inject_docstring_text(doc, docstring_header, 1)
constructor_fn = None
if type(cls) is type:
clsdict = dict(cls.__dict__)
clsdict["__doc__"] = doc
clsdict.pop("__dict__", None)
clsdict.pop("__weakref__", None)
cls = type(cls.__name__, cls.__bases__, clsdict)
if constructor is not None:
constructor_fn = clsdict[constructor]
else:
cls.__doc__ = doc
if constructor is not None:
constructor_fn = getattr(cls, constructor)
if constructor is not None:
assert constructor_fn is not None
assert wtype is not None
setattr(
cls,
constructor,
_decorate_with_warning(
constructor_fn, wtype, message, version, None
),
)
return cls
def _decorate_with_warning(
func: _F,
wtype: Type[exc.SADeprecationWarning],
message: str,
version: str,
docstring_header: Optional[str] = None,
enable_warnings: bool = True,
) -> _F:
"""Wrap a function with a warnings.warn and augmented docstring."""
message = _sanitize_restructured_text(message)
if issubclass(wtype, exc.Base20DeprecationWarning):
doc_only = (
" (Background on SQLAlchemy 2.0 at: "
":ref:`migration_20_toplevel`)"
)
else:
doc_only = ""
@decorator
def warned(fn: _F, *args: Any, **kwargs: Any) -> _F:
skip_warning = not enable_warnings or kwargs.pop(
"_sa_skip_warning", False
)
if not skip_warning:
_warn_with_version(message, version, wtype, stacklevel=3)
return fn(*args, **kwargs) # type: ignore[no-any-return]
doc = func.__doc__ is not None and func.__doc__ or ""
if docstring_header is not None:
docstring_header %= dict(func=func.__name__)
docstring_header += doc_only
doc = inject_docstring_text(doc, docstring_header, 1)
decorated = warned(func)
decorated.__doc__ = doc
decorated._sa_warn = lambda: _warn_with_version( # type: ignore
message, version, wtype, stacklevel=3
)
return decorated

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,150 @@
# util/preloaded.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: allow-untyped-defs, allow-untyped-calls
"""supplies the "preloaded" registry to resolve circular module imports at
runtime.
"""
from __future__ import annotations
import sys
from typing import Any
from typing import Callable
from typing import TYPE_CHECKING
from typing import TypeVar
_FN = TypeVar("_FN", bound=Callable[..., Any])
if TYPE_CHECKING:
from sqlalchemy import dialects as _dialects
from sqlalchemy import orm as _orm
from sqlalchemy.engine import cursor as _engine_cursor
from sqlalchemy.engine import default as _engine_default
from sqlalchemy.engine import reflection as _engine_reflection
from sqlalchemy.engine import result as _engine_result
from sqlalchemy.engine import url as _engine_url
from sqlalchemy.orm import attributes as _orm_attributes
from sqlalchemy.orm import base as _orm_base
from sqlalchemy.orm import clsregistry as _orm_clsregistry
from sqlalchemy.orm import decl_api as _orm_decl_api
from sqlalchemy.orm import decl_base as _orm_decl_base
from sqlalchemy.orm import dependency as _orm_dependency
from sqlalchemy.orm import descriptor_props as _orm_descriptor_props
from sqlalchemy.orm import mapperlib as _orm_mapper
from sqlalchemy.orm import properties as _orm_properties
from sqlalchemy.orm import relationships as _orm_relationships
from sqlalchemy.orm import session as _orm_session
from sqlalchemy.orm import state as _orm_state
from sqlalchemy.orm import strategies as _orm_strategies
from sqlalchemy.orm import strategy_options as _orm_strategy_options
from sqlalchemy.orm import util as _orm_util
from sqlalchemy.sql import default_comparator as _sql_default_comparator
from sqlalchemy.sql import dml as _sql_dml
from sqlalchemy.sql import elements as _sql_elements
from sqlalchemy.sql import functions as _sql_functions
from sqlalchemy.sql import naming as _sql_naming
from sqlalchemy.sql import schema as _sql_schema
from sqlalchemy.sql import selectable as _sql_selectable
from sqlalchemy.sql import sqltypes as _sql_sqltypes
from sqlalchemy.sql import traversals as _sql_traversals
from sqlalchemy.sql import util as _sql_util
# sigh, appease mypy 0.971 which does not accept imports as instance
# variables of a module
dialects = _dialects
engine_cursor = _engine_cursor
engine_default = _engine_default
engine_reflection = _engine_reflection
engine_result = _engine_result
engine_url = _engine_url
orm_clsregistry = _orm_clsregistry
orm_base = _orm_base
orm = _orm
orm_attributes = _orm_attributes
orm_decl_api = _orm_decl_api
orm_decl_base = _orm_decl_base
orm_descriptor_props = _orm_descriptor_props
orm_dependency = _orm_dependency
orm_mapper = _orm_mapper
orm_properties = _orm_properties
orm_relationships = _orm_relationships
orm_session = _orm_session
orm_strategies = _orm_strategies
orm_strategy_options = _orm_strategy_options
orm_state = _orm_state
orm_util = _orm_util
sql_default_comparator = _sql_default_comparator
sql_dml = _sql_dml
sql_elements = _sql_elements
sql_functions = _sql_functions
sql_naming = _sql_naming
sql_selectable = _sql_selectable
sql_traversals = _sql_traversals
sql_schema = _sql_schema
sql_sqltypes = _sql_sqltypes
sql_util = _sql_util
class _ModuleRegistry:
"""Registry of modules to load in a package init file.
To avoid potential thread safety issues for imports that are deferred
in a function, like https://bugs.python.org/issue38884, these modules
are added to the system module cache by importing them after the packages
has finished initialization.
A global instance is provided under the name :attr:`.preloaded`. Use
the function :func:`.preload_module` to register modules to load and
:meth:`.import_prefix` to load all the modules that start with the
given path.
While the modules are loaded in the global module cache, it's advisable
to access them using :attr:`.preloaded` to ensure that it was actually
registered. Each registered module is added to the instance ``__dict__``
in the form `<package>_<module>`, omitting ``sqlalchemy`` from the package
name. Example: ``sqlalchemy.sql.util`` becomes ``preloaded.sql_util``.
"""
def __init__(self, prefix="sqlalchemy."):
self.module_registry = set()
self.prefix = prefix
def preload_module(self, *deps: str) -> Callable[[_FN], _FN]:
"""Adds the specified modules to the list to load.
This method can be used both as a normal function and as a decorator.
No change is performed to the decorated object.
"""
self.module_registry.update(deps)
return lambda fn: fn
def import_prefix(self, path: str) -> None:
"""Resolve all the modules in the registry that start with the
specified path.
"""
for module in self.module_registry:
if self.prefix:
key = module.split(self.prefix)[-1].replace(".", "_")
else:
key = module
if (
not path or module.startswith(path)
) and key not in self.__dict__:
__import__(module, globals(), locals())
self.__dict__[key] = globals()[key] = sys.modules[module]
_reg = _ModuleRegistry()
preload_module = _reg.preload_module
import_prefix = _reg.import_prefix
# this appears to do absolutely nothing for any version of mypy
# if TYPE_CHECKING:
# def __getattr__(key: str) -> ModuleType:
# ...

View File

@ -0,0 +1,322 @@
# util/queue.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: allow-untyped-defs, allow-untyped-calls
"""An adaptation of Py2.3/2.4's Queue module which supports reentrant
behavior, using RLock instead of Lock for its mutex object. The
Queue object is used exclusively by the sqlalchemy.pool.QueuePool
class.
This is to support the connection pool's usage of weakref callbacks to return
connections to the underlying Queue, which can in extremely
rare cases be invoked within the ``get()`` method of the Queue itself,
producing a ``put()`` inside the ``get()`` and therefore a reentrant
condition.
"""
from __future__ import annotations
import asyncio
from collections import deque
import threading
from time import time as _time
import typing
from typing import Any
from typing import Awaitable
from typing import Deque
from typing import Generic
from typing import Optional
from typing import TypeVar
from .concurrency import await_fallback
from .concurrency import await_only
from .langhelpers import memoized_property
_T = TypeVar("_T", bound=Any)
__all__ = ["Empty", "Full", "Queue"]
class Empty(Exception):
"Exception raised by Queue.get(block=0)/get_nowait()."
pass
class Full(Exception):
"Exception raised by Queue.put(block=0)/put_nowait()."
pass
class QueueCommon(Generic[_T]):
maxsize: int
use_lifo: bool
def __init__(self, maxsize: int = 0, use_lifo: bool = False): ...
def empty(self) -> bool:
raise NotImplementedError()
def full(self) -> bool:
raise NotImplementedError()
def qsize(self) -> int:
raise NotImplementedError()
def put_nowait(self, item: _T) -> None:
raise NotImplementedError()
def put(
self, item: _T, block: bool = True, timeout: Optional[float] = None
) -> None:
raise NotImplementedError()
def get_nowait(self) -> _T:
raise NotImplementedError()
def get(self, block: bool = True, timeout: Optional[float] = None) -> _T:
raise NotImplementedError()
class Queue(QueueCommon[_T]):
queue: Deque[_T]
def __init__(self, maxsize: int = 0, use_lifo: bool = False):
"""Initialize a queue object with a given maximum size.
If `maxsize` is <= 0, the queue size is infinite.
If `use_lifo` is True, this Queue acts like a Stack (LIFO).
"""
self._init(maxsize)
# mutex must be held whenever the queue is mutating. All methods
# that acquire mutex must release it before returning. mutex
# is shared between the two conditions, so acquiring and
# releasing the conditions also acquires and releases mutex.
self.mutex = threading.RLock()
# Notify not_empty whenever an item is added to the queue; a
# thread waiting to get is notified then.
self.not_empty = threading.Condition(self.mutex)
# Notify not_full whenever an item is removed from the queue;
# a thread waiting to put is notified then.
self.not_full = threading.Condition(self.mutex)
# If this queue uses LIFO or FIFO
self.use_lifo = use_lifo
def qsize(self) -> int:
"""Return the approximate size of the queue (not reliable!)."""
with self.mutex:
return self._qsize()
def empty(self) -> bool:
"""Return True if the queue is empty, False otherwise (not
reliable!)."""
with self.mutex:
return self._empty()
def full(self) -> bool:
"""Return True if the queue is full, False otherwise (not
reliable!)."""
with self.mutex:
return self._full()
def put(
self, item: _T, block: bool = True, timeout: Optional[float] = None
) -> None:
"""Put an item into the queue.
If optional args `block` is True and `timeout` is None (the
default), block if necessary until a free slot is
available. If `timeout` is a positive number, it blocks at
most `timeout` seconds and raises the ``Full`` exception if no
free slot was available within that time. Otherwise (`block`
is false), put an item on the queue if a free slot is
immediately available, else raise the ``Full`` exception
(`timeout` is ignored in that case).
"""
with self.not_full:
if not block:
if self._full():
raise Full
elif timeout is None:
while self._full():
self.not_full.wait()
else:
if timeout < 0:
raise ValueError("'timeout' must be a positive number")
endtime = _time() + timeout
while self._full():
remaining = endtime - _time()
if remaining <= 0.0:
raise Full
self.not_full.wait(remaining)
self._put(item)
self.not_empty.notify()
def put_nowait(self, item: _T) -> None:
"""Put an item into the queue without blocking.
Only enqueue the item if a free slot is immediately available.
Otherwise raise the ``Full`` exception.
"""
return self.put(item, False)
def get(self, block: bool = True, timeout: Optional[float] = None) -> _T:
"""Remove and return an item from the queue.
If optional args `block` is True and `timeout` is None (the
default), block if necessary until an item is available. If
`timeout` is a positive number, it blocks at most `timeout`
seconds and raises the ``Empty`` exception if no item was
available within that time. Otherwise (`block` is false),
return an item if one is immediately available, else raise the
``Empty`` exception (`timeout` is ignored in that case).
"""
with self.not_empty:
if not block:
if self._empty():
raise Empty
elif timeout is None:
while self._empty():
self.not_empty.wait()
else:
if timeout < 0:
raise ValueError("'timeout' must be a positive number")
endtime = _time() + timeout
while self._empty():
remaining = endtime - _time()
if remaining <= 0.0:
raise Empty
self.not_empty.wait(remaining)
item = self._get()
self.not_full.notify()
return item
def get_nowait(self) -> _T:
"""Remove and return an item from the queue without blocking.
Only get an item if one is immediately available. Otherwise
raise the ``Empty`` exception.
"""
return self.get(False)
def _init(self, maxsize: int) -> None:
self.maxsize = maxsize
self.queue = deque()
def _qsize(self) -> int:
return len(self.queue)
def _empty(self) -> bool:
return not self.queue
def _full(self) -> bool:
return self.maxsize > 0 and len(self.queue) == self.maxsize
def _put(self, item: _T) -> None:
self.queue.append(item)
def _get(self) -> _T:
if self.use_lifo:
# LIFO
return self.queue.pop()
else:
# FIFO
return self.queue.popleft()
class AsyncAdaptedQueue(QueueCommon[_T]):
if typing.TYPE_CHECKING:
@staticmethod
def await_(coroutine: Awaitable[Any]) -> _T: ...
else:
await_ = staticmethod(await_only)
def __init__(self, maxsize: int = 0, use_lifo: bool = False):
self.use_lifo = use_lifo
self.maxsize = maxsize
def empty(self) -> bool:
return self._queue.empty()
def full(self):
return self._queue.full()
def qsize(self):
return self._queue.qsize()
@memoized_property
def _queue(self) -> asyncio.Queue[_T]:
# Delay creation of the queue until it is first used, to avoid
# binding it to a possibly wrong event loop.
# By delaying the creation of the pool we accommodate the common
# usage pattern of instantiating the engine at module level, where a
# different event loop is in present compared to when the application
# is actually run.
queue: asyncio.Queue[_T]
if self.use_lifo:
queue = asyncio.LifoQueue(maxsize=self.maxsize)
else:
queue = asyncio.Queue(maxsize=self.maxsize)
return queue
def put_nowait(self, item: _T) -> None:
try:
self._queue.put_nowait(item)
except asyncio.QueueFull as err:
raise Full() from err
def put(
self, item: _T, block: bool = True, timeout: Optional[float] = None
) -> None:
if not block:
return self.put_nowait(item)
try:
if timeout is not None:
self.await_(asyncio.wait_for(self._queue.put(item), timeout))
else:
self.await_(self._queue.put(item))
except (asyncio.QueueFull, asyncio.TimeoutError) as err:
raise Full() from err
def get_nowait(self) -> _T:
try:
return self._queue.get_nowait()
except asyncio.QueueEmpty as err:
raise Empty() from err
def get(self, block: bool = True, timeout: Optional[float] = None) -> _T:
if not block:
return self.get_nowait()
try:
if timeout is not None:
return self.await_(
asyncio.wait_for(self._queue.get(), timeout)
)
else:
return self.await_(self._queue.get())
except (asyncio.QueueEmpty, asyncio.TimeoutError) as err:
raise Empty() from err
class FallbackAsyncAdaptedQueue(AsyncAdaptedQueue[_T]):
if not typing.TYPE_CHECKING:
await_ = staticmethod(await_fallback)

View File

@ -0,0 +1,201 @@
# util/tool_support.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: allow-untyped-defs, allow-untyped-calls
"""support routines for the helpers in tools/.
These aren't imported by the enclosing util package as the are not
needed for normal library use.
"""
from __future__ import annotations
from argparse import ArgumentParser
from argparse import Namespace
import contextlib
import difflib
import os
from pathlib import Path
import shlex
import shutil
import subprocess
import sys
from typing import Any
from typing import Dict
from typing import Iterator
from typing import Optional
from typing import Union
from . import compat
class code_writer_cmd:
parser: ArgumentParser
args: Namespace
suppress_output: bool
diffs_detected: bool
source_root: Path
pyproject_toml_path: Path
def __init__(self, tool_script: str):
self.source_root = Path(tool_script).parent.parent
self.pyproject_toml_path = self.source_root / Path("pyproject.toml")
assert self.pyproject_toml_path.exists()
self.parser = ArgumentParser()
self.parser.add_argument(
"--stdout",
action="store_true",
help="Write to stdout instead of saving to file",
)
self.parser.add_argument(
"-c",
"--check",
help="Don't write the files back, just return the "
"status. Return code 0 means nothing would change. "
"Return code 1 means some files would be reformatted",
action="store_true",
)
def run_zimports(self, tempfile: str) -> None:
self._run_console_script(
str(tempfile),
{
"entrypoint": "zimports",
"options": f"--toml-config {self.pyproject_toml_path}",
},
)
def run_black(self, tempfile: str) -> None:
self._run_console_script(
str(tempfile),
{
"entrypoint": "black",
"options": f"--config {self.pyproject_toml_path}",
},
)
def _run_console_script(self, path: str, options: Dict[str, Any]) -> None:
"""Run a Python console application from within the process.
Used for black, zimports
"""
is_posix = os.name == "posix"
entrypoint_name = options["entrypoint"]
for entry in compat.importlib_metadata_get("console_scripts"):
if entry.name == entrypoint_name:
impl = entry
break
else:
raise Exception(
f"Could not find entrypoint console_scripts.{entrypoint_name}"
)
cmdline_options_str = options.get("options", "")
cmdline_options_list = shlex.split(
cmdline_options_str, posix=is_posix
) + [path]
kw: Dict[str, Any] = {}
if self.suppress_output:
kw["stdout"] = kw["stderr"] = subprocess.DEVNULL
subprocess.run(
[
sys.executable,
"-c",
"import %s; %s.%s()" % (impl.module, impl.module, impl.attr),
]
+ cmdline_options_list,
cwd=str(self.source_root),
**kw,
)
def write_status(self, *text: str) -> None:
if not self.suppress_output:
sys.stderr.write(" ".join(text))
def write_output_file_from_text(
self, text: str, destination_path: Union[str, Path]
) -> None:
if self.args.check:
self._run_diff(destination_path, source=text)
elif self.args.stdout:
print(text)
else:
self.write_status(f"Writing {destination_path}...")
Path(destination_path).write_text(
text, encoding="utf-8", newline="\n"
)
self.write_status("done\n")
def write_output_file_from_tempfile(
self, tempfile: str, destination_path: str
) -> None:
if self.args.check:
self._run_diff(destination_path, source_file=tempfile)
os.unlink(tempfile)
elif self.args.stdout:
with open(tempfile) as tf:
print(tf.read())
os.unlink(tempfile)
else:
self.write_status(f"Writing {destination_path}...")
shutil.move(tempfile, destination_path)
self.write_status("done\n")
def _run_diff(
self,
destination_path: Union[str, Path],
*,
source: Optional[str] = None,
source_file: Optional[str] = None,
) -> None:
if source_file:
with open(source_file, encoding="utf-8") as tf:
source_lines = list(tf)
elif source is not None:
source_lines = source.splitlines(keepends=True)
else:
assert False, "source or source_file is required"
with open(destination_path, encoding="utf-8") as dp:
d = difflib.unified_diff(
list(dp),
source_lines,
fromfile=Path(destination_path).as_posix(),
tofile="<proposed changes>",
n=3,
lineterm="\n",
)
d_as_list = list(d)
if d_as_list:
self.diffs_detected = True
print("".join(d_as_list))
@contextlib.contextmanager
def add_arguments(self) -> Iterator[ArgumentParser]:
yield self.parser
@contextlib.contextmanager
def run_program(self) -> Iterator[None]:
self.args = self.parser.parse_args()
if self.args.check:
self.diffs_detected = False
self.suppress_output = True
elif self.args.stdout:
self.suppress_output = True
else:
self.suppress_output = False
yield
if self.args.check and self.diffs_detected:
sys.exit(1)
else:
sys.exit(0)

View File

@ -0,0 +1,120 @@
# util/topological.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""Topological sorting algorithms."""
from __future__ import annotations
from typing import Any
from typing import Collection
from typing import DefaultDict
from typing import Iterable
from typing import Iterator
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import TypeVar
from .. import util
from ..exc import CircularDependencyError
_T = TypeVar("_T", bound=Any)
__all__ = ["sort", "sort_as_subsets", "find_cycles"]
def sort_as_subsets(
tuples: Collection[Tuple[_T, _T]], allitems: Collection[_T]
) -> Iterator[Sequence[_T]]:
edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set)
for parent, child in tuples:
edges[child].add(parent)
todo = list(allitems)
todo_set = set(allitems)
while todo_set:
output = []
for node in todo:
if todo_set.isdisjoint(edges[node]):
output.append(node)
if not output:
raise CircularDependencyError(
"Circular dependency detected.",
find_cycles(tuples, allitems),
_gen_edges(edges),
)
todo_set.difference_update(output)
todo = [t for t in todo if t in todo_set]
yield output
def sort(
tuples: Collection[Tuple[_T, _T]],
allitems: Collection[_T],
deterministic_order: bool = True,
) -> Iterator[_T]:
"""sort the given list of items by dependency.
'tuples' is a list of tuples representing a partial ordering.
deterministic_order is no longer used, the order is now always
deterministic given the order of "allitems". the flag is there
for backwards compatibility with Alembic.
"""
for set_ in sort_as_subsets(tuples, allitems):
yield from set_
def find_cycles(
tuples: Iterable[Tuple[_T, _T]], allitems: Iterable[_T]
) -> Set[_T]:
# adapted from:
# https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html
edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set)
for parent, child in tuples:
edges[parent].add(child)
nodes_to_test = set(edges)
output = set()
# we'd like to find all nodes that are
# involved in cycles, so we do the full
# pass through the whole thing for each
# node in the original list.
# we can go just through parent edge nodes.
# if a node is only a child and never a parent,
# by definition it can't be part of a cycle. same
# if it's not in the edges at all.
for node in nodes_to_test:
stack = [node]
todo = nodes_to_test.difference(stack)
while stack:
top = stack[-1]
for node in edges[top]:
if node in stack:
cyc = stack[stack.index(node) :]
todo.difference_update(cyc)
output.update(cyc)
if node in todo:
stack.append(node)
todo.remove(node)
break
else:
node = stack.pop()
return output
def _gen_edges(edges: DefaultDict[_T, Set[_T]]) -> Set[Tuple[_T, _T]]:
return {(right, left) for left in edges for right in edges[left]}

View File

@ -0,0 +1,732 @@
# util/typing.py
# Copyright (C) 2022-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: allow-untyped-defs, allow-untyped-calls
from __future__ import annotations
import builtins
from collections import deque
import collections.abc as collections_abc
import re
import sys
import typing
from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
from typing import ForwardRef
from typing import Generic
from typing import Iterable
from typing import Mapping
from typing import NewType
from typing import NoReturn
from typing import Optional
from typing import overload
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import typing_extensions
from . import compat
if True: # zimports removes the tailing comments
from typing_extensions import Annotated as Annotated # 3.8
from typing_extensions import Concatenate as Concatenate # 3.10
from typing_extensions import (
dataclass_transform as dataclass_transform, # 3.11,
)
from typing_extensions import Final as Final # 3.8
from typing_extensions import final as final # 3.8
from typing_extensions import get_args as get_args # 3.10
from typing_extensions import get_origin as get_origin # 3.10
from typing_extensions import Literal as Literal # 3.8
from typing_extensions import NotRequired as NotRequired # 3.11
from typing_extensions import ParamSpec as ParamSpec # 3.10
from typing_extensions import Protocol as Protocol # 3.8
from typing_extensions import SupportsIndex as SupportsIndex # 3.8
from typing_extensions import TypeAlias as TypeAlias # 3.10
from typing_extensions import TypedDict as TypedDict # 3.8
from typing_extensions import TypeGuard as TypeGuard # 3.10
from typing_extensions import Self as Self # 3.11
from typing_extensions import TypeAliasType as TypeAliasType # 3.12
from typing_extensions import Never as Never # 3.11
from typing_extensions import LiteralString as LiteralString # 3.11
_T = TypeVar("_T", bound=Any)
_KT = TypeVar("_KT")
_KT_co = TypeVar("_KT_co", covariant=True)
_KT_contra = TypeVar("_KT_contra", contravariant=True)
_VT = TypeVar("_VT")
_VT_co = TypeVar("_VT_co", covariant=True)
if compat.py310:
# why they took until py310 to put this in stdlib is beyond me,
# I've been wanting it since py27
from types import NoneType as NoneType
else:
NoneType = type(None) # type: ignore
NoneFwd = ForwardRef("None")
_AnnotationScanType = Union[
Type[Any], str, ForwardRef, NewType, TypeAliasType, "GenericProtocol[Any]"
]
class ArgsTypeProtocol(Protocol):
"""protocol for types that have ``__args__``
there's no public interface for this AFAIK
"""
__args__: Tuple[_AnnotationScanType, ...]
class GenericProtocol(Protocol[_T]):
"""protocol for generic types.
this since Python.typing _GenericAlias is private
"""
__args__: Tuple[_AnnotationScanType, ...]
__origin__: Type[_T]
# Python's builtin _GenericAlias has this method, however builtins like
# list, dict, etc. do not, even though they have ``__origin__`` and
# ``__args__``
#
# def copy_with(self, params: Tuple[_AnnotationScanType, ...]) -> Type[_T]:
# ...
# copied from TypeShed, required in order to implement
# MutableMapping.update()
class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
def keys(self) -> Iterable[_KT]: ...
def __getitem__(self, __k: _KT) -> _VT_co: ...
# work around https://github.com/microsoft/pyright/issues/3025
_LiteralStar = Literal["*"]
def de_stringify_annotation(
cls: Type[Any],
annotation: _AnnotationScanType,
originating_module: str,
locals_: Mapping[str, Any],
*,
str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
include_generic: bool = False,
_already_seen: Optional[Set[Any]] = None,
) -> Type[Any]:
"""Resolve annotations that may be string based into real objects.
This is particularly important if a module defines "from __future__ import
annotations", as everything inside of __annotations__ is a string. We want
to at least have generic containers like ``Mapped``, ``Union``, ``List``,
etc.
"""
# looked at typing.get_type_hints(), looked at pydantic. We need much
# less here, and we here try to not use any private typing internals
# or construct ForwardRef objects which is documented as something
# that should be avoided.
original_annotation = annotation
if is_fwd_ref(annotation):
annotation = annotation.__forward_arg__
if isinstance(annotation, str):
if str_cleanup_fn:
annotation = str_cleanup_fn(annotation, originating_module)
annotation = eval_expression(
annotation, originating_module, locals_=locals_, in_class=cls
)
if (
include_generic
and is_generic(annotation)
and not is_literal(annotation)
):
if _already_seen is None:
_already_seen = set()
if annotation in _already_seen:
# only occurs recursively. outermost return type
# will always be Type.
# the element here will be either ForwardRef or
# Optional[ForwardRef]
return original_annotation # type: ignore
else:
_already_seen.add(annotation)
elements = tuple(
de_stringify_annotation(
cls,
elem,
originating_module,
locals_,
str_cleanup_fn=str_cleanup_fn,
include_generic=include_generic,
_already_seen=_already_seen,
)
for elem in annotation.__args__
)
return _copy_generic_annotation_with(annotation, elements)
return annotation # type: ignore
def fixup_container_fwd_refs(
type_: _AnnotationScanType,
) -> _AnnotationScanType:
"""Correct dict['x', 'y'] into dict[ForwardRef('x'), ForwardRef('y')]
and similar for list, set
"""
if (
is_generic(type_)
and get_origin(type_)
in (
dict,
set,
list,
collections_abc.MutableSet,
collections_abc.MutableMapping,
collections_abc.MutableSequence,
collections_abc.Mapping,
collections_abc.Sequence,
)
# fight, kick and scream to struggle to tell the difference between
# dict[] and typing.Dict[] which DO NOT compare the same and DO NOT
# behave the same yet there is NO WAY to distinguish between which type
# it is using public attributes
and not re.match(
"typing.(?:Dict|List|Set|.*Mapping|.*Sequence|.*Set)", repr(type_)
)
):
# compat with py3.10 and earlier
return get_origin(type_).__class_getitem__( # type: ignore
tuple(
[
ForwardRef(elem) if isinstance(elem, str) else elem
for elem in get_args(type_)
]
)
)
return type_
def _copy_generic_annotation_with(
annotation: GenericProtocol[_T], elements: Tuple[_AnnotationScanType, ...]
) -> Type[_T]:
if hasattr(annotation, "copy_with"):
# List, Dict, etc. real generics
return annotation.copy_with(elements) # type: ignore
else:
# Python builtins list, dict, etc.
return annotation.__origin__[elements] # type: ignore
def eval_expression(
expression: str,
module_name: str,
*,
locals_: Optional[Mapping[str, Any]] = None,
in_class: Optional[Type[Any]] = None,
) -> Any:
try:
base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
except KeyError as ke:
raise NameError(
f"Module {module_name} isn't present in sys.modules; can't "
f"evaluate expression {expression}"
) from ke
try:
if in_class is not None:
cls_namespace = dict(in_class.__dict__)
cls_namespace.setdefault(in_class.__name__, in_class)
# see #10899. We want the locals/globals to take precedence
# over the class namespace in this context, even though this
# is not the usual way variables would resolve.
cls_namespace.update(base_globals)
annotation = eval(expression, cls_namespace, locals_)
else:
annotation = eval(expression, base_globals, locals_)
except Exception as err:
raise NameError(
f"Could not de-stringify annotation {expression!r}"
) from err
else:
return annotation
def eval_name_only(
name: str,
module_name: str,
*,
locals_: Optional[Mapping[str, Any]] = None,
) -> Any:
if "." in name:
return eval_expression(name, module_name, locals_=locals_)
try:
base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
except KeyError as ke:
raise NameError(
f"Module {module_name} isn't present in sys.modules; can't "
f"resolve name {name}"
) from ke
# name only, just look in globals. eval() works perfectly fine here,
# however we are seeking to have this be faster, as this occurs for
# every Mapper[] keyword, etc. depending on configuration
try:
return base_globals[name]
except KeyError as ke:
# check in builtins as well to handle `list`, `set` or `dict`, etc.
try:
return builtins.__dict__[name]
except KeyError:
pass
raise NameError(
f"Could not locate name {name} in module {module_name}"
) from ke
def resolve_name_to_real_class_name(name: str, module_name: str) -> str:
try:
obj = eval_name_only(name, module_name)
except NameError:
return name
else:
return getattr(obj, "__name__", name)
def is_pep593(type_: Optional[Any]) -> bool:
return type_ is not None and get_origin(type_) in _type_tuples.Annotated
def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]:
return isinstance(obj, collections_abc.Iterable) and not isinstance(
obj, (str, bytes)
)
def is_literal(type_: Any) -> bool:
return get_origin(type_) in _type_tuples.Literal
def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]:
return hasattr(type_, "__supertype__")
# doesn't work in 3.8, 3.7 as it passes a closure, not an
# object instance
# isinstance(type, type_instances.NewType)
def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]:
return hasattr(type_, "__args__") and hasattr(type_, "__origin__")
def is_pep695(type_: _AnnotationScanType) -> TypeGuard[TypeAliasType]:
# NOTE: a generic TAT does not instance check as TypeAliasType outside of
# python 3.10. For sqlalchemy use cases it's fine to consider it a TAT
# though.
# NOTE: things seems to work also without this additional check
if is_generic(type_):
return is_pep695(type_.__origin__)
return isinstance(type_, _type_instances.TypeAliasType)
def flatten_newtype(type_: NewType) -> Type[Any]:
super_type = type_.__supertype__
while is_newtype(super_type):
super_type = super_type.__supertype__
return super_type # type: ignore[return-value]
def pep695_values(type_: _AnnotationScanType) -> Set[Any]:
"""Extracts the value from a TypeAliasType, recursively exploring unions
and inner TypeAliasType to flatten them into a single set.
Forward references are not evaluated, so no recursive exploration happens
into them.
"""
_seen = set()
def recursive_value(inner_type):
if inner_type in _seen:
# recursion are not supported (at least it's flagged as
# an error by pyright). Just avoid infinite loop
return inner_type
_seen.add(inner_type)
if not is_pep695(inner_type):
return inner_type
value = inner_type.__value__
if not is_union(value):
return value
return [recursive_value(t) for t in value.__args__]
res = recursive_value(type_)
if isinstance(res, list):
types = set()
stack = deque(res)
while stack:
t = stack.popleft()
if isinstance(t, list):
stack.extend(t)
else:
types.add(None if t in {NoneType, NoneFwd} else t)
return types
else:
return {res}
def is_fwd_ref(
type_: _AnnotationScanType,
check_generic: bool = False,
check_for_plain_string: bool = False,
) -> TypeGuard[ForwardRef]:
if check_for_plain_string and isinstance(type_, str):
return True
elif isinstance(type_, _type_instances.ForwardRef):
return True
elif check_generic and is_generic(type_):
return any(
is_fwd_ref(
arg, True, check_for_plain_string=check_for_plain_string
)
for arg in type_.__args__
)
else:
return False
@overload
def de_optionalize_union_types(type_: str) -> str: ...
@overload
def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]: ...
@overload
def de_optionalize_union_types(
type_: _AnnotationScanType,
) -> _AnnotationScanType: ...
def de_optionalize_union_types(
type_: _AnnotationScanType,
) -> _AnnotationScanType:
"""Given a type, filter out ``Union`` types that include ``NoneType``
to not include the ``NoneType``.
Contains extra logic to work on non-flattened unions, unions that contain
``None`` (seen in py38, 37)
"""
if is_fwd_ref(type_):
return _de_optionalize_fwd_ref_union_types(type_, False)
elif is_union(type_) and includes_none(type_):
if compat.py39:
typ = set(type_.__args__)
else:
# py38, 37 - unions are not automatically flattened, can contain
# None rather than NoneType
stack_of_unions = deque([type_])
typ = set()
while stack_of_unions:
u_typ = stack_of_unions.popleft()
for elem in u_typ.__args__:
if is_union(elem):
stack_of_unions.append(elem)
else:
typ.add(elem)
typ.discard(None) # type: ignore
typ.discard(NoneType)
typ.discard(NoneFwd)
return make_union_type(*typ)
else:
return type_
@overload
def _de_optionalize_fwd_ref_union_types(
type_: ForwardRef, return_has_none: Literal[True]
) -> bool: ...
@overload
def _de_optionalize_fwd_ref_union_types(
type_: ForwardRef, return_has_none: Literal[False]
) -> _AnnotationScanType: ...
def _de_optionalize_fwd_ref_union_types(
type_: ForwardRef, return_has_none: bool
) -> Union[_AnnotationScanType, bool]:
"""return the non-optional type for Optional[], Union[None, ...], x|None,
etc. without de-stringifying forward refs.
unfortunately this seems to require lots of hardcoded heuristics
"""
annotation = type_.__forward_arg__
mm = re.match(r"^(.+?)\[(.+)\]$", annotation)
if mm:
g1 = mm.group(1).split(".")[-1]
if g1 == "Optional":
return True if return_has_none else ForwardRef(mm.group(2))
elif g1 == "Union":
if "[" in mm.group(2):
# cases like "Union[Dict[str, int], int, None]"
elements: list[str] = []
current: list[str] = []
ignore_comma = 0
for char in mm.group(2):
if char == "[":
ignore_comma += 1
elif char == "]":
ignore_comma -= 1
elif ignore_comma == 0 and char == ",":
elements.append("".join(current).strip())
current.clear()
continue
current.append(char)
else:
elements = re.split(r",\s*", mm.group(2))
parts = [ForwardRef(elem) for elem in elements if elem != "None"]
if return_has_none:
return len(elements) != len(parts)
else:
return make_union_type(*parts) if parts else Never # type: ignore[return-value] # noqa: E501
else:
return False if return_has_none else type_
pipe_tokens = re.split(r"\s*\|\s*", annotation)
has_none = "None" in pipe_tokens
if return_has_none:
return has_none
if has_none:
anno_str = "|".join(p for p in pipe_tokens if p != "None")
return ForwardRef(anno_str) if anno_str else Never # type: ignore[return-value] # noqa: E501
return type_
def make_union_type(*types: _AnnotationScanType) -> Type[Any]:
"""Make a Union type."""
return Union.__getitem__(types) # type: ignore
def includes_none(type_: Any) -> bool:
"""Returns if the type annotation ``type_`` allows ``None``.
This function supports:
* forward refs
* unions
* pep593 - Annotated
* pep695 - TypeAliasType (does not support looking into
fw reference of other pep695)
* NewType
* plain types like ``int``, ``None``, etc
"""
if is_fwd_ref(type_):
return _de_optionalize_fwd_ref_union_types(type_, True)
if is_union(type_):
return any(includes_none(t) for t in get_args(type_))
if is_pep593(type_):
return includes_none(get_args(type_)[0])
if is_pep695(type_):
return any(includes_none(t) for t in pep695_values(type_))
if is_newtype(type_):
return includes_none(type_.__supertype__)
try:
return type_ in (NoneFwd, NoneType, None)
except TypeError:
# if type_ is Column, mapped_column(), etc. the use of "in"
# resolves to ``__eq__()`` which then gives us an expression object
# that can't resolve to boolean. just catch it all via exception
return False
def is_a_type(type_: Any) -> bool:
return (
isinstance(type_, type)
or hasattr(type_, "__origin__")
or type_.__module__ in ("typing", "typing_extensions")
or type(type_).__mro__[0].__module__ in ("typing", "typing_extensions")
)
def is_union(type_: Any) -> TypeGuard[ArgsTypeProtocol]:
return is_origin_of(type_, "Union", "UnionType")
def is_origin_of_cls(
type_: Any, class_obj: Union[Tuple[Type[Any], ...], Type[Any]]
) -> bool:
"""return True if the given type has an __origin__ that shares a base
with the given class"""
origin = get_origin(type_)
if origin is None:
return False
return isinstance(origin, type) and issubclass(origin, class_obj)
def is_origin_of(
type_: Any, *names: str, module: Optional[str] = None
) -> bool:
"""return True if the given type has an __origin__ with the given name
and optional module."""
origin = get_origin(type_)
if origin is None:
return False
return _get_type_name(origin) in names and (
module is None or origin.__module__.startswith(module)
)
def _get_type_name(type_: Type[Any]) -> str:
if compat.py310:
return type_.__name__
else:
typ_name = getattr(type_, "__name__", None)
if typ_name is None:
typ_name = getattr(type_, "_name", None)
return typ_name # type: ignore
class DescriptorProto(Protocol):
def __get__(self, instance: object, owner: Any) -> Any: ...
def __set__(self, instance: Any, value: Any) -> None: ...
def __delete__(self, instance: Any) -> None: ...
_DESC = TypeVar("_DESC", bound=DescriptorProto)
class DescriptorReference(Generic[_DESC]):
"""a descriptor that refers to a descriptor.
used for cases where we need to have an instance variable referring to an
object that is itself a descriptor, which typically confuses typing tools
as they don't know when they should use ``__get__`` or not when referring
to the descriptor assignment as an instance variable. See
sqlalchemy.orm.interfaces.PropComparator.prop
"""
if TYPE_CHECKING:
def __get__(self, instance: object, owner: Any) -> _DESC: ...
def __set__(self, instance: Any, value: _DESC) -> None: ...
def __delete__(self, instance: Any) -> None: ...
_DESC_co = TypeVar("_DESC_co", bound=DescriptorProto, covariant=True)
class RODescriptorReference(Generic[_DESC_co]):
"""a descriptor that refers to a descriptor.
same as :class:`.DescriptorReference` but is read-only, so that subclasses
can define a subtype as the generically contained element
"""
if TYPE_CHECKING:
def __get__(self, instance: object, owner: Any) -> _DESC_co: ...
def __set__(self, instance: Any, value: Any) -> NoReturn: ...
def __delete__(self, instance: Any) -> NoReturn: ...
_FN = TypeVar("_FN", bound=Optional[Callable[..., Any]])
class CallableReference(Generic[_FN]):
"""a descriptor that refers to a callable.
works around mypy's limitation of not allowing callables assigned
as instance variables
"""
if TYPE_CHECKING:
def __get__(self, instance: object, owner: Any) -> _FN: ...
def __set__(self, instance: Any, value: _FN) -> None: ...
def __delete__(self, instance: Any) -> None: ...
class _TypingInstances:
def __getattr__(self, key: str) -> tuple[type, ...]:
types = tuple(
{
t
for t in [
getattr(typing, key, None),
getattr(typing_extensions, key, None),
]
if t is not None
}
)
if not types:
raise AttributeError(key)
self.__dict__[key] = types
return types
_type_tuples = _TypingInstances()
if TYPE_CHECKING:
_type_instances = typing_extensions
else:
_type_instances = _type_tuples
LITERAL_TYPES = _type_tuples.Literal