Update 2025-04-13_16:25:39
This commit is contained in:
@ -0,0 +1,96 @@
|
||||
# testing/__init__.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from . import config
|
||||
from .assertions import assert_raises
|
||||
from .assertions import assert_raises_context_ok
|
||||
from .assertions import assert_raises_message
|
||||
from .assertions import assert_raises_message_context_ok
|
||||
from .assertions import assert_warns
|
||||
from .assertions import assert_warns_message
|
||||
from .assertions import AssertsCompiledSQL
|
||||
from .assertions import AssertsExecutionResults
|
||||
from .assertions import ComparesIndexes
|
||||
from .assertions import ComparesTables
|
||||
from .assertions import emits_warning
|
||||
from .assertions import emits_warning_on
|
||||
from .assertions import eq_
|
||||
from .assertions import eq_ignore_whitespace
|
||||
from .assertions import eq_regex
|
||||
from .assertions import expect_deprecated
|
||||
from .assertions import expect_deprecated_20
|
||||
from .assertions import expect_raises
|
||||
from .assertions import expect_raises_message
|
||||
from .assertions import expect_warnings
|
||||
from .assertions import in_
|
||||
from .assertions import int_within_variance
|
||||
from .assertions import is_
|
||||
from .assertions import is_false
|
||||
from .assertions import is_instance_of
|
||||
from .assertions import is_none
|
||||
from .assertions import is_not
|
||||
from .assertions import is_not_
|
||||
from .assertions import is_not_none
|
||||
from .assertions import is_true
|
||||
from .assertions import le_
|
||||
from .assertions import ne_
|
||||
from .assertions import not_in
|
||||
from .assertions import not_in_
|
||||
from .assertions import startswith_
|
||||
from .assertions import uses_deprecated
|
||||
from .config import add_to_marker
|
||||
from .config import async_test
|
||||
from .config import combinations
|
||||
from .config import combinations_list
|
||||
from .config import db
|
||||
from .config import fixture
|
||||
from .config import requirements as requires
|
||||
from .config import skip_test
|
||||
from .config import Variation
|
||||
from .config import variation
|
||||
from .config import variation_fixture
|
||||
from .exclusions import _is_excluded
|
||||
from .exclusions import _server_version
|
||||
from .exclusions import against as _against
|
||||
from .exclusions import db_spec
|
||||
from .exclusions import exclude
|
||||
from .exclusions import fails
|
||||
from .exclusions import fails_if
|
||||
from .exclusions import fails_on
|
||||
from .exclusions import fails_on_everything_except
|
||||
from .exclusions import future
|
||||
from .exclusions import only_if
|
||||
from .exclusions import only_on
|
||||
from .exclusions import skip
|
||||
from .exclusions import skip_if
|
||||
from .schema import eq_clause_element
|
||||
from .schema import eq_type_affinity
|
||||
from .util import adict
|
||||
from .util import fail
|
||||
from .util import flag_combinations
|
||||
from .util import force_drop_names
|
||||
from .util import lambda_combinations
|
||||
from .util import metadata_fixture
|
||||
from .util import provide_metadata
|
||||
from .util import resolve_lambda
|
||||
from .util import rowset
|
||||
from .util import run_as_contextmanager
|
||||
from .util import skip_if_timeout
|
||||
from .util import teardown_events
|
||||
from .warnings import assert_warnings
|
||||
from .warnings import warn_test_suite
|
||||
|
||||
|
||||
def against(*queries):
|
||||
return _against(config._current, *queries)
|
||||
|
||||
|
||||
crashes = skip
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,989 @@
|
||||
# testing/assertions.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
import contextlib
|
||||
from copy import copy
|
||||
from itertools import filterfalse
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
from . import assertsql
|
||||
from . import config
|
||||
from . import engines
|
||||
from . import mock
|
||||
from .exclusions import db_spec
|
||||
from .util import fail
|
||||
from .. import exc as sa_exc
|
||||
from .. import schema
|
||||
from .. import sql
|
||||
from .. import types as sqltypes
|
||||
from .. import util
|
||||
from ..engine import default
|
||||
from ..engine import url
|
||||
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
|
||||
from ..util import decorator
|
||||
|
||||
|
||||
def expect_warnings(*messages, **kw):
|
||||
"""Context manager which expects one or more warnings.
|
||||
|
||||
With no arguments, squelches all SAWarning emitted via
|
||||
sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise
|
||||
pass string expressions that will match selected warnings via regex;
|
||||
all non-matching warnings are sent through.
|
||||
|
||||
The expect version **asserts** that the warnings were in fact seen.
|
||||
|
||||
Note that the test suite sets SAWarning warnings to raise exceptions.
|
||||
|
||||
""" # noqa
|
||||
return _expect_warnings_sqla_only(sa_exc.SAWarning, messages, **kw)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def expect_warnings_on(db, *messages, **kw):
|
||||
"""Context manager which expects one or more warnings on specific
|
||||
dialects.
|
||||
|
||||
The expect version **asserts** that the warnings were in fact seen.
|
||||
|
||||
"""
|
||||
spec = db_spec(db)
|
||||
|
||||
if isinstance(db, str) and not spec(config._current):
|
||||
yield
|
||||
else:
|
||||
with expect_warnings(*messages, **kw):
|
||||
yield
|
||||
|
||||
|
||||
def emits_warning(*messages):
|
||||
"""Decorator form of expect_warnings().
|
||||
|
||||
Note that emits_warning does **not** assert that the warnings
|
||||
were in fact seen.
|
||||
|
||||
"""
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
with expect_warnings(assert_=False, *messages):
|
||||
return fn(*args, **kw)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def expect_deprecated(*messages, **kw):
|
||||
return _expect_warnings_sqla_only(
|
||||
sa_exc.SADeprecationWarning, messages, **kw
|
||||
)
|
||||
|
||||
|
||||
def expect_deprecated_20(*messages, **kw):
|
||||
return _expect_warnings_sqla_only(
|
||||
sa_exc.Base20DeprecationWarning, messages, **kw
|
||||
)
|
||||
|
||||
|
||||
def emits_warning_on(db, *messages):
|
||||
"""Mark a test as emitting a warning on a specific dialect.
|
||||
|
||||
With no arguments, squelches all SAWarning failures. Or pass one or more
|
||||
strings; these will be matched to the root of the warning description by
|
||||
warnings.filterwarnings().
|
||||
|
||||
Note that emits_warning_on does **not** assert that the warnings
|
||||
were in fact seen.
|
||||
|
||||
"""
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
with expect_warnings_on(db, assert_=False, *messages):
|
||||
return fn(*args, **kw)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def uses_deprecated(*messages):
|
||||
"""Mark a test as immune from fatal deprecation warnings.
|
||||
|
||||
With no arguments, squelches all SADeprecationWarning failures.
|
||||
Or pass one or more strings; these will be matched to the root
|
||||
of the warning description by warnings.filterwarnings().
|
||||
|
||||
As a special case, you may pass a function name prefixed with //
|
||||
and it will be re-written as needed to match the standard warning
|
||||
verbiage emitted by the sqlalchemy.util.deprecated decorator.
|
||||
|
||||
Note that uses_deprecated does **not** assert that the warnings
|
||||
were in fact seen.
|
||||
|
||||
"""
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
with expect_deprecated(*messages, assert_=False):
|
||||
return fn(*args, **kw)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
_FILTERS = None
|
||||
_SEEN = None
|
||||
_EXC_CLS = None
|
||||
|
||||
|
||||
def _expect_warnings_sqla_only(
|
||||
exc_cls,
|
||||
messages,
|
||||
regex=True,
|
||||
search_msg=False,
|
||||
assert_=True,
|
||||
):
|
||||
"""SQLAlchemy internal use only _expect_warnings().
|
||||
|
||||
Alembic is using _expect_warnings() directly, and should be updated
|
||||
to use this new interface.
|
||||
|
||||
"""
|
||||
return _expect_warnings(
|
||||
exc_cls,
|
||||
messages,
|
||||
regex=regex,
|
||||
search_msg=search_msg,
|
||||
assert_=assert_,
|
||||
raise_on_any_unexpected=True,
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _expect_warnings(
|
||||
exc_cls,
|
||||
messages,
|
||||
regex=True,
|
||||
search_msg=False,
|
||||
assert_=True,
|
||||
raise_on_any_unexpected=False,
|
||||
squelch_other_warnings=False,
|
||||
):
|
||||
global _FILTERS, _SEEN, _EXC_CLS
|
||||
|
||||
if regex or search_msg:
|
||||
filters = [re.compile(msg, re.I | re.S) for msg in messages]
|
||||
else:
|
||||
filters = list(messages)
|
||||
|
||||
if _FILTERS is not None:
|
||||
# nested call; update _FILTERS and _SEEN, return. outer
|
||||
# block will assert our messages
|
||||
assert _SEEN is not None
|
||||
assert _EXC_CLS is not None
|
||||
_FILTERS.extend(filters)
|
||||
_SEEN.update(filters)
|
||||
_EXC_CLS += (exc_cls,)
|
||||
yield
|
||||
else:
|
||||
seen = _SEEN = set(filters)
|
||||
_FILTERS = filters
|
||||
_EXC_CLS = (exc_cls,)
|
||||
|
||||
if raise_on_any_unexpected:
|
||||
|
||||
def real_warn(msg, *arg, **kw):
|
||||
raise AssertionError("Got unexpected warning: %r" % msg)
|
||||
|
||||
else:
|
||||
real_warn = warnings.warn
|
||||
|
||||
def our_warn(msg, *arg, **kw):
|
||||
if isinstance(msg, _EXC_CLS):
|
||||
exception = type(msg)
|
||||
msg = str(msg)
|
||||
elif arg:
|
||||
exception = arg[0]
|
||||
else:
|
||||
exception = None
|
||||
|
||||
if not exception or not issubclass(exception, _EXC_CLS):
|
||||
if not squelch_other_warnings:
|
||||
return real_warn(msg, *arg, **kw)
|
||||
else:
|
||||
return
|
||||
|
||||
if not filters and not raise_on_any_unexpected:
|
||||
return
|
||||
|
||||
for filter_ in filters:
|
||||
if (
|
||||
(search_msg and filter_.search(msg))
|
||||
or (regex and filter_.match(msg))
|
||||
or (not regex and filter_ == msg)
|
||||
):
|
||||
seen.discard(filter_)
|
||||
break
|
||||
else:
|
||||
if not squelch_other_warnings:
|
||||
real_warn(msg, *arg, **kw)
|
||||
|
||||
with mock.patch("warnings.warn", our_warn):
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_SEEN = _FILTERS = _EXC_CLS = None
|
||||
|
||||
if assert_:
|
||||
assert not seen, "Warnings were not seen: %s" % ", ".join(
|
||||
"%r" % (s.pattern if regex else s) for s in seen
|
||||
)
|
||||
|
||||
|
||||
def global_cleanup_assertions():
|
||||
"""Check things that have to be finalized at the end of a test suite.
|
||||
|
||||
Hardcoded at the moment, a modular system can be built here
|
||||
to support things like PG prepared transactions, tables all
|
||||
dropped, etc.
|
||||
|
||||
"""
|
||||
_assert_no_stray_pool_connections()
|
||||
|
||||
|
||||
def _assert_no_stray_pool_connections():
|
||||
engines.testing_reaper.assert_all_closed()
|
||||
|
||||
|
||||
def int_within_variance(expected, received, variance):
|
||||
deviance = int(expected * variance)
|
||||
assert (
|
||||
abs(received - expected) < deviance
|
||||
), "Given int value %s is not within %d%% of expected value %s" % (
|
||||
received,
|
||||
variance * 100,
|
||||
expected,
|
||||
)
|
||||
|
||||
|
||||
def eq_regex(a, b, msg=None, flags=0):
|
||||
assert re.match(b, a, flags), msg or "%r !~ %r" % (a, b)
|
||||
|
||||
|
||||
def eq_(a, b, msg=None):
|
||||
"""Assert a == b, with repr messaging on failure."""
|
||||
assert a == b, msg or "%r != %r" % (a, b)
|
||||
|
||||
|
||||
def ne_(a, b, msg=None):
|
||||
"""Assert a != b, with repr messaging on failure."""
|
||||
assert a != b, msg or "%r == %r" % (a, b)
|
||||
|
||||
|
||||
def le_(a, b, msg=None):
|
||||
"""Assert a <= b, with repr messaging on failure."""
|
||||
assert a <= b, msg or "%r != %r" % (a, b)
|
||||
|
||||
|
||||
def is_instance_of(a, b, msg=None):
|
||||
assert isinstance(a, b), msg or "%r is not an instance of %r" % (a, b)
|
||||
|
||||
|
||||
def is_none(a, msg=None):
|
||||
is_(a, None, msg=msg)
|
||||
|
||||
|
||||
def is_not_none(a, msg=None):
|
||||
is_not(a, None, msg=msg)
|
||||
|
||||
|
||||
def is_true(a, msg=None):
|
||||
is_(bool(a), True, msg=msg)
|
||||
|
||||
|
||||
def is_false(a, msg=None):
|
||||
is_(bool(a), False, msg=msg)
|
||||
|
||||
|
||||
def is_(a, b, msg=None):
|
||||
"""Assert a is b, with repr messaging on failure."""
|
||||
assert a is b, msg or "%r is not %r" % (a, b)
|
||||
|
||||
|
||||
def is_not(a, b, msg=None):
|
||||
"""Assert a is not b, with repr messaging on failure."""
|
||||
assert a is not b, msg or "%r is %r" % (a, b)
|
||||
|
||||
|
||||
# deprecated. See #5429
|
||||
is_not_ = is_not
|
||||
|
||||
|
||||
def in_(a, b, msg=None):
|
||||
"""Assert a in b, with repr messaging on failure."""
|
||||
assert a in b, msg or "%r not in %r" % (a, b)
|
||||
|
||||
|
||||
def not_in(a, b, msg=None):
|
||||
"""Assert a in not b, with repr messaging on failure."""
|
||||
assert a not in b, msg or "%r is in %r" % (a, b)
|
||||
|
||||
|
||||
# deprecated. See #5429
|
||||
not_in_ = not_in
|
||||
|
||||
|
||||
def startswith_(a, fragment, msg=None):
|
||||
"""Assert a.startswith(fragment), with repr messaging on failure."""
|
||||
assert a.startswith(fragment), msg or "%r does not start with %r" % (
|
||||
a,
|
||||
fragment,
|
||||
)
|
||||
|
||||
|
||||
def eq_ignore_whitespace(a, b, msg=None):
|
||||
a = re.sub(r"^\s+?|\n", "", a)
|
||||
a = re.sub(r" {2,}", " ", a)
|
||||
a = re.sub(r"\t", "", a)
|
||||
b = re.sub(r"^\s+?|\n", "", b)
|
||||
b = re.sub(r" {2,}", " ", b)
|
||||
b = re.sub(r"\t", "", b)
|
||||
|
||||
assert a == b, msg or "%r != %r" % (a, b)
|
||||
|
||||
|
||||
def _assert_proper_exception_context(exception):
|
||||
"""assert that any exception we're catching does not have a __context__
|
||||
without a __cause__, and that __suppress_context__ is never set.
|
||||
|
||||
Python 3 will report nested as exceptions as "during the handling of
|
||||
error X, error Y occurred". That's not what we want to do. we want
|
||||
these exceptions in a cause chain.
|
||||
|
||||
"""
|
||||
|
||||
if (
|
||||
exception.__context__ is not exception.__cause__
|
||||
and not exception.__suppress_context__
|
||||
):
|
||||
assert False, (
|
||||
"Exception %r was correctly raised but did not set a cause, "
|
||||
"within context %r as its cause."
|
||||
% (exception, exception.__context__)
|
||||
)
|
||||
|
||||
|
||||
def assert_raises(except_cls, callable_, *args, **kw):
|
||||
return _assert_raises(except_cls, callable_, args, kw, check_context=True)
|
||||
|
||||
|
||||
def assert_raises_context_ok(except_cls, callable_, *args, **kw):
|
||||
return _assert_raises(except_cls, callable_, args, kw)
|
||||
|
||||
|
||||
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
|
||||
return _assert_raises(
|
||||
except_cls, callable_, args, kwargs, msg=msg, check_context=True
|
||||
)
|
||||
|
||||
|
||||
def assert_warns(except_cls, callable_, *args, **kwargs):
|
||||
"""legacy adapter function for functions that were previously using
|
||||
assert_raises with SAWarning or similar.
|
||||
|
||||
has some workarounds to accommodate the fact that the callable completes
|
||||
with this approach rather than stopping at the exception raise.
|
||||
|
||||
|
||||
"""
|
||||
with _expect_warnings_sqla_only(except_cls, [".*"]):
|
||||
return callable_(*args, **kwargs)
|
||||
|
||||
|
||||
def assert_warns_message(except_cls, msg, callable_, *args, **kwargs):
|
||||
"""legacy adapter function for functions that were previously using
|
||||
assert_raises with SAWarning or similar.
|
||||
|
||||
has some workarounds to accommodate the fact that the callable completes
|
||||
with this approach rather than stopping at the exception raise.
|
||||
|
||||
Also uses regex.search() to match the given message to the error string
|
||||
rather than regex.match().
|
||||
|
||||
"""
|
||||
with _expect_warnings_sqla_only(
|
||||
except_cls,
|
||||
[msg],
|
||||
search_msg=True,
|
||||
regex=False,
|
||||
):
|
||||
return callable_(*args, **kwargs)
|
||||
|
||||
|
||||
def assert_raises_message_context_ok(
|
||||
except_cls, msg, callable_, *args, **kwargs
|
||||
):
|
||||
return _assert_raises(except_cls, callable_, args, kwargs, msg=msg)
|
||||
|
||||
|
||||
def _assert_raises(
|
||||
except_cls, callable_, args, kwargs, msg=None, check_context=False
|
||||
):
|
||||
with _expect_raises(except_cls, msg, check_context) as ec:
|
||||
callable_(*args, **kwargs)
|
||||
return ec.error
|
||||
|
||||
|
||||
class _ErrorContainer:
|
||||
error = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _expect_raises(except_cls, msg=None, check_context=False):
|
||||
if (
|
||||
isinstance(except_cls, type)
|
||||
and issubclass(except_cls, Warning)
|
||||
or isinstance(except_cls, Warning)
|
||||
):
|
||||
raise TypeError(
|
||||
"Use expect_warnings for warnings, not "
|
||||
"expect_raises / assert_raises"
|
||||
)
|
||||
ec = _ErrorContainer()
|
||||
if check_context:
|
||||
are_we_already_in_a_traceback = sys.exc_info()[0]
|
||||
try:
|
||||
yield ec
|
||||
success = False
|
||||
except except_cls as err:
|
||||
ec.error = err
|
||||
success = True
|
||||
if msg is not None:
|
||||
# I'm often pdbing here, and "err" above isn't
|
||||
# in scope, so assign the string explicitly
|
||||
error_as_string = str(err)
|
||||
assert re.search(msg, error_as_string, re.UNICODE), "%r !~ %s" % (
|
||||
msg,
|
||||
error_as_string,
|
||||
)
|
||||
if check_context and not are_we_already_in_a_traceback:
|
||||
_assert_proper_exception_context(err)
|
||||
print(str(err).encode("utf-8"))
|
||||
|
||||
# it's generally a good idea to not carry traceback objects outside
|
||||
# of the except: block, but in this case especially we seem to have
|
||||
# hit some bug in either python 3.10.0b2 or greenlet or both which
|
||||
# this seems to fix:
|
||||
# https://github.com/python-greenlet/greenlet/issues/242
|
||||
del ec
|
||||
|
||||
# assert outside the block so it works for AssertionError too !
|
||||
assert success, "Callable did not raise an exception"
|
||||
|
||||
|
||||
def expect_raises(except_cls, check_context=True):
|
||||
return _expect_raises(except_cls, check_context=check_context)
|
||||
|
||||
|
||||
def expect_raises_message(except_cls, msg, check_context=True):
|
||||
return _expect_raises(except_cls, msg=msg, check_context=check_context)
|
||||
|
||||
|
||||
class AssertsCompiledSQL:
|
||||
def assert_compile(
|
||||
self,
|
||||
clause,
|
||||
result,
|
||||
params=None,
|
||||
checkparams=None,
|
||||
for_executemany=False,
|
||||
check_literal_execute=None,
|
||||
check_post_param=None,
|
||||
dialect=None,
|
||||
checkpositional=None,
|
||||
check_prefetch=None,
|
||||
use_default_dialect=False,
|
||||
allow_dialect_select=False,
|
||||
supports_default_values=True,
|
||||
supports_default_metavalue=True,
|
||||
literal_binds=False,
|
||||
render_postcompile=False,
|
||||
schema_translate_map=None,
|
||||
render_schema_translate=False,
|
||||
default_schema_name=None,
|
||||
from_linting=False,
|
||||
check_param_order=True,
|
||||
use_literal_execute_for_simple_int=False,
|
||||
):
|
||||
if use_default_dialect:
|
||||
dialect = default.DefaultDialect()
|
||||
dialect.supports_default_values = supports_default_values
|
||||
dialect.supports_default_metavalue = supports_default_metavalue
|
||||
elif allow_dialect_select:
|
||||
dialect = None
|
||||
else:
|
||||
if dialect is None:
|
||||
dialect = getattr(self, "__dialect__", None)
|
||||
|
||||
if dialect is None:
|
||||
dialect = config.db.dialect
|
||||
elif dialect == "default" or dialect == "default_qmark":
|
||||
if dialect == "default":
|
||||
dialect = default.DefaultDialect()
|
||||
else:
|
||||
dialect = default.DefaultDialect("qmark")
|
||||
dialect.supports_default_values = supports_default_values
|
||||
dialect.supports_default_metavalue = supports_default_metavalue
|
||||
elif dialect == "default_enhanced":
|
||||
dialect = default.StrCompileDialect()
|
||||
elif isinstance(dialect, str):
|
||||
dialect = url.URL.create(dialect).get_dialect()()
|
||||
|
||||
if default_schema_name:
|
||||
dialect.default_schema_name = default_schema_name
|
||||
|
||||
kw = {}
|
||||
compile_kwargs = {}
|
||||
|
||||
if schema_translate_map:
|
||||
kw["schema_translate_map"] = schema_translate_map
|
||||
|
||||
if params is not None:
|
||||
kw["column_keys"] = list(params)
|
||||
|
||||
if literal_binds:
|
||||
compile_kwargs["literal_binds"] = True
|
||||
|
||||
if render_postcompile:
|
||||
compile_kwargs["render_postcompile"] = True
|
||||
|
||||
if use_literal_execute_for_simple_int:
|
||||
compile_kwargs["use_literal_execute_for_simple_int"] = True
|
||||
|
||||
if for_executemany:
|
||||
kw["for_executemany"] = True
|
||||
|
||||
if render_schema_translate:
|
||||
kw["render_schema_translate"] = True
|
||||
|
||||
if from_linting or getattr(self, "assert_from_linting", False):
|
||||
kw["linting"] = sql.FROM_LINTING
|
||||
|
||||
from sqlalchemy import orm
|
||||
|
||||
if isinstance(clause, orm.Query):
|
||||
stmt = clause._statement_20()
|
||||
stmt._label_style = LABEL_STYLE_TABLENAME_PLUS_COL
|
||||
clause = stmt
|
||||
|
||||
if compile_kwargs:
|
||||
kw["compile_kwargs"] = compile_kwargs
|
||||
|
||||
class DontAccess:
|
||||
def __getattribute__(self, key):
|
||||
raise NotImplementedError(
|
||||
"compiler accessed .statement; use "
|
||||
"compiler.current_executable"
|
||||
)
|
||||
|
||||
class CheckCompilerAccess:
|
||||
def __init__(self, test_statement):
|
||||
self.test_statement = test_statement
|
||||
self._annotations = {}
|
||||
self.supports_execution = getattr(
|
||||
test_statement, "supports_execution", False
|
||||
)
|
||||
|
||||
if self.supports_execution:
|
||||
self._execution_options = test_statement._execution_options
|
||||
|
||||
if hasattr(test_statement, "_returning"):
|
||||
self._returning = test_statement._returning
|
||||
if hasattr(test_statement, "_inline"):
|
||||
self._inline = test_statement._inline
|
||||
if hasattr(test_statement, "_return_defaults"):
|
||||
self._return_defaults = test_statement._return_defaults
|
||||
|
||||
@property
|
||||
def _variant_mapping(self):
|
||||
return self.test_statement._variant_mapping
|
||||
|
||||
def _default_dialect(self):
|
||||
return self.test_statement._default_dialect()
|
||||
|
||||
def compile(self, dialect, **kw):
|
||||
return self.test_statement.compile.__func__(
|
||||
self, dialect=dialect, **kw
|
||||
)
|
||||
|
||||
def _compiler(self, dialect, **kw):
|
||||
return self.test_statement._compiler.__func__(
|
||||
self, dialect, **kw
|
||||
)
|
||||
|
||||
def _compiler_dispatch(self, compiler, **kwargs):
|
||||
if hasattr(compiler, "statement"):
|
||||
with mock.patch.object(
|
||||
compiler, "statement", DontAccess()
|
||||
):
|
||||
return self.test_statement._compiler_dispatch(
|
||||
compiler, **kwargs
|
||||
)
|
||||
else:
|
||||
return self.test_statement._compiler_dispatch(
|
||||
compiler, **kwargs
|
||||
)
|
||||
|
||||
# no construct can assume it's the "top level" construct in all cases
|
||||
# as anything can be nested. ensure constructs don't assume they
|
||||
# are the "self.statement" element
|
||||
c = CheckCompilerAccess(clause).compile(dialect=dialect, **kw)
|
||||
|
||||
if isinstance(clause, sqltypes.TypeEngine):
|
||||
cache_key_no_warnings = clause._static_cache_key
|
||||
if cache_key_no_warnings:
|
||||
hash(cache_key_no_warnings)
|
||||
else:
|
||||
cache_key_no_warnings = clause._generate_cache_key()
|
||||
if cache_key_no_warnings:
|
||||
hash(cache_key_no_warnings[0])
|
||||
|
||||
param_str = repr(getattr(c, "params", {}))
|
||||
param_str = param_str.encode("utf-8").decode("ascii", "ignore")
|
||||
print(("\nSQL String:\n" + str(c) + param_str).encode("utf-8"))
|
||||
|
||||
cc = re.sub(r"[\n\t]", "", str(c))
|
||||
|
||||
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
|
||||
|
||||
if checkparams is not None:
|
||||
if render_postcompile:
|
||||
expanded_state = c.construct_expanded_state(
|
||||
params, escape_names=False
|
||||
)
|
||||
eq_(expanded_state.parameters, checkparams)
|
||||
else:
|
||||
eq_(c.construct_params(params), checkparams)
|
||||
if checkpositional is not None:
|
||||
if render_postcompile:
|
||||
expanded_state = c.construct_expanded_state(
|
||||
params, escape_names=False
|
||||
)
|
||||
eq_(
|
||||
tuple(
|
||||
[
|
||||
expanded_state.parameters[x]
|
||||
for x in expanded_state.positiontup
|
||||
]
|
||||
),
|
||||
checkpositional,
|
||||
)
|
||||
else:
|
||||
p = c.construct_params(params, escape_names=False)
|
||||
eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
|
||||
if check_prefetch is not None:
|
||||
eq_(c.prefetch, check_prefetch)
|
||||
if check_literal_execute is not None:
|
||||
eq_(
|
||||
{
|
||||
c.bind_names[b]: b.effective_value
|
||||
for b in c.literal_execute_params
|
||||
},
|
||||
check_literal_execute,
|
||||
)
|
||||
if check_post_param is not None:
|
||||
eq_(
|
||||
{
|
||||
c.bind_names[b]: b.effective_value
|
||||
for b in c.post_compile_params
|
||||
},
|
||||
check_post_param,
|
||||
)
|
||||
if check_param_order and getattr(c, "params", None):
|
||||
|
||||
def get_dialect(paramstyle, positional):
|
||||
cp = copy(dialect)
|
||||
cp.paramstyle = paramstyle
|
||||
cp.positional = positional
|
||||
return cp
|
||||
|
||||
pyformat_dialect = get_dialect("pyformat", False)
|
||||
pyformat_c = clause.compile(dialect=pyformat_dialect, **kw)
|
||||
stmt = re.sub(r"[\n\t]", "", str(pyformat_c))
|
||||
|
||||
qmark_dialect = get_dialect("qmark", True)
|
||||
qmark_c = clause.compile(dialect=qmark_dialect, **kw)
|
||||
values = list(qmark_c.positiontup)
|
||||
escaped = qmark_c.escaped_bind_names
|
||||
|
||||
for post_param in (
|
||||
qmark_c.post_compile_params | qmark_c.literal_execute_params
|
||||
):
|
||||
name = qmark_c.bind_names[post_param]
|
||||
if name in values:
|
||||
values = [v for v in values if v != name]
|
||||
positions = []
|
||||
pos_by_value = defaultdict(list)
|
||||
for v in values:
|
||||
try:
|
||||
if v in pos_by_value:
|
||||
start = pos_by_value[v][-1]
|
||||
else:
|
||||
start = 0
|
||||
esc = escaped.get(v, v)
|
||||
pos = stmt.index("%%(%s)s" % (esc,), start) + 2
|
||||
positions.append(pos)
|
||||
pos_by_value[v].append(pos)
|
||||
except ValueError:
|
||||
msg = "Expected to find bindparam %r in %r" % (v, stmt)
|
||||
assert False, msg
|
||||
|
||||
ordered = all(
|
||||
positions[i - 1] < positions[i]
|
||||
for i in range(1, len(positions))
|
||||
)
|
||||
|
||||
expected = [v for _, v in sorted(zip(positions, values))]
|
||||
|
||||
msg = (
|
||||
"Order of parameters %s does not match the order "
|
||||
"in the statement %s. Statement %r" % (values, expected, stmt)
|
||||
)
|
||||
|
||||
is_true(ordered, msg)
|
||||
|
||||
|
||||
class ComparesTables:
|
||||
def assert_tables_equal(
|
||||
self,
|
||||
table,
|
||||
reflected_table,
|
||||
strict_types=False,
|
||||
strict_constraints=True,
|
||||
):
|
||||
assert len(table.c) == len(reflected_table.c)
|
||||
for c, reflected_c in zip(table.c, reflected_table.c):
|
||||
eq_(c.name, reflected_c.name)
|
||||
assert reflected_c is reflected_table.c[c.name]
|
||||
|
||||
if strict_constraints:
|
||||
eq_(c.primary_key, reflected_c.primary_key)
|
||||
eq_(c.nullable, reflected_c.nullable)
|
||||
|
||||
if strict_types:
|
||||
msg = "Type '%s' doesn't correspond to type '%s'"
|
||||
assert isinstance(reflected_c.type, type(c.type)), msg % (
|
||||
reflected_c.type,
|
||||
c.type,
|
||||
)
|
||||
else:
|
||||
self.assert_types_base(reflected_c, c)
|
||||
|
||||
if isinstance(c.type, sqltypes.String):
|
||||
eq_(c.type.length, reflected_c.type.length)
|
||||
|
||||
if strict_constraints:
|
||||
eq_(
|
||||
{f.column.name for f in c.foreign_keys},
|
||||
{f.column.name for f in reflected_c.foreign_keys},
|
||||
)
|
||||
if c.server_default:
|
||||
assert isinstance(
|
||||
reflected_c.server_default, schema.FetchedValue
|
||||
)
|
||||
|
||||
if strict_constraints:
|
||||
assert len(table.primary_key) == len(reflected_table.primary_key)
|
||||
for c in table.primary_key:
|
||||
assert reflected_table.primary_key.columns[c.name] is not None
|
||||
|
||||
def assert_types_base(self, c1, c2):
|
||||
assert c1.type._compare_type_affinity(
|
||||
c2.type
|
||||
), "On column %r, type '%s' doesn't correspond to type '%s'" % (
|
||||
c1.name,
|
||||
c1.type,
|
||||
c2.type,
|
||||
)
|
||||
|
||||
|
||||
class AssertsExecutionResults:
|
||||
def assert_result(self, result, class_, *objects):
|
||||
result = list(result)
|
||||
print(repr(result))
|
||||
self.assert_list(result, class_, objects)
|
||||
|
||||
def assert_list(self, result, class_, list_):
|
||||
self.assert_(
|
||||
len(result) == len(list_),
|
||||
"result list is not the same size as test list, "
|
||||
+ "for class "
|
||||
+ class_.__name__,
|
||||
)
|
||||
for i in range(0, len(list_)):
|
||||
self.assert_row(class_, result[i], list_[i])
|
||||
|
||||
def assert_row(self, class_, rowobj, desc):
|
||||
self.assert_(
|
||||
rowobj.__class__ is class_, "item class is not " + repr(class_)
|
||||
)
|
||||
for key, value in desc.items():
|
||||
if isinstance(value, tuple):
|
||||
if isinstance(value[1], list):
|
||||
self.assert_list(getattr(rowobj, key), value[0], value[1])
|
||||
else:
|
||||
self.assert_row(value[0], getattr(rowobj, key), value[1])
|
||||
else:
|
||||
self.assert_(
|
||||
getattr(rowobj, key) == value,
|
||||
"attribute %s value %s does not match %s"
|
||||
% (key, getattr(rowobj, key), value),
|
||||
)
|
||||
|
||||
def assert_unordered_result(self, result, cls, *expected):
|
||||
"""As assert_result, but the order of objects is not considered.
|
||||
|
||||
The algorithm is very expensive but not a big deal for the small
|
||||
numbers of rows that the test suite manipulates.
|
||||
"""
|
||||
|
||||
class immutabledict(dict):
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
found = util.IdentitySet(result)
|
||||
expected = {immutabledict(e) for e in expected}
|
||||
|
||||
for wrong in filterfalse(lambda o: isinstance(o, cls), found):
|
||||
fail(
|
||||
'Unexpected type "%s", expected "%s"'
|
||||
% (type(wrong).__name__, cls.__name__)
|
||||
)
|
||||
|
||||
if len(found) != len(expected):
|
||||
fail(
|
||||
'Unexpected object count "%s", expected "%s"'
|
||||
% (len(found), len(expected))
|
||||
)
|
||||
|
||||
NOVALUE = object()
|
||||
|
||||
def _compare_item(obj, spec):
|
||||
for key, value in spec.items():
|
||||
if isinstance(value, tuple):
|
||||
try:
|
||||
self.assert_unordered_result(
|
||||
getattr(obj, key), value[0], *value[1]
|
||||
)
|
||||
except AssertionError:
|
||||
return False
|
||||
else:
|
||||
if getattr(obj, key, NOVALUE) != value:
|
||||
return False
|
||||
return True
|
||||
|
||||
for expected_item in expected:
|
||||
for found_item in found:
|
||||
if _compare_item(found_item, expected_item):
|
||||
found.remove(found_item)
|
||||
break
|
||||
else:
|
||||
fail(
|
||||
"Expected %s instance with attributes %s not found."
|
||||
% (cls.__name__, repr(expected_item))
|
||||
)
|
||||
return True
|
||||
|
||||
def sql_execution_asserter(self, db=None):
|
||||
if db is None:
|
||||
from . import db as db
|
||||
|
||||
return assertsql.assert_engine(db)
|
||||
|
||||
def assert_sql_execution(self, db, callable_, *rules):
|
||||
with self.sql_execution_asserter(db) as asserter:
|
||||
result = callable_()
|
||||
asserter.assert_(*rules)
|
||||
return result
|
||||
|
||||
def assert_sql(self, db, callable_, rules):
|
||||
newrules = []
|
||||
for rule in rules:
|
||||
if isinstance(rule, dict):
|
||||
newrule = assertsql.AllOf(
|
||||
*[assertsql.CompiledSQL(k, v) for k, v in rule.items()]
|
||||
)
|
||||
else:
|
||||
newrule = assertsql.CompiledSQL(*rule)
|
||||
newrules.append(newrule)
|
||||
|
||||
return self.assert_sql_execution(db, callable_, *newrules)
|
||||
|
||||
def assert_sql_count(self, db, callable_, count):
|
||||
return self.assert_sql_execution(
|
||||
db, callable_, assertsql.CountStatements(count)
|
||||
)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assert_execution(self, db, *rules):
|
||||
with self.sql_execution_asserter(db) as asserter:
|
||||
yield
|
||||
asserter.assert_(*rules)
|
||||
|
||||
def assert_statement_count(self, db, count):
|
||||
return self.assert_execution(db, assertsql.CountStatements(count))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assert_statement_count_multi_db(self, dbs, counts):
|
||||
recs = [
|
||||
(self.sql_execution_asserter(db), db, count)
|
||||
for (db, count) in zip(dbs, counts)
|
||||
]
|
||||
asserters = []
|
||||
for ctx, db, count in recs:
|
||||
asserters.append(ctx.__enter__())
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for asserter, (ctx, db, count) in zip(asserters, recs):
|
||||
ctx.__exit__(None, None, None)
|
||||
asserter.assert_(assertsql.CountStatements(count))
|
||||
|
||||
|
||||
class ComparesIndexes:
|
||||
def compare_table_index_with_expected(
|
||||
self, table: schema.Table, expected: list, dialect_name: str
|
||||
):
|
||||
eq_(len(table.indexes), len(expected))
|
||||
idx_dict = {idx.name: idx for idx in table.indexes}
|
||||
for exp in expected:
|
||||
idx = idx_dict[exp["name"]]
|
||||
eq_(idx.unique, exp["unique"])
|
||||
cols = [c for c in exp["column_names"] if c is not None]
|
||||
eq_(len(idx.columns), len(cols))
|
||||
for c in cols:
|
||||
is_true(c in idx.columns)
|
||||
exprs = exp.get("expressions")
|
||||
if exprs:
|
||||
eq_(len(idx.expressions), len(exprs))
|
||||
for idx_exp, expr, col in zip(
|
||||
idx.expressions, exprs, exp["column_names"]
|
||||
):
|
||||
if col is None:
|
||||
eq_(idx_exp.text, expr)
|
||||
if (
|
||||
exp.get("dialect_options")
|
||||
and f"{dialect_name}_include" in exp["dialect_options"]
|
||||
):
|
||||
eq_(
|
||||
idx.dialect_options[dialect_name]["include"],
|
||||
exp["dialect_options"][f"{dialect_name}_include"],
|
||||
)
|
@ -0,0 +1,516 @@
|
||||
# testing/assertsql.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import itertools
|
||||
import re
|
||||
|
||||
from .. import event
|
||||
from ..engine import url
|
||||
from ..engine.default import DefaultDialect
|
||||
from ..schema import BaseDDLElement
|
||||
|
||||
|
||||
class AssertRule:
|
||||
is_consumed = False
|
||||
errormessage = None
|
||||
consume_statement = True
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
pass
|
||||
|
||||
def no_more_statements(self):
|
||||
assert False, (
|
||||
"All statements are complete, but pending "
|
||||
"assertion rules remain"
|
||||
)
|
||||
|
||||
|
||||
class SQLMatchRule(AssertRule):
|
||||
pass
|
||||
|
||||
|
||||
class CursorSQL(SQLMatchRule):
|
||||
def __init__(self, statement, params=None, consume_statement=True):
|
||||
self.statement = statement
|
||||
self.params = params
|
||||
self.consume_statement = consume_statement
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
stmt = execute_observed.statements[0]
|
||||
if self.statement != stmt.statement or (
|
||||
self.params is not None and self.params != stmt.parameters
|
||||
):
|
||||
self.consume_statement = True
|
||||
self.errormessage = (
|
||||
"Testing for exact SQL %s parameters %s received %s %s"
|
||||
% (
|
||||
self.statement,
|
||||
self.params,
|
||||
stmt.statement,
|
||||
stmt.parameters,
|
||||
)
|
||||
)
|
||||
else:
|
||||
execute_observed.statements.pop(0)
|
||||
self.is_consumed = True
|
||||
if not execute_observed.statements:
|
||||
self.consume_statement = True
|
||||
|
||||
|
||||
class CompiledSQL(SQLMatchRule):
|
||||
def __init__(
|
||||
self, statement, params=None, dialect="default", enable_returning=True
|
||||
):
|
||||
self.statement = statement
|
||||
self.params = params
|
||||
self.dialect = dialect
|
||||
self.enable_returning = enable_returning
|
||||
|
||||
def _compare_sql(self, execute_observed, received_statement):
|
||||
stmt = re.sub(r"[\n\t]", "", self.statement)
|
||||
return received_statement == stmt
|
||||
|
||||
def _compile_dialect(self, execute_observed):
|
||||
if self.dialect == "default":
|
||||
dialect = DefaultDialect()
|
||||
# this is currently what tests are expecting
|
||||
# dialect.supports_default_values = True
|
||||
dialect.supports_default_metavalue = True
|
||||
|
||||
if self.enable_returning:
|
||||
dialect.insert_returning = dialect.update_returning = (
|
||||
dialect.delete_returning
|
||||
) = True
|
||||
dialect.use_insertmanyvalues = True
|
||||
dialect.supports_multivalues_insert = True
|
||||
dialect.update_returning_multifrom = True
|
||||
dialect.delete_returning_multifrom = True
|
||||
# dialect.favor_returning_over_lastrowid = True
|
||||
# dialect.insert_null_pk_still_autoincrements = True
|
||||
|
||||
# this is calculated but we need it to be True for this
|
||||
# to look like all the current RETURNING dialects
|
||||
assert dialect.insert_executemany_returning
|
||||
|
||||
return dialect
|
||||
else:
|
||||
return url.URL.create(self.dialect).get_dialect()()
|
||||
|
||||
def _received_statement(self, execute_observed):
|
||||
"""reconstruct the statement and params in terms
|
||||
of a target dialect, which for CompiledSQL is just DefaultDialect."""
|
||||
|
||||
context = execute_observed.context
|
||||
compare_dialect = self._compile_dialect(execute_observed)
|
||||
|
||||
# received_statement runs a full compile(). we should not need to
|
||||
# consider extracted_parameters; if we do this indicates some state
|
||||
# is being sent from a previous cached query, which some misbehaviors
|
||||
# in the ORM can cause, see #6881
|
||||
cache_key = None # execute_observed.context.compiled.cache_key
|
||||
extracted_parameters = (
|
||||
None # execute_observed.context.extracted_parameters
|
||||
)
|
||||
|
||||
if "schema_translate_map" in context.execution_options:
|
||||
map_ = context.execution_options["schema_translate_map"]
|
||||
else:
|
||||
map_ = None
|
||||
|
||||
if isinstance(execute_observed.clauseelement, BaseDDLElement):
|
||||
compiled = execute_observed.clauseelement.compile(
|
||||
dialect=compare_dialect,
|
||||
schema_translate_map=map_,
|
||||
)
|
||||
else:
|
||||
compiled = execute_observed.clauseelement.compile(
|
||||
cache_key=cache_key,
|
||||
dialect=compare_dialect,
|
||||
column_keys=context.compiled.column_keys,
|
||||
for_executemany=context.compiled.for_executemany,
|
||||
schema_translate_map=map_,
|
||||
)
|
||||
_received_statement = re.sub(r"[\n\t]", "", str(compiled))
|
||||
parameters = execute_observed.parameters
|
||||
|
||||
if not parameters:
|
||||
_received_parameters = [
|
||||
compiled.construct_params(
|
||||
extracted_parameters=extracted_parameters
|
||||
)
|
||||
]
|
||||
else:
|
||||
_received_parameters = [
|
||||
compiled.construct_params(
|
||||
m, extracted_parameters=extracted_parameters
|
||||
)
|
||||
for m in parameters
|
||||
]
|
||||
|
||||
return _received_statement, _received_parameters
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
context = execute_observed.context
|
||||
|
||||
_received_statement, _received_parameters = self._received_statement(
|
||||
execute_observed
|
||||
)
|
||||
params = self._all_params(context)
|
||||
|
||||
equivalent = self._compare_sql(execute_observed, _received_statement)
|
||||
|
||||
if equivalent:
|
||||
if params is not None:
|
||||
all_params = list(params)
|
||||
all_received = list(_received_parameters)
|
||||
while all_params and all_received:
|
||||
param = dict(all_params.pop(0))
|
||||
|
||||
for idx, received in enumerate(list(all_received)):
|
||||
# do a positive compare only
|
||||
for param_key in param:
|
||||
# a key in param did not match current
|
||||
# 'received'
|
||||
if (
|
||||
param_key not in received
|
||||
or received[param_key] != param[param_key]
|
||||
):
|
||||
break
|
||||
else:
|
||||
# all keys in param matched 'received';
|
||||
# onto next param
|
||||
del all_received[idx]
|
||||
break
|
||||
else:
|
||||
# param did not match any entry
|
||||
# in all_received
|
||||
equivalent = False
|
||||
break
|
||||
if all_params or all_received:
|
||||
equivalent = False
|
||||
|
||||
if equivalent:
|
||||
self.is_consumed = True
|
||||
self.errormessage = None
|
||||
else:
|
||||
self.errormessage = self._failure_message(
|
||||
execute_observed, params
|
||||
) % {
|
||||
"received_statement": _received_statement,
|
||||
"received_parameters": _received_parameters,
|
||||
}
|
||||
|
||||
def _all_params(self, context):
|
||||
if self.params:
|
||||
if callable(self.params):
|
||||
params = self.params(context)
|
||||
else:
|
||||
params = self.params
|
||||
if not isinstance(params, list):
|
||||
params = [params]
|
||||
return params
|
||||
else:
|
||||
return None
|
||||
|
||||
def _failure_message(self, execute_observed, expected_params):
|
||||
return (
|
||||
"Testing for compiled statement\n%r partial params %s, "
|
||||
"received\n%%(received_statement)r with params "
|
||||
"%%(received_parameters)r"
|
||||
% (
|
||||
self.statement.replace("%", "%%"),
|
||||
repr(expected_params).replace("%", "%%"),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RegexSQL(CompiledSQL):
|
||||
def __init__(
|
||||
self, regex, params=None, dialect="default", enable_returning=False
|
||||
):
|
||||
SQLMatchRule.__init__(self)
|
||||
self.regex = re.compile(regex)
|
||||
self.orig_regex = regex
|
||||
self.params = params
|
||||
self.dialect = dialect
|
||||
self.enable_returning = enable_returning
|
||||
|
||||
def _failure_message(self, execute_observed, expected_params):
|
||||
return (
|
||||
"Testing for compiled statement ~%r partial params %s, "
|
||||
"received %%(received_statement)r with params "
|
||||
"%%(received_parameters)r"
|
||||
% (
|
||||
self.orig_regex.replace("%", "%%"),
|
||||
repr(expected_params).replace("%", "%%"),
|
||||
)
|
||||
)
|
||||
|
||||
def _compare_sql(self, execute_observed, received_statement):
|
||||
return bool(self.regex.match(received_statement))
|
||||
|
||||
|
||||
class DialectSQL(CompiledSQL):
|
||||
def _compile_dialect(self, execute_observed):
|
||||
return execute_observed.context.dialect
|
||||
|
||||
def _compare_no_space(self, real_stmt, received_stmt):
|
||||
stmt = re.sub(r"[\n\t]", "", real_stmt)
|
||||
return received_stmt == stmt
|
||||
|
||||
def _received_statement(self, execute_observed):
|
||||
received_stmt, received_params = super()._received_statement(
|
||||
execute_observed
|
||||
)
|
||||
|
||||
# TODO: why do we need this part?
|
||||
for real_stmt in execute_observed.statements:
|
||||
if self._compare_no_space(real_stmt.statement, received_stmt):
|
||||
break
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Can't locate compiled statement %r in list of "
|
||||
"statements actually invoked" % received_stmt
|
||||
)
|
||||
|
||||
return received_stmt, execute_observed.context.compiled_parameters
|
||||
|
||||
def _dialect_adjusted_statement(self, dialect):
|
||||
paramstyle = dialect.paramstyle
|
||||
stmt = re.sub(r"[\n\t]", "", self.statement)
|
||||
|
||||
# temporarily escape out PG double colons
|
||||
stmt = stmt.replace("::", "!!")
|
||||
|
||||
if paramstyle == "pyformat":
|
||||
stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
|
||||
else:
|
||||
# positional params
|
||||
repl = None
|
||||
if paramstyle == "qmark":
|
||||
repl = "?"
|
||||
elif paramstyle == "format":
|
||||
repl = r"%s"
|
||||
elif paramstyle.startswith("numeric"):
|
||||
counter = itertools.count(1)
|
||||
|
||||
num_identifier = "$" if paramstyle == "numeric_dollar" else ":"
|
||||
|
||||
def repl(m):
|
||||
return f"{num_identifier}{next(counter)}"
|
||||
|
||||
stmt = re.sub(r":([\w_]+)", repl, stmt)
|
||||
|
||||
# put them back
|
||||
stmt = stmt.replace("!!", "::")
|
||||
|
||||
return stmt
|
||||
|
||||
def _compare_sql(self, execute_observed, received_statement):
|
||||
stmt = self._dialect_adjusted_statement(
|
||||
execute_observed.context.dialect
|
||||
)
|
||||
return received_statement == stmt
|
||||
|
||||
def _failure_message(self, execute_observed, expected_params):
|
||||
return (
|
||||
"Testing for compiled statement\n%r partial params %s, "
|
||||
"received\n%%(received_statement)r with params "
|
||||
"%%(received_parameters)r"
|
||||
% (
|
||||
self._dialect_adjusted_statement(
|
||||
execute_observed.context.dialect
|
||||
).replace("%", "%%"),
|
||||
repr(expected_params).replace("%", "%%"),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class CountStatements(AssertRule):
|
||||
def __init__(self, count):
|
||||
self.count = count
|
||||
self._statement_count = 0
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
self._statement_count += 1
|
||||
|
||||
def no_more_statements(self):
|
||||
if self.count != self._statement_count:
|
||||
assert False, "desired statement count %d does not match %d" % (
|
||||
self.count,
|
||||
self._statement_count,
|
||||
)
|
||||
|
||||
|
||||
class AllOf(AssertRule):
|
||||
def __init__(self, *rules):
|
||||
self.rules = set(rules)
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
for rule in list(self.rules):
|
||||
rule.errormessage = None
|
||||
rule.process_statement(execute_observed)
|
||||
if rule.is_consumed:
|
||||
self.rules.discard(rule)
|
||||
if not self.rules:
|
||||
self.is_consumed = True
|
||||
break
|
||||
elif not rule.errormessage:
|
||||
# rule is not done yet
|
||||
self.errormessage = None
|
||||
break
|
||||
else:
|
||||
self.errormessage = list(self.rules)[0].errormessage
|
||||
|
||||
|
||||
class EachOf(AssertRule):
|
||||
def __init__(self, *rules):
|
||||
self.rules = list(rules)
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
if not self.rules:
|
||||
self.is_consumed = True
|
||||
self.consume_statement = False
|
||||
|
||||
while self.rules:
|
||||
rule = self.rules[0]
|
||||
rule.process_statement(execute_observed)
|
||||
if rule.is_consumed:
|
||||
self.rules.pop(0)
|
||||
elif rule.errormessage:
|
||||
self.errormessage = rule.errormessage
|
||||
if rule.consume_statement:
|
||||
break
|
||||
|
||||
if not self.rules:
|
||||
self.is_consumed = True
|
||||
|
||||
def no_more_statements(self):
|
||||
if self.rules and not self.rules[0].is_consumed:
|
||||
self.rules[0].no_more_statements()
|
||||
elif self.rules:
|
||||
super().no_more_statements()
|
||||
|
||||
|
||||
class Conditional(EachOf):
|
||||
def __init__(self, condition, rules, else_rules):
|
||||
if condition:
|
||||
super().__init__(*rules)
|
||||
else:
|
||||
super().__init__(*else_rules)
|
||||
|
||||
|
||||
class Or(AllOf):
|
||||
def process_statement(self, execute_observed):
|
||||
for rule in self.rules:
|
||||
rule.process_statement(execute_observed)
|
||||
if rule.is_consumed:
|
||||
self.is_consumed = True
|
||||
break
|
||||
else:
|
||||
self.errormessage = list(self.rules)[0].errormessage
|
||||
|
||||
|
||||
class SQLExecuteObserved:
|
||||
def __init__(self, context, clauseelement, multiparams, params):
|
||||
self.context = context
|
||||
self.clauseelement = clauseelement
|
||||
|
||||
if multiparams:
|
||||
self.parameters = multiparams
|
||||
elif params:
|
||||
self.parameters = [params]
|
||||
else:
|
||||
self.parameters = []
|
||||
self.statements = []
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.statements)
|
||||
|
||||
|
||||
class SQLCursorExecuteObserved(
|
||||
collections.namedtuple(
|
||||
"SQLCursorExecuteObserved",
|
||||
["statement", "parameters", "context", "executemany"],
|
||||
)
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class SQLAsserter:
|
||||
def __init__(self):
|
||||
self.accumulated = []
|
||||
|
||||
def _close(self):
|
||||
self._final = self.accumulated
|
||||
del self.accumulated
|
||||
|
||||
def assert_(self, *rules):
|
||||
rule = EachOf(*rules)
|
||||
|
||||
observed = list(self._final)
|
||||
while observed:
|
||||
statement = observed.pop(0)
|
||||
rule.process_statement(statement)
|
||||
if rule.is_consumed:
|
||||
break
|
||||
elif rule.errormessage:
|
||||
assert False, rule.errormessage
|
||||
if observed:
|
||||
assert False, "Additional SQL statements remain:\n%s" % observed
|
||||
elif not rule.is_consumed:
|
||||
rule.no_more_statements()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assert_engine(engine):
|
||||
asserter = SQLAsserter()
|
||||
|
||||
orig = []
|
||||
|
||||
@event.listens_for(engine, "before_execute")
|
||||
def connection_execute(
|
||||
conn, clauseelement, multiparams, params, execution_options
|
||||
):
|
||||
# grab the original statement + params before any cursor
|
||||
# execution
|
||||
orig[:] = clauseelement, multiparams, params
|
||||
|
||||
@event.listens_for(engine, "after_cursor_execute")
|
||||
def cursor_execute(
|
||||
conn, cursor, statement, parameters, context, executemany
|
||||
):
|
||||
if not context:
|
||||
return
|
||||
# then grab real cursor statements and associate them all
|
||||
# around a single context
|
||||
if (
|
||||
asserter.accumulated
|
||||
and asserter.accumulated[-1].context is context
|
||||
):
|
||||
obs = asserter.accumulated[-1]
|
||||
else:
|
||||
obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
|
||||
asserter.accumulated.append(obs)
|
||||
obs.statements.append(
|
||||
SQLCursorExecuteObserved(
|
||||
statement, parameters, context, executemany
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
yield asserter
|
||||
finally:
|
||||
event.remove(engine, "after_cursor_execute", cursor_execute)
|
||||
event.remove(engine, "before_execute", connection_execute)
|
||||
asserter._close()
|
135
venv/lib/python3.11/site-packages/sqlalchemy/testing/asyncio.py
Normal file
135
venv/lib/python3.11/site-packages/sqlalchemy/testing/asyncio.py
Normal file
@ -0,0 +1,135 @@
|
||||
# testing/asyncio.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
# functions and wrappers to run tests, fixtures, provisioning and
|
||||
# setup/teardown in an asyncio event loop, conditionally based on the
|
||||
# current DB driver being used for a test.
|
||||
|
||||
# note that SQLAlchemy's asyncio integration also supports a method
|
||||
# of running individual asyncio functions inside of separate event loops
|
||||
# using "async_fallback" mode; however running whole functions in the event
|
||||
# loop is a more accurate test for how SQLAlchemy's asyncio features
|
||||
# would run in the real world.
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import wraps
|
||||
import inspect
|
||||
|
||||
from . import config
|
||||
from ..util.concurrency import _AsyncUtil
|
||||
|
||||
# may be set to False if the
|
||||
# --disable-asyncio flag is passed to the test runner.
|
||||
ENABLE_ASYNCIO = True
|
||||
_async_util = _AsyncUtil() # it has lazy init so just always create one
|
||||
|
||||
|
||||
def _shutdown():
|
||||
"""called when the test finishes"""
|
||||
_async_util.close()
|
||||
|
||||
|
||||
def _run_coroutine_function(fn, *args, **kwargs):
|
||||
return _async_util.run(fn, *args, **kwargs)
|
||||
|
||||
|
||||
def _assume_async(fn, *args, **kwargs):
|
||||
"""Run a function in an asyncio loop unconditionally.
|
||||
|
||||
This function is used for provisioning features like
|
||||
testing a database connection for server info.
|
||||
|
||||
Note that for blocking IO database drivers, this means they block the
|
||||
event loop.
|
||||
|
||||
"""
|
||||
|
||||
if not ENABLE_ASYNCIO:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return _async_util.run_in_greenlet(fn, *args, **kwargs)
|
||||
|
||||
|
||||
def _maybe_async_provisioning(fn, *args, **kwargs):
|
||||
"""Run a function in an asyncio loop if any current drivers might need it.
|
||||
|
||||
This function is used for provisioning features that take
|
||||
place outside of a specific database driver being selected, so if the
|
||||
current driver that happens to be used for the provisioning operation
|
||||
is an async driver, it will run in asyncio and not fail.
|
||||
|
||||
Note that for blocking IO database drivers, this means they block the
|
||||
event loop.
|
||||
|
||||
"""
|
||||
if not ENABLE_ASYNCIO:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
if config.any_async:
|
||||
return _async_util.run_in_greenlet(fn, *args, **kwargs)
|
||||
else:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
|
||||
def _maybe_async(fn, *args, **kwargs):
|
||||
"""Run a function in an asyncio loop if the current selected driver is
|
||||
async.
|
||||
|
||||
This function is used for test setup/teardown and tests themselves
|
||||
where the current DB driver is known.
|
||||
|
||||
|
||||
"""
|
||||
if not ENABLE_ASYNCIO:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
is_async = config._current.is_async
|
||||
|
||||
if is_async:
|
||||
return _async_util.run_in_greenlet(fn, *args, **kwargs)
|
||||
else:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
|
||||
def _maybe_async_wrapper(fn):
|
||||
"""Apply the _maybe_async function to an existing function and return
|
||||
as a wrapped callable, supporting generator functions as well.
|
||||
|
||||
This is currently used for pytest fixtures that support generator use.
|
||||
|
||||
"""
|
||||
|
||||
if inspect.isgeneratorfunction(fn):
|
||||
_stop = object()
|
||||
|
||||
def call_next(gen):
|
||||
try:
|
||||
return next(gen)
|
||||
# can't raise StopIteration in an awaitable.
|
||||
except StopIteration:
|
||||
return _stop
|
||||
|
||||
@wraps(fn)
|
||||
def wrap_fixture(*args, **kwargs):
|
||||
gen = fn(*args, **kwargs)
|
||||
while True:
|
||||
value = _maybe_async(call_next, gen)
|
||||
if value is _stop:
|
||||
break
|
||||
yield value
|
||||
|
||||
else:
|
||||
|
||||
@wraps(fn)
|
||||
def wrap_fixture(*args, **kwargs):
|
||||
return _maybe_async(fn, *args, **kwargs)
|
||||
|
||||
return wrap_fixture
|
423
venv/lib/python3.11/site-packages/sqlalchemy/testing/config.py
Normal file
423
venv/lib/python3.11/site-packages/sqlalchemy/testing/config.py
Normal file
@ -0,0 +1,423 @@
|
||||
# testing/config.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from argparse import Namespace
|
||||
import collections
|
||||
import inspect
|
||||
import typing
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Iterable
|
||||
from typing import NoReturn
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
|
||||
from . import mock
|
||||
from . import requirements as _requirements
|
||||
from .util import fail
|
||||
from .. import util
|
||||
|
||||
# default requirements; this is replaced by plugin_base when pytest
|
||||
# is run
|
||||
requirements = _requirements.SuiteRequirements()
|
||||
|
||||
db = None
|
||||
db_url = None
|
||||
db_opts = None
|
||||
file_config = None
|
||||
test_schema = None
|
||||
test_schema_2 = None
|
||||
any_async = False
|
||||
_current = None
|
||||
ident = "main"
|
||||
options: Namespace = None # type: ignore
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .plugin.plugin_base import FixtureFunctions
|
||||
|
||||
_fixture_functions: FixtureFunctions
|
||||
else:
|
||||
|
||||
class _NullFixtureFunctions:
|
||||
def _null_decorator(self):
|
||||
def go(fn):
|
||||
return fn
|
||||
|
||||
return go
|
||||
|
||||
def skip_test_exception(self, *arg, **kw):
|
||||
return Exception()
|
||||
|
||||
@property
|
||||
def add_to_marker(self):
|
||||
return mock.Mock()
|
||||
|
||||
def mark_base_test_class(self):
|
||||
return self._null_decorator()
|
||||
|
||||
def combinations(self, *arg_sets, **kw):
|
||||
return self._null_decorator()
|
||||
|
||||
def param_ident(self, *parameters):
|
||||
return self._null_decorator()
|
||||
|
||||
def fixture(self, *arg, **kw):
|
||||
return self._null_decorator()
|
||||
|
||||
def get_current_test_name(self):
|
||||
return None
|
||||
|
||||
def async_test(self, fn):
|
||||
return fn
|
||||
|
||||
# default fixture functions; these are replaced by plugin_base when
|
||||
# pytest runs
|
||||
_fixture_functions = _NullFixtureFunctions()
|
||||
|
||||
|
||||
_FN = TypeVar("_FN", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def combinations(
|
||||
*comb: Union[Any, Tuple[Any, ...]],
|
||||
argnames: Optional[str] = None,
|
||||
id_: Optional[str] = None,
|
||||
**kw: str,
|
||||
) -> Callable[[_FN], _FN]:
|
||||
r"""Deliver multiple versions of a test based on positional combinations.
|
||||
|
||||
This is a facade over pytest.mark.parametrize.
|
||||
|
||||
|
||||
:param \*comb: argument combinations. These are tuples that will be passed
|
||||
positionally to the decorated function.
|
||||
|
||||
:param argnames: optional list of argument names. These are the names
|
||||
of the arguments in the test function that correspond to the entries
|
||||
in each argument tuple. pytest.mark.parametrize requires this, however
|
||||
the combinations function will derive it automatically if not present
|
||||
by using ``inspect.getfullargspec(fn).args[1:]``. Note this assumes the
|
||||
first argument is "self" which is discarded.
|
||||
|
||||
:param id\_: optional id template. This is a string template that
|
||||
describes how the "id" for each parameter set should be defined, if any.
|
||||
The number of characters in the template should match the number of
|
||||
entries in each argument tuple. Each character describes how the
|
||||
corresponding entry in the argument tuple should be handled, as far as
|
||||
whether or not it is included in the arguments passed to the function, as
|
||||
well as if it is included in the tokens used to create the id of the
|
||||
parameter set.
|
||||
|
||||
If omitted, the argument combinations are passed to parametrize as is. If
|
||||
passed, each argument combination is turned into a pytest.param() object,
|
||||
mapping the elements of the argument tuple to produce an id based on a
|
||||
character value in the same position within the string template using the
|
||||
following scheme:
|
||||
|
||||
.. sourcecode:: text
|
||||
|
||||
i - the given argument is a string that is part of the id only, don't
|
||||
pass it as an argument
|
||||
|
||||
n - the given argument should be passed and it should be added to the
|
||||
id by calling the .__name__ attribute
|
||||
|
||||
r - the given argument should be passed and it should be added to the
|
||||
id by calling repr()
|
||||
|
||||
s - the given argument should be passed and it should be added to the
|
||||
id by calling str()
|
||||
|
||||
a - (argument) the given argument should be passed and it should not
|
||||
be used to generated the id
|
||||
|
||||
e.g.::
|
||||
|
||||
@testing.combinations(
|
||||
(operator.eq, "eq"),
|
||||
(operator.ne, "ne"),
|
||||
(operator.gt, "gt"),
|
||||
(operator.lt, "lt"),
|
||||
id_="na",
|
||||
)
|
||||
def test_operator(self, opfunc, name):
|
||||
pass
|
||||
|
||||
The above combination will call ``.__name__`` on the first member of
|
||||
each tuple and use that as the "id" to pytest.param().
|
||||
|
||||
|
||||
"""
|
||||
return _fixture_functions.combinations(
|
||||
*comb, id_=id_, argnames=argnames, **kw
|
||||
)
|
||||
|
||||
|
||||
def combinations_list(arg_iterable: Iterable[Tuple[Any, ...]], **kw):
|
||||
"As combination, but takes a single iterable"
|
||||
return combinations(*arg_iterable, **kw)
|
||||
|
||||
|
||||
class Variation:
|
||||
__slots__ = ("_name", "_argname")
|
||||
|
||||
def __init__(self, case, argname, case_names):
|
||||
self._name = case
|
||||
self._argname = argname
|
||||
for casename in case_names:
|
||||
setattr(self, casename, casename == case)
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
|
||||
def __getattr__(self, key: str) -> bool: ...
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
def __bool__(self):
|
||||
return self._name == self._argname
|
||||
|
||||
def __nonzero__(self):
|
||||
return not self.__bool__()
|
||||
|
||||
def __str__(self):
|
||||
return f"{self._argname}={self._name!r}"
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def fail(self) -> NoReturn:
|
||||
fail(f"Unknown {self}")
|
||||
|
||||
@classmethod
|
||||
def idfn(cls, variation):
|
||||
return variation.name
|
||||
|
||||
@classmethod
|
||||
def generate_cases(cls, argname, cases):
|
||||
case_names = [
|
||||
argname if c is True else "not_" + argname if c is False else c
|
||||
for c in cases
|
||||
]
|
||||
|
||||
typ = type(
|
||||
argname,
|
||||
(Variation,),
|
||||
{
|
||||
"__slots__": tuple(case_names),
|
||||
},
|
||||
)
|
||||
|
||||
return [typ(casename, argname, case_names) for casename in case_names]
|
||||
|
||||
|
||||
def variation(argname_or_fn, cases=None):
|
||||
"""a helper around testing.combinations that provides a single namespace
|
||||
that can be used as a switch.
|
||||
|
||||
e.g.::
|
||||
|
||||
@testing.variation("querytyp", ["select", "subquery", "legacy_query"])
|
||||
@testing.variation("lazy", ["select", "raise", "raise_on_sql"])
|
||||
def test_thing(self, querytyp, lazy, decl_base):
|
||||
class Thing(decl_base):
|
||||
__tablename__ = "thing"
|
||||
|
||||
# use name directly
|
||||
rel = relationship("Rel", lazy=lazy.name)
|
||||
|
||||
# use as a switch
|
||||
if querytyp.select:
|
||||
stmt = select(Thing)
|
||||
elif querytyp.subquery:
|
||||
stmt = select(Thing).subquery()
|
||||
elif querytyp.legacy_query:
|
||||
stmt = Session.query(Thing)
|
||||
else:
|
||||
querytyp.fail()
|
||||
|
||||
The variable provided is a slots object of boolean variables, as well
|
||||
as the name of the case itself under the attribute ".name"
|
||||
|
||||
"""
|
||||
|
||||
if inspect.isfunction(argname_or_fn):
|
||||
argname = argname_or_fn.__name__
|
||||
cases = argname_or_fn(None)
|
||||
|
||||
@variation_fixture(argname, cases)
|
||||
def go(self, request):
|
||||
yield request.param
|
||||
|
||||
return go
|
||||
else:
|
||||
argname = argname_or_fn
|
||||
cases_plus_limitations = [
|
||||
(
|
||||
entry
|
||||
if (isinstance(entry, tuple) and len(entry) == 2)
|
||||
else (entry, None)
|
||||
)
|
||||
for entry in cases
|
||||
]
|
||||
|
||||
variations = Variation.generate_cases(
|
||||
argname, [c for c, l in cases_plus_limitations]
|
||||
)
|
||||
return combinations(
|
||||
*[
|
||||
(
|
||||
(variation._name, variation, limitation)
|
||||
if limitation is not None
|
||||
else (variation._name, variation)
|
||||
)
|
||||
for variation, (case, limitation) in zip(
|
||||
variations, cases_plus_limitations
|
||||
)
|
||||
],
|
||||
id_="ia",
|
||||
argnames=argname,
|
||||
)
|
||||
|
||||
|
||||
def variation_fixture(argname, cases, scope="function"):
|
||||
return fixture(
|
||||
params=Variation.generate_cases(argname, cases),
|
||||
ids=Variation.idfn,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
|
||||
def fixture(*arg: Any, **kw: Any) -> Any:
|
||||
return _fixture_functions.fixture(*arg, **kw)
|
||||
|
||||
|
||||
def get_current_test_name() -> str:
|
||||
return _fixture_functions.get_current_test_name()
|
||||
|
||||
|
||||
def mark_base_test_class() -> Any:
|
||||
return _fixture_functions.mark_base_test_class()
|
||||
|
||||
|
||||
class _AddToMarker:
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
return getattr(_fixture_functions.add_to_marker, attr)
|
||||
|
||||
|
||||
add_to_marker = _AddToMarker()
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self, db, db_opts, options, file_config):
|
||||
self._set_name(db)
|
||||
self.db = db
|
||||
self.db_opts = db_opts
|
||||
self.options = options
|
||||
self.file_config = file_config
|
||||
self.test_schema = "test_schema"
|
||||
self.test_schema_2 = "test_schema_2"
|
||||
|
||||
self.is_async = db.dialect.is_async and not util.asbool(
|
||||
db.url.query.get("async_fallback", False)
|
||||
)
|
||||
|
||||
_stack = collections.deque()
|
||||
_configs = set()
|
||||
|
||||
def _set_name(self, db):
|
||||
suffix = "_async" if db.dialect.is_async else ""
|
||||
if db.dialect.server_version_info:
|
||||
svi = ".".join(str(tok) for tok in db.dialect.server_version_info)
|
||||
self.name = "%s+%s%s_[%s]" % (db.name, db.driver, suffix, svi)
|
||||
else:
|
||||
self.name = "%s+%s%s" % (db.name, db.driver, suffix)
|
||||
|
||||
@classmethod
|
||||
def register(cls, db, db_opts, options, file_config):
|
||||
"""add a config as one of the global configs.
|
||||
|
||||
If there are no configs set up yet, this config also
|
||||
gets set as the "_current".
|
||||
"""
|
||||
global any_async
|
||||
|
||||
cfg = Config(db, db_opts, options, file_config)
|
||||
|
||||
# if any backends include an async driver, then ensure
|
||||
# all setup/teardown and tests are wrapped in the maybe_async()
|
||||
# decorator that will set up a greenlet context for async drivers.
|
||||
any_async = any_async or cfg.is_async
|
||||
|
||||
cls._configs.add(cfg)
|
||||
return cfg
|
||||
|
||||
@classmethod
|
||||
def set_as_current(cls, config, namespace):
|
||||
global db, _current, db_url, test_schema, test_schema_2, db_opts
|
||||
_current = config
|
||||
db_url = config.db.url
|
||||
db_opts = config.db_opts
|
||||
test_schema = config.test_schema
|
||||
test_schema_2 = config.test_schema_2
|
||||
namespace.db = db = config.db
|
||||
|
||||
@classmethod
|
||||
def push_engine(cls, db, namespace):
|
||||
assert _current, "Can't push without a default Config set up"
|
||||
cls.push(
|
||||
Config(
|
||||
db, _current.db_opts, _current.options, _current.file_config
|
||||
),
|
||||
namespace,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def push(cls, config, namespace):
|
||||
cls._stack.append(_current)
|
||||
cls.set_as_current(config, namespace)
|
||||
|
||||
@classmethod
|
||||
def pop(cls, namespace):
|
||||
if cls._stack:
|
||||
# a failed test w/ -x option can call reset() ahead of time
|
||||
_current = cls._stack[-1]
|
||||
del cls._stack[-1]
|
||||
cls.set_as_current(_current, namespace)
|
||||
|
||||
@classmethod
|
||||
def reset(cls, namespace):
|
||||
if cls._stack:
|
||||
cls.set_as_current(cls._stack[0], namespace)
|
||||
cls._stack.clear()
|
||||
|
||||
@classmethod
|
||||
def all_configs(cls):
|
||||
return cls._configs
|
||||
|
||||
@classmethod
|
||||
def all_dbs(cls):
|
||||
for cfg in cls.all_configs():
|
||||
yield cfg.db
|
||||
|
||||
def skip_test(self, msg):
|
||||
skip_test(msg)
|
||||
|
||||
|
||||
def skip_test(msg):
|
||||
raise _fixture_functions.skip_test_exception(msg)
|
||||
|
||||
|
||||
def async_test(fn):
|
||||
return _fixture_functions.async_test(fn)
|
474
venv/lib/python3.11/site-packages/sqlalchemy/testing/engines.py
Normal file
474
venv/lib/python3.11/site-packages/sqlalchemy/testing/engines.py
Normal file
@ -0,0 +1,474 @@
|
||||
# testing/engines.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import re
|
||||
import typing
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
import warnings
|
||||
import weakref
|
||||
|
||||
from . import config
|
||||
from .util import decorator
|
||||
from .util import gc_collect
|
||||
from .. import event
|
||||
from .. import pool
|
||||
from ..util import await_only
|
||||
from ..util.typing import Literal
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ..engine import Engine
|
||||
from ..engine.url import URL
|
||||
from ..ext.asyncio import AsyncEngine
|
||||
|
||||
|
||||
class ConnectionKiller:
|
||||
def __init__(self):
|
||||
self.proxy_refs = weakref.WeakKeyDictionary()
|
||||
self.testing_engines = collections.defaultdict(set)
|
||||
self.dbapi_connections = set()
|
||||
|
||||
def add_pool(self, pool):
|
||||
event.listen(pool, "checkout", self._add_conn)
|
||||
event.listen(pool, "checkin", self._remove_conn)
|
||||
event.listen(pool, "close", self._remove_conn)
|
||||
event.listen(pool, "close_detached", self._remove_conn)
|
||||
# note we are keeping "invalidated" here, as those are still
|
||||
# opened connections we would like to roll back
|
||||
|
||||
def _add_conn(self, dbapi_con, con_record, con_proxy):
|
||||
self.dbapi_connections.add(dbapi_con)
|
||||
self.proxy_refs[con_proxy] = True
|
||||
|
||||
def _remove_conn(self, dbapi_conn, *arg):
|
||||
self.dbapi_connections.discard(dbapi_conn)
|
||||
|
||||
def add_engine(self, engine, scope):
|
||||
self.add_pool(engine.pool)
|
||||
|
||||
assert scope in ("class", "global", "function", "fixture")
|
||||
self.testing_engines[scope].add(engine)
|
||||
|
||||
def _safe(self, fn):
|
||||
try:
|
||||
fn()
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
"testing_reaper couldn't rollback/close connection: %s" % e
|
||||
)
|
||||
|
||||
def rollback_all(self):
|
||||
for rec in list(self.proxy_refs):
|
||||
if rec is not None and rec.is_valid:
|
||||
self._safe(rec.rollback)
|
||||
|
||||
def checkin_all(self):
|
||||
# run pool.checkin() for all ConnectionFairy instances we have
|
||||
# tracked.
|
||||
|
||||
for rec in list(self.proxy_refs):
|
||||
if rec is not None and rec.is_valid:
|
||||
self.dbapi_connections.discard(rec.dbapi_connection)
|
||||
self._safe(rec._checkin)
|
||||
|
||||
# for fairy refs that were GCed and could not close the connection,
|
||||
# such as asyncio, roll back those remaining connections
|
||||
for con in self.dbapi_connections:
|
||||
self._safe(con.rollback)
|
||||
self.dbapi_connections.clear()
|
||||
|
||||
def close_all(self):
|
||||
self.checkin_all()
|
||||
|
||||
def prepare_for_drop_tables(self, connection):
|
||||
# don't do aggressive checks for third party test suites
|
||||
if not config.bootstrapped_as_sqlalchemy:
|
||||
return
|
||||
|
||||
from . import provision
|
||||
|
||||
provision.prepare_for_drop_tables(connection.engine.url, connection)
|
||||
|
||||
def _drop_testing_engines(self, scope):
|
||||
eng = self.testing_engines[scope]
|
||||
for rec in list(eng):
|
||||
for proxy_ref in list(self.proxy_refs):
|
||||
if proxy_ref is not None and proxy_ref.is_valid:
|
||||
if (
|
||||
proxy_ref._pool is not None
|
||||
and proxy_ref._pool is rec.pool
|
||||
):
|
||||
self._safe(proxy_ref._checkin)
|
||||
|
||||
if hasattr(rec, "sync_engine"):
|
||||
await_only(rec.dispose())
|
||||
else:
|
||||
rec.dispose()
|
||||
eng.clear()
|
||||
|
||||
def after_test(self):
|
||||
self._drop_testing_engines("function")
|
||||
|
||||
def after_test_outside_fixtures(self, test):
|
||||
# don't do aggressive checks for third party test suites
|
||||
if not config.bootstrapped_as_sqlalchemy:
|
||||
return
|
||||
|
||||
if test.__class__.__leave_connections_for_teardown__:
|
||||
return
|
||||
|
||||
self.checkin_all()
|
||||
|
||||
# on PostgreSQL, this will test for any "idle in transaction"
|
||||
# connections. useful to identify tests with unusual patterns
|
||||
# that can't be cleaned up correctly.
|
||||
from . import provision
|
||||
|
||||
with config.db.connect() as conn:
|
||||
provision.prepare_for_drop_tables(conn.engine.url, conn)
|
||||
|
||||
def stop_test_class_inside_fixtures(self):
|
||||
self.checkin_all()
|
||||
self._drop_testing_engines("function")
|
||||
self._drop_testing_engines("class")
|
||||
|
||||
def stop_test_class_outside_fixtures(self):
|
||||
# ensure no refs to checked out connections at all.
|
||||
|
||||
if pool.base._strong_ref_connection_records:
|
||||
gc_collect()
|
||||
|
||||
if pool.base._strong_ref_connection_records:
|
||||
ln = len(pool.base._strong_ref_connection_records)
|
||||
pool.base._strong_ref_connection_records.clear()
|
||||
assert (
|
||||
False
|
||||
), "%d connection recs not cleared after test suite" % (ln)
|
||||
|
||||
def final_cleanup(self):
|
||||
self.checkin_all()
|
||||
for scope in self.testing_engines:
|
||||
self._drop_testing_engines(scope)
|
||||
|
||||
def assert_all_closed(self):
|
||||
for rec in self.proxy_refs:
|
||||
if rec.is_valid:
|
||||
assert False
|
||||
|
||||
|
||||
testing_reaper = ConnectionKiller()
|
||||
|
||||
|
||||
@decorator
|
||||
def assert_conns_closed(fn, *args, **kw):
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.assert_all_closed()
|
||||
|
||||
|
||||
@decorator
|
||||
def rollback_open_connections(fn, *args, **kw):
|
||||
"""Decorator that rolls back all open connections after fn execution."""
|
||||
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.rollback_all()
|
||||
|
||||
|
||||
@decorator
|
||||
def close_first(fn, *args, **kw):
|
||||
"""Decorator that closes all connections before fn execution."""
|
||||
|
||||
testing_reaper.checkin_all()
|
||||
fn(*args, **kw)
|
||||
|
||||
|
||||
@decorator
|
||||
def close_open_connections(fn, *args, **kw):
|
||||
"""Decorator that closes all connections after fn execution."""
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.checkin_all()
|
||||
|
||||
|
||||
def all_dialects(exclude=None):
|
||||
import sqlalchemy.dialects as d
|
||||
|
||||
for name in d.__all__:
|
||||
# TEMPORARY
|
||||
if exclude and name in exclude:
|
||||
continue
|
||||
mod = getattr(d, name, None)
|
||||
if not mod:
|
||||
mod = getattr(
|
||||
__import__("sqlalchemy.dialects.%s" % name).dialects, name
|
||||
)
|
||||
yield mod.dialect()
|
||||
|
||||
|
||||
class ReconnectFixture:
|
||||
def __init__(self, dbapi):
|
||||
self.dbapi = dbapi
|
||||
self.connections = []
|
||||
self.is_stopped = False
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.dbapi, key)
|
||||
|
||||
def connect(self, *args, **kwargs):
|
||||
conn = self.dbapi.connect(*args, **kwargs)
|
||||
if self.is_stopped:
|
||||
self._safe(conn.close)
|
||||
curs = conn.cursor() # should fail on Oracle etc.
|
||||
# should fail for everything that didn't fail
|
||||
# above, connection is closed
|
||||
curs.execute("select 1")
|
||||
assert False, "simulated connect failure didn't work"
|
||||
else:
|
||||
self.connections.append(conn)
|
||||
return conn
|
||||
|
||||
def _safe(self, fn):
|
||||
try:
|
||||
fn()
|
||||
except Exception as e:
|
||||
warnings.warn("ReconnectFixture couldn't close connection: %s" % e)
|
||||
|
||||
def shutdown(self, stop=False):
|
||||
# TODO: this doesn't cover all cases
|
||||
# as nicely as we'd like, namely MySQLdb.
|
||||
# would need to implement R. Brewer's
|
||||
# proxy server idea to get better
|
||||
# coverage.
|
||||
self.is_stopped = stop
|
||||
for c in list(self.connections):
|
||||
self._safe(c.close)
|
||||
self.connections = []
|
||||
|
||||
def restart(self):
|
||||
self.is_stopped = False
|
||||
|
||||
|
||||
def reconnecting_engine(url=None, options=None):
|
||||
url = url or config.db.url
|
||||
dbapi = config.db.dialect.dbapi
|
||||
if not options:
|
||||
options = {}
|
||||
options["module"] = ReconnectFixture(dbapi)
|
||||
engine = testing_engine(url, options)
|
||||
_dispose = engine.dispose
|
||||
|
||||
def dispose():
|
||||
engine.dialect.dbapi.shutdown()
|
||||
engine.dialect.dbapi.is_stopped = False
|
||||
_dispose()
|
||||
|
||||
engine.test_shutdown = engine.dialect.dbapi.shutdown
|
||||
engine.test_restart = engine.dialect.dbapi.restart
|
||||
engine.dispose = dispose
|
||||
return engine
|
||||
|
||||
|
||||
@typing.overload
|
||||
def testing_engine(
|
||||
url: Optional[URL] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
asyncio: Literal[False] = False,
|
||||
transfer_staticpool: bool = False,
|
||||
) -> Engine: ...
|
||||
|
||||
|
||||
@typing.overload
|
||||
def testing_engine(
|
||||
url: Optional[URL] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
asyncio: Literal[True] = True,
|
||||
transfer_staticpool: bool = False,
|
||||
) -> AsyncEngine: ...
|
||||
|
||||
|
||||
def testing_engine(
|
||||
url=None,
|
||||
options=None,
|
||||
asyncio=False,
|
||||
transfer_staticpool=False,
|
||||
share_pool=False,
|
||||
_sqlite_savepoint=False,
|
||||
):
|
||||
if asyncio:
|
||||
assert not _sqlite_savepoint
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
create_async_engine as create_engine,
|
||||
)
|
||||
else:
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
if not options:
|
||||
use_reaper = True
|
||||
scope = "function"
|
||||
sqlite_savepoint = False
|
||||
else:
|
||||
use_reaper = options.pop("use_reaper", True)
|
||||
scope = options.pop("scope", "function")
|
||||
sqlite_savepoint = options.pop("sqlite_savepoint", False)
|
||||
|
||||
url = url or config.db.url
|
||||
|
||||
url = make_url(url)
|
||||
|
||||
if (
|
||||
config.db is None or url.drivername == config.db.url.drivername
|
||||
) and config.db_opts:
|
||||
use_options = config.db_opts.copy()
|
||||
else:
|
||||
use_options = {}
|
||||
|
||||
if options is not None:
|
||||
use_options.update(options)
|
||||
|
||||
engine = create_engine(url, **use_options)
|
||||
|
||||
if sqlite_savepoint and engine.name == "sqlite":
|
||||
# apply SQLite savepoint workaround
|
||||
@event.listens_for(engine, "connect")
|
||||
def do_connect(dbapi_connection, connection_record):
|
||||
dbapi_connection.isolation_level = None
|
||||
|
||||
@event.listens_for(engine, "begin")
|
||||
def do_begin(conn):
|
||||
conn.exec_driver_sql("BEGIN")
|
||||
|
||||
if transfer_staticpool:
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
if config.db is not None and isinstance(config.db.pool, StaticPool):
|
||||
use_reaper = False
|
||||
engine.pool._transfer_from(config.db.pool)
|
||||
elif share_pool:
|
||||
engine.pool = config.db.pool
|
||||
|
||||
if scope == "global":
|
||||
if asyncio:
|
||||
engine.sync_engine._has_events = True
|
||||
else:
|
||||
engine._has_events = (
|
||||
True # enable event blocks, helps with profiling
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(engine.pool, pool.QueuePool)
|
||||
and "pool" not in use_options
|
||||
and "pool_timeout" not in use_options
|
||||
and "max_overflow" not in use_options
|
||||
):
|
||||
engine.pool._timeout = 0
|
||||
engine.pool._max_overflow = 0
|
||||
if use_reaper:
|
||||
testing_reaper.add_engine(engine, scope)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def mock_engine(dialect_name=None):
|
||||
"""Provides a mocking engine based on the current testing.db.
|
||||
|
||||
This is normally used to test DDL generation flow as emitted
|
||||
by an Engine.
|
||||
|
||||
It should not be used in other cases, as assert_compile() and
|
||||
assert_sql_execution() are much better choices with fewer
|
||||
moving parts.
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import create_mock_engine
|
||||
|
||||
if not dialect_name:
|
||||
dialect_name = config.db.name
|
||||
|
||||
buffer = []
|
||||
|
||||
def executor(sql, *a, **kw):
|
||||
buffer.append(sql)
|
||||
|
||||
def assert_sql(stmts):
|
||||
recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer]
|
||||
assert recv == stmts, recv
|
||||
|
||||
def print_sql():
|
||||
d = engine.dialect
|
||||
return "\n".join(str(s.compile(dialect=d)) for s in engine.mock)
|
||||
|
||||
engine = create_mock_engine(dialect_name + "://", executor)
|
||||
assert not hasattr(engine, "mock")
|
||||
engine.mock = buffer
|
||||
engine.assert_sql = assert_sql
|
||||
engine.print_sql = print_sql
|
||||
return engine
|
||||
|
||||
|
||||
class DBAPIProxyCursor:
|
||||
"""Proxy a DBAPI cursor.
|
||||
|
||||
Tests can provide subclasses of this to intercept
|
||||
DBAPI-level cursor operations.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, engine, conn, *args, **kwargs):
|
||||
self.engine = engine
|
||||
self.connection = conn
|
||||
self.cursor = conn.cursor(*args, **kwargs)
|
||||
|
||||
def execute(self, stmt, parameters=None, **kw):
|
||||
if parameters:
|
||||
return self.cursor.execute(stmt, parameters, **kw)
|
||||
else:
|
||||
return self.cursor.execute(stmt, **kw)
|
||||
|
||||
def executemany(self, stmt, params, **kw):
|
||||
return self.cursor.executemany(stmt, params, **kw)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.cursor)
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.cursor, key)
|
||||
|
||||
|
||||
class DBAPIProxyConnection:
|
||||
"""Proxy a DBAPI connection.
|
||||
|
||||
Tests can provide subclasses of this to intercept
|
||||
DBAPI-level connection operations.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, engine, conn, cursor_cls):
|
||||
self.conn = conn
|
||||
self.engine = engine
|
||||
self.cursor_cls = cursor_cls
|
||||
|
||||
def cursor(self, *args, **kwargs):
|
||||
return self.cursor_cls(self.engine, self.conn, *args, **kwargs)
|
||||
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.conn, key)
|
117
venv/lib/python3.11/site-packages/sqlalchemy/testing/entities.py
Normal file
117
venv/lib/python3.11/site-packages/sqlalchemy/testing/entities.py
Normal file
@ -0,0 +1,117 @@
|
||||
# testing/entities.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlalchemy as sa
|
||||
from .. import exc as sa_exc
|
||||
from ..orm.writeonly import WriteOnlyCollection
|
||||
|
||||
_repr_stack = set()
|
||||
|
||||
|
||||
class BasicEntity:
|
||||
def __init__(self, **kw):
|
||||
for key, value in kw.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def __repr__(self):
|
||||
if id(self) in _repr_stack:
|
||||
return object.__repr__(self)
|
||||
_repr_stack.add(id(self))
|
||||
try:
|
||||
return "%s(%s)" % (
|
||||
(self.__class__.__name__),
|
||||
", ".join(
|
||||
[
|
||||
"%s=%r" % (key, getattr(self, key))
|
||||
for key in sorted(self.__dict__.keys())
|
||||
if not key.startswith("_")
|
||||
]
|
||||
),
|
||||
)
|
||||
finally:
|
||||
_repr_stack.remove(id(self))
|
||||
|
||||
|
||||
_recursion_stack = set()
|
||||
|
||||
|
||||
class ComparableMixin:
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __eq__(self, other):
|
||||
"""'Deep, sparse compare.
|
||||
|
||||
Deeply compare two entities, following the non-None attributes of the
|
||||
non-persisted object, if possible.
|
||||
|
||||
"""
|
||||
if other is self:
|
||||
return True
|
||||
elif not self.__class__ == other.__class__:
|
||||
return False
|
||||
|
||||
if id(self) in _recursion_stack:
|
||||
return True
|
||||
_recursion_stack.add(id(self))
|
||||
|
||||
try:
|
||||
# pick the entity that's not SA persisted as the source
|
||||
try:
|
||||
self_key = sa.orm.attributes.instance_state(self).key
|
||||
except sa.orm.exc.NO_STATE:
|
||||
self_key = None
|
||||
|
||||
if other is None:
|
||||
a = self
|
||||
b = other
|
||||
elif self_key is not None:
|
||||
a = other
|
||||
b = self
|
||||
else:
|
||||
a = self
|
||||
b = other
|
||||
|
||||
for attr in list(a.__dict__):
|
||||
if attr.startswith("_"):
|
||||
continue
|
||||
|
||||
value = getattr(a, attr)
|
||||
|
||||
if isinstance(value, WriteOnlyCollection):
|
||||
continue
|
||||
|
||||
try:
|
||||
# handle lazy loader errors
|
||||
battr = getattr(b, attr)
|
||||
except (AttributeError, sa_exc.UnboundExecutionError):
|
||||
return False
|
||||
|
||||
if hasattr(value, "__iter__") and not isinstance(value, str):
|
||||
if hasattr(value, "__getitem__") and not hasattr(
|
||||
value, "keys"
|
||||
):
|
||||
if list(value) != list(battr):
|
||||
return False
|
||||
else:
|
||||
if set(value) != set(battr):
|
||||
return False
|
||||
else:
|
||||
if value is not None and value != battr:
|
||||
return False
|
||||
return True
|
||||
finally:
|
||||
_recursion_stack.remove(id(self))
|
||||
|
||||
|
||||
class ComparableEntity(ComparableMixin, BasicEntity):
|
||||
def __hash__(self):
|
||||
return hash(self.__class__)
|
@ -0,0 +1,435 @@
|
||||
# testing/exclusions.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
import contextlib
|
||||
import operator
|
||||
import re
|
||||
import sys
|
||||
|
||||
from . import config
|
||||
from .. import util
|
||||
from ..util import decorator
|
||||
from ..util.compat import inspect_getfullargspec
|
||||
|
||||
|
||||
def skip_if(predicate, reason=None):
|
||||
rule = compound()
|
||||
pred = _as_predicate(predicate, reason)
|
||||
rule.skips.add(pred)
|
||||
return rule
|
||||
|
||||
|
||||
def fails_if(predicate, reason=None):
|
||||
rule = compound()
|
||||
pred = _as_predicate(predicate, reason)
|
||||
rule.fails.add(pred)
|
||||
return rule
|
||||
|
||||
|
||||
class compound:
|
||||
def __init__(self):
|
||||
self.fails = set()
|
||||
self.skips = set()
|
||||
|
||||
def __add__(self, other):
|
||||
return self.add(other)
|
||||
|
||||
def as_skips(self):
|
||||
rule = compound()
|
||||
rule.skips.update(self.skips)
|
||||
rule.skips.update(self.fails)
|
||||
return rule
|
||||
|
||||
def add(self, *others):
|
||||
copy = compound()
|
||||
copy.fails.update(self.fails)
|
||||
copy.skips.update(self.skips)
|
||||
|
||||
for other in others:
|
||||
copy.fails.update(other.fails)
|
||||
copy.skips.update(other.skips)
|
||||
return copy
|
||||
|
||||
def not_(self):
|
||||
copy = compound()
|
||||
copy.fails.update(NotPredicate(fail) for fail in self.fails)
|
||||
copy.skips.update(NotPredicate(skip) for skip in self.skips)
|
||||
return copy
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.enabled_for_config(config._current)
|
||||
|
||||
def enabled_for_config(self, config):
|
||||
for predicate in self.skips.union(self.fails):
|
||||
if predicate(config):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def matching_config_reasons(self, config):
|
||||
return [
|
||||
predicate._as_string(config)
|
||||
for predicate in self.skips.union(self.fails)
|
||||
if predicate(config)
|
||||
]
|
||||
|
||||
def _extend(self, other):
|
||||
self.skips.update(other.skips)
|
||||
self.fails.update(other.fails)
|
||||
|
||||
def __call__(self, fn):
|
||||
if hasattr(fn, "_sa_exclusion_extend"):
|
||||
fn._sa_exclusion_extend._extend(self)
|
||||
return fn
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
return self._do(config._current, fn, *args, **kw)
|
||||
|
||||
decorated = decorate(fn)
|
||||
decorated._sa_exclusion_extend = self
|
||||
return decorated
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fail_if(self):
|
||||
all_fails = compound()
|
||||
all_fails.fails.update(self.skips.union(self.fails))
|
||||
|
||||
try:
|
||||
yield
|
||||
except Exception as ex:
|
||||
all_fails._expect_failure(config._current, ex)
|
||||
else:
|
||||
all_fails._expect_success(config._current)
|
||||
|
||||
def _do(self, cfg, fn, *args, **kw):
|
||||
for skip in self.skips:
|
||||
if skip(cfg):
|
||||
msg = "'%s' : %s" % (
|
||||
config.get_current_test_name(),
|
||||
skip._as_string(cfg),
|
||||
)
|
||||
config.skip_test(msg)
|
||||
|
||||
try:
|
||||
return_value = fn(*args, **kw)
|
||||
except Exception as ex:
|
||||
self._expect_failure(cfg, ex, name=fn.__name__)
|
||||
else:
|
||||
self._expect_success(cfg, name=fn.__name__)
|
||||
return return_value
|
||||
|
||||
def _expect_failure(self, config, ex, name="block"):
|
||||
for fail in self.fails:
|
||||
if fail(config):
|
||||
print(
|
||||
"%s failed as expected (%s): %s "
|
||||
% (name, fail._as_string(config), ex)
|
||||
)
|
||||
break
|
||||
else:
|
||||
raise ex.with_traceback(sys.exc_info()[2])
|
||||
|
||||
def _expect_success(self, config, name="block"):
|
||||
if not self.fails:
|
||||
return
|
||||
|
||||
for fail in self.fails:
|
||||
if fail(config):
|
||||
raise AssertionError(
|
||||
"Unexpected success for '%s' (%s)"
|
||||
% (
|
||||
name,
|
||||
" and ".join(
|
||||
fail._as_string(config) for fail in self.fails
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def only_if(predicate, reason=None):
|
||||
predicate = _as_predicate(predicate)
|
||||
return skip_if(NotPredicate(predicate), reason)
|
||||
|
||||
|
||||
def succeeds_if(predicate, reason=None):
|
||||
predicate = _as_predicate(predicate)
|
||||
return fails_if(NotPredicate(predicate), reason)
|
||||
|
||||
|
||||
class Predicate:
|
||||
@classmethod
|
||||
def as_predicate(cls, predicate, description=None):
|
||||
if isinstance(predicate, compound):
|
||||
return cls.as_predicate(predicate.enabled_for_config, description)
|
||||
elif isinstance(predicate, Predicate):
|
||||
if description and predicate.description is None:
|
||||
predicate.description = description
|
||||
return predicate
|
||||
elif isinstance(predicate, (list, set)):
|
||||
return OrPredicate(
|
||||
[cls.as_predicate(pred) for pred in predicate], description
|
||||
)
|
||||
elif isinstance(predicate, tuple):
|
||||
return SpecPredicate(*predicate)
|
||||
elif isinstance(predicate, str):
|
||||
tokens = re.match(
|
||||
r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate
|
||||
)
|
||||
if not tokens:
|
||||
raise ValueError(
|
||||
"Couldn't locate DB name in predicate: %r" % predicate
|
||||
)
|
||||
db = tokens.group(1)
|
||||
op = tokens.group(2)
|
||||
spec = (
|
||||
tuple(int(d) for d in tokens.group(3).split("."))
|
||||
if tokens.group(3)
|
||||
else None
|
||||
)
|
||||
|
||||
return SpecPredicate(db, op, spec, description=description)
|
||||
elif callable(predicate):
|
||||
return LambdaPredicate(predicate, description)
|
||||
else:
|
||||
assert False, "unknown predicate type: %s" % predicate
|
||||
|
||||
def _format_description(self, config, negate=False):
|
||||
bool_ = self(config)
|
||||
if negate:
|
||||
bool_ = not negate
|
||||
return self.description % {
|
||||
"driver": (
|
||||
config.db.url.get_driver_name() if config else "<no driver>"
|
||||
),
|
||||
"database": (
|
||||
config.db.url.get_backend_name() if config else "<no database>"
|
||||
),
|
||||
"doesnt_support": "doesn't support" if bool_ else "does support",
|
||||
"does_support": "does support" if bool_ else "doesn't support",
|
||||
}
|
||||
|
||||
def _as_string(self, config=None, negate=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class BooleanPredicate(Predicate):
|
||||
def __init__(self, value, description=None):
|
||||
self.value = value
|
||||
self.description = description or "boolean %s" % value
|
||||
|
||||
def __call__(self, config):
|
||||
return self.value
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
return self._format_description(config, negate=negate)
|
||||
|
||||
|
||||
class SpecPredicate(Predicate):
|
||||
def __init__(self, db, op=None, spec=None, description=None):
|
||||
self.db = db
|
||||
self.op = op
|
||||
self.spec = spec
|
||||
self.description = description
|
||||
|
||||
_ops = {
|
||||
"<": operator.lt,
|
||||
">": operator.gt,
|
||||
"==": operator.eq,
|
||||
"!=": operator.ne,
|
||||
"<=": operator.le,
|
||||
">=": operator.ge,
|
||||
"in": operator.contains,
|
||||
"between": lambda val, pair: val >= pair[0] and val <= pair[1],
|
||||
}
|
||||
|
||||
def __call__(self, config):
|
||||
if config is None:
|
||||
return False
|
||||
|
||||
engine = config.db
|
||||
|
||||
if "+" in self.db:
|
||||
dialect, driver = self.db.split("+")
|
||||
else:
|
||||
dialect, driver = self.db, None
|
||||
|
||||
if dialect and engine.name != dialect:
|
||||
return False
|
||||
if driver is not None and engine.driver != driver:
|
||||
return False
|
||||
|
||||
if self.op is not None:
|
||||
assert driver is None, "DBAPI version specs not supported yet"
|
||||
|
||||
version = _server_version(engine)
|
||||
oper = (
|
||||
hasattr(self.op, "__call__") and self.op or self._ops[self.op]
|
||||
)
|
||||
return oper(version, self.spec)
|
||||
else:
|
||||
return True
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
if self.description is not None:
|
||||
return self._format_description(config)
|
||||
elif self.op is None:
|
||||
if negate:
|
||||
return "not %s" % self.db
|
||||
else:
|
||||
return "%s" % self.db
|
||||
else:
|
||||
if negate:
|
||||
return "not %s %s %s" % (self.db, self.op, self.spec)
|
||||
else:
|
||||
return "%s %s %s" % (self.db, self.op, self.spec)
|
||||
|
||||
|
||||
class LambdaPredicate(Predicate):
|
||||
def __init__(self, lambda_, description=None, args=None, kw=None):
|
||||
spec = inspect_getfullargspec(lambda_)
|
||||
if not spec[0]:
|
||||
self.lambda_ = lambda db: lambda_()
|
||||
else:
|
||||
self.lambda_ = lambda_
|
||||
self.args = args or ()
|
||||
self.kw = kw or {}
|
||||
if description:
|
||||
self.description = description
|
||||
elif lambda_.__doc__:
|
||||
self.description = lambda_.__doc__
|
||||
else:
|
||||
self.description = "custom function"
|
||||
|
||||
def __call__(self, config):
|
||||
return self.lambda_(config)
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
return self._format_description(config)
|
||||
|
||||
|
||||
class NotPredicate(Predicate):
|
||||
def __init__(self, predicate, description=None):
|
||||
self.predicate = predicate
|
||||
self.description = description
|
||||
|
||||
def __call__(self, config):
|
||||
return not self.predicate(config)
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
if self.description:
|
||||
return self._format_description(config, not negate)
|
||||
else:
|
||||
return self.predicate._as_string(config, not negate)
|
||||
|
||||
|
||||
class OrPredicate(Predicate):
|
||||
def __init__(self, predicates, description=None):
|
||||
self.predicates = predicates
|
||||
self.description = description
|
||||
|
||||
def __call__(self, config):
|
||||
for pred in self.predicates:
|
||||
if pred(config):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _eval_str(self, config, negate=False):
|
||||
if negate:
|
||||
conjunction = " and "
|
||||
else:
|
||||
conjunction = " or "
|
||||
return conjunction.join(
|
||||
p._as_string(config, negate=negate) for p in self.predicates
|
||||
)
|
||||
|
||||
def _negation_str(self, config):
|
||||
if self.description is not None:
|
||||
return "Not " + self._format_description(config)
|
||||
else:
|
||||
return self._eval_str(config, negate=True)
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
if negate:
|
||||
return self._negation_str(config)
|
||||
else:
|
||||
if self.description is not None:
|
||||
return self._format_description(config)
|
||||
else:
|
||||
return self._eval_str(config)
|
||||
|
||||
|
||||
_as_predicate = Predicate.as_predicate
|
||||
|
||||
|
||||
def _is_excluded(db, op, spec):
|
||||
return SpecPredicate(db, op, spec)(config._current)
|
||||
|
||||
|
||||
def _server_version(engine):
|
||||
"""Return a server_version_info tuple."""
|
||||
|
||||
# force metadata to be retrieved
|
||||
conn = engine.connect()
|
||||
version = getattr(engine.dialect, "server_version_info", None)
|
||||
if version is None:
|
||||
version = ()
|
||||
conn.close()
|
||||
return version
|
||||
|
||||
|
||||
def db_spec(*dbs):
|
||||
return OrPredicate([Predicate.as_predicate(db) for db in dbs])
|
||||
|
||||
|
||||
def open(): # noqa
|
||||
return skip_if(BooleanPredicate(False, "mark as execute"))
|
||||
|
||||
|
||||
def closed():
|
||||
return skip_if(BooleanPredicate(True, "marked as skip"))
|
||||
|
||||
|
||||
def fails(reason=None):
|
||||
return fails_if(BooleanPredicate(True, reason or "expected to fail"))
|
||||
|
||||
|
||||
def future():
|
||||
return fails_if(BooleanPredicate(True, "Future feature"))
|
||||
|
||||
|
||||
def fails_on(db, reason=None):
|
||||
return fails_if(db, reason)
|
||||
|
||||
|
||||
def fails_on_everything_except(*dbs):
|
||||
return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))
|
||||
|
||||
|
||||
def skip(db, reason=None):
|
||||
return skip_if(db, reason)
|
||||
|
||||
|
||||
def only_on(dbs, reason=None):
|
||||
return only_if(
|
||||
OrPredicate(
|
||||
[Predicate.as_predicate(db, reason) for db in util.to_list(dbs)]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def exclude(db, op, spec, reason=None):
|
||||
return skip_if(SpecPredicate(db, op, spec), reason)
|
||||
|
||||
|
||||
def against(config, *queries):
|
||||
assert queries, "no queries sent!"
|
||||
return OrPredicate([Predicate.as_predicate(query) for query in queries])(
|
||||
config
|
||||
)
|
@ -0,0 +1,28 @@
|
||||
# testing/fixtures/__init__.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
from .base import FutureEngineMixin as FutureEngineMixin
|
||||
from .base import TestBase as TestBase
|
||||
from .mypy import MypyTest as MypyTest
|
||||
from .orm import after_test as after_test
|
||||
from .orm import close_all_sessions as close_all_sessions
|
||||
from .orm import DeclarativeMappedTest as DeclarativeMappedTest
|
||||
from .orm import fixture_session as fixture_session
|
||||
from .orm import MappedTest as MappedTest
|
||||
from .orm import ORMTest as ORMTest
|
||||
from .orm import RemoveORMEventsGlobally as RemoveORMEventsGlobally
|
||||
from .orm import (
|
||||
stop_test_class_inside_fixtures as stop_test_class_inside_fixtures,
|
||||
)
|
||||
from .sql import CacheKeyFixture as CacheKeyFixture
|
||||
from .sql import (
|
||||
ComputedReflectionFixtureTest as ComputedReflectionFixtureTest,
|
||||
)
|
||||
from .sql import insertmanyvalues_fixture as insertmanyvalues_fixture
|
||||
from .sql import NoCache as NoCache
|
||||
from .sql import RemovesEvents as RemovesEvents
|
||||
from .sql import TablesTest as TablesTest
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,366 @@
|
||||
# testing/fixtures/base.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlalchemy as sa
|
||||
from .. import assertions
|
||||
from .. import config
|
||||
from ..assertions import eq_
|
||||
from ..util import drop_all_tables_from_metadata
|
||||
from ... import Column
|
||||
from ... import func
|
||||
from ... import Integer
|
||||
from ... import select
|
||||
from ... import Table
|
||||
from ...orm import DeclarativeBase
|
||||
from ...orm import MappedAsDataclass
|
||||
from ...orm import registry
|
||||
|
||||
|
||||
@config.mark_base_test_class()
|
||||
class TestBase:
|
||||
# A sequence of requirement names matching testing.requires decorators
|
||||
__requires__ = ()
|
||||
|
||||
# A sequence of dialect names to exclude from the test class.
|
||||
__unsupported_on__ = ()
|
||||
|
||||
# If present, test class is only runnable for the *single* specified
|
||||
# dialect. If you need multiple, use __unsupported_on__ and invert.
|
||||
__only_on__ = None
|
||||
|
||||
# A sequence of no-arg callables. If any are True, the entire testcase is
|
||||
# skipped.
|
||||
__skip_if__ = None
|
||||
|
||||
# if True, the testing reaper will not attempt to touch connection
|
||||
# state after a test is completed and before the outer teardown
|
||||
# starts
|
||||
__leave_connections_for_teardown__ = False
|
||||
|
||||
def assert_(self, val, msg=None):
|
||||
assert val, msg
|
||||
|
||||
@config.fixture()
|
||||
def nocache(self):
|
||||
_cache = config.db._compiled_cache
|
||||
config.db._compiled_cache = None
|
||||
yield
|
||||
config.db._compiled_cache = _cache
|
||||
|
||||
@config.fixture()
|
||||
def connection_no_trans(self):
|
||||
eng = getattr(self, "bind", None) or config.db
|
||||
|
||||
with eng.connect() as conn:
|
||||
yield conn
|
||||
|
||||
@config.fixture()
|
||||
def connection(self):
|
||||
global _connection_fixture_connection
|
||||
|
||||
eng = getattr(self, "bind", None) or config.db
|
||||
|
||||
conn = eng.connect()
|
||||
trans = conn.begin()
|
||||
|
||||
_connection_fixture_connection = conn
|
||||
yield conn
|
||||
|
||||
_connection_fixture_connection = None
|
||||
|
||||
if trans.is_active:
|
||||
trans.rollback()
|
||||
# trans would not be active here if the test is using
|
||||
# the legacy @provide_metadata decorator still, as it will
|
||||
# run a close all connections.
|
||||
conn.close()
|
||||
|
||||
@config.fixture()
|
||||
def close_result_when_finished(self):
|
||||
to_close = []
|
||||
to_consume = []
|
||||
|
||||
def go(result, consume=False):
|
||||
to_close.append(result)
|
||||
if consume:
|
||||
to_consume.append(result)
|
||||
|
||||
yield go
|
||||
for r in to_consume:
|
||||
try:
|
||||
r.all()
|
||||
except:
|
||||
pass
|
||||
for r in to_close:
|
||||
try:
|
||||
r.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
@config.fixture()
|
||||
def registry(self, metadata):
|
||||
reg = registry(
|
||||
metadata=metadata,
|
||||
type_annotation_map={
|
||||
str: sa.String().with_variant(
|
||||
sa.String(50), "mysql", "mariadb", "oracle"
|
||||
)
|
||||
},
|
||||
)
|
||||
yield reg
|
||||
reg.dispose()
|
||||
|
||||
@config.fixture
|
||||
def decl_base(self, metadata):
|
||||
_md = metadata
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
metadata = _md
|
||||
type_annotation_map = {
|
||||
str: sa.String().with_variant(
|
||||
sa.String(50), "mysql", "mariadb", "oracle"
|
||||
)
|
||||
}
|
||||
|
||||
yield Base
|
||||
Base.registry.dispose()
|
||||
|
||||
@config.fixture
|
||||
def dc_decl_base(self, metadata):
|
||||
_md = metadata
|
||||
|
||||
class Base(MappedAsDataclass, DeclarativeBase):
|
||||
metadata = _md
|
||||
type_annotation_map = {
|
||||
str: sa.String().with_variant(
|
||||
sa.String(50), "mysql", "mariadb"
|
||||
)
|
||||
}
|
||||
|
||||
yield Base
|
||||
Base.registry.dispose()
|
||||
|
||||
@config.fixture()
|
||||
def future_connection(self, future_engine, connection):
|
||||
# integrate the future_engine and connection fixtures so
|
||||
# that users of the "connection" fixture will get at the
|
||||
# "future" connection
|
||||
yield connection
|
||||
|
||||
@config.fixture()
|
||||
def future_engine(self):
|
||||
yield
|
||||
|
||||
@config.fixture()
|
||||
def testing_engine(self):
|
||||
from .. import engines
|
||||
|
||||
def gen_testing_engine(
|
||||
url=None,
|
||||
options=None,
|
||||
future=None,
|
||||
asyncio=False,
|
||||
transfer_staticpool=False,
|
||||
share_pool=False,
|
||||
):
|
||||
if options is None:
|
||||
options = {}
|
||||
options["scope"] = "fixture"
|
||||
return engines.testing_engine(
|
||||
url=url,
|
||||
options=options,
|
||||
asyncio=asyncio,
|
||||
transfer_staticpool=transfer_staticpool,
|
||||
share_pool=share_pool,
|
||||
)
|
||||
|
||||
yield gen_testing_engine
|
||||
|
||||
engines.testing_reaper._drop_testing_engines("fixture")
|
||||
|
||||
@config.fixture()
|
||||
def async_testing_engine(self, testing_engine):
|
||||
def go(**kw):
|
||||
kw["asyncio"] = True
|
||||
return testing_engine(**kw)
|
||||
|
||||
return go
|
||||
|
||||
@config.fixture()
|
||||
def metadata(self, request):
|
||||
"""Provide bound MetaData for a single test, dropping afterwards."""
|
||||
|
||||
from ...sql import schema
|
||||
|
||||
metadata = schema.MetaData()
|
||||
request.instance.metadata = metadata
|
||||
yield metadata
|
||||
del request.instance.metadata
|
||||
|
||||
if (
|
||||
_connection_fixture_connection
|
||||
and _connection_fixture_connection.in_transaction()
|
||||
):
|
||||
trans = _connection_fixture_connection.get_transaction()
|
||||
trans.rollback()
|
||||
with _connection_fixture_connection.begin():
|
||||
drop_all_tables_from_metadata(
|
||||
metadata, _connection_fixture_connection
|
||||
)
|
||||
else:
|
||||
drop_all_tables_from_metadata(metadata, config.db)
|
||||
|
||||
@config.fixture(
|
||||
params=[
|
||||
(rollback, second_operation, begin_nested)
|
||||
for rollback in (True, False)
|
||||
for second_operation in ("none", "execute", "begin")
|
||||
for begin_nested in (
|
||||
True,
|
||||
False,
|
||||
)
|
||||
]
|
||||
)
|
||||
def trans_ctx_manager_fixture(self, request, metadata):
|
||||
rollback, second_operation, begin_nested = request.param
|
||||
|
||||
t = Table("test", metadata, Column("data", Integer))
|
||||
eng = getattr(self, "bind", None) or config.db
|
||||
|
||||
t.create(eng)
|
||||
|
||||
def run_test(subject, trans_on_subject, execute_on_subject):
|
||||
with subject.begin() as trans:
|
||||
if begin_nested:
|
||||
if not config.requirements.savepoints.enabled:
|
||||
config.skip_test("savepoints not enabled")
|
||||
if execute_on_subject:
|
||||
nested_trans = subject.begin_nested()
|
||||
else:
|
||||
nested_trans = trans.begin_nested()
|
||||
|
||||
with nested_trans:
|
||||
if execute_on_subject:
|
||||
subject.execute(t.insert(), {"data": 10})
|
||||
else:
|
||||
trans.execute(t.insert(), {"data": 10})
|
||||
|
||||
# for nested trans, we always commit/rollback on the
|
||||
# "nested trans" object itself.
|
||||
# only Session(future=False) will affect savepoint
|
||||
# transaction for session.commit/rollback
|
||||
|
||||
if rollback:
|
||||
nested_trans.rollback()
|
||||
else:
|
||||
nested_trans.commit()
|
||||
|
||||
if second_operation != "none":
|
||||
with assertions.expect_raises_message(
|
||||
sa.exc.InvalidRequestError,
|
||||
"Can't operate on closed transaction "
|
||||
"inside context "
|
||||
"manager. Please complete the context "
|
||||
"manager "
|
||||
"before emitting further commands.",
|
||||
):
|
||||
if second_operation == "execute":
|
||||
if execute_on_subject:
|
||||
subject.execute(
|
||||
t.insert(), {"data": 12}
|
||||
)
|
||||
else:
|
||||
trans.execute(t.insert(), {"data": 12})
|
||||
elif second_operation == "begin":
|
||||
if execute_on_subject:
|
||||
subject.begin_nested()
|
||||
else:
|
||||
trans.begin_nested()
|
||||
|
||||
# outside the nested trans block, but still inside the
|
||||
# transaction block, we can run SQL, and it will be
|
||||
# committed
|
||||
if execute_on_subject:
|
||||
subject.execute(t.insert(), {"data": 14})
|
||||
else:
|
||||
trans.execute(t.insert(), {"data": 14})
|
||||
|
||||
else:
|
||||
if execute_on_subject:
|
||||
subject.execute(t.insert(), {"data": 10})
|
||||
else:
|
||||
trans.execute(t.insert(), {"data": 10})
|
||||
|
||||
if trans_on_subject:
|
||||
if rollback:
|
||||
subject.rollback()
|
||||
else:
|
||||
subject.commit()
|
||||
else:
|
||||
if rollback:
|
||||
trans.rollback()
|
||||
else:
|
||||
trans.commit()
|
||||
|
||||
if second_operation != "none":
|
||||
with assertions.expect_raises_message(
|
||||
sa.exc.InvalidRequestError,
|
||||
"Can't operate on closed transaction inside "
|
||||
"context "
|
||||
"manager. Please complete the context manager "
|
||||
"before emitting further commands.",
|
||||
):
|
||||
if second_operation == "execute":
|
||||
if execute_on_subject:
|
||||
subject.execute(t.insert(), {"data": 12})
|
||||
else:
|
||||
trans.execute(t.insert(), {"data": 12})
|
||||
elif second_operation == "begin":
|
||||
if hasattr(trans, "begin"):
|
||||
trans.begin()
|
||||
else:
|
||||
subject.begin()
|
||||
elif second_operation == "begin_nested":
|
||||
if execute_on_subject:
|
||||
subject.begin_nested()
|
||||
else:
|
||||
trans.begin_nested()
|
||||
|
||||
expected_committed = 0
|
||||
if begin_nested:
|
||||
# begin_nested variant, we inserted a row after the nested
|
||||
# block
|
||||
expected_committed += 1
|
||||
if not rollback:
|
||||
# not rollback variant, our row inserted in the target
|
||||
# block itself would be committed
|
||||
expected_committed += 1
|
||||
|
||||
if execute_on_subject:
|
||||
eq_(
|
||||
subject.scalar(select(func.count()).select_from(t)),
|
||||
expected_committed,
|
||||
)
|
||||
else:
|
||||
with subject.connect() as conn:
|
||||
eq_(
|
||||
conn.scalar(select(func.count()).select_from(t)),
|
||||
expected_committed,
|
||||
)
|
||||
|
||||
return run_test
|
||||
|
||||
|
||||
_connection_fixture_connection = None
|
||||
|
||||
|
||||
class FutureEngineMixin:
|
||||
"""alembic's suite still using this"""
|
@ -0,0 +1,312 @@
|
||||
# testing/fixtures/mypy.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from .base import TestBase
|
||||
from .. import config
|
||||
from ..assertions import eq_
|
||||
from ... import util
|
||||
|
||||
|
||||
@config.add_to_marker.mypy
|
||||
class MypyTest(TestBase):
|
||||
__requires__ = ("no_sqlalchemy2_stubs",)
|
||||
|
||||
@config.fixture(scope="function")
|
||||
def per_func_cachedir(self):
|
||||
yield from self._cachedir()
|
||||
|
||||
@config.fixture(scope="class")
|
||||
def cachedir(self):
|
||||
yield from self._cachedir()
|
||||
|
||||
def _cachedir(self):
|
||||
# as of mypy 0.971 i think we need to keep mypy_path empty
|
||||
mypy_path = ""
|
||||
|
||||
with tempfile.TemporaryDirectory() as cachedir:
|
||||
with open(
|
||||
Path(cachedir) / "sqla_mypy_config.cfg", "w"
|
||||
) as config_file:
|
||||
config_file.write(
|
||||
f"""
|
||||
[mypy]\n
|
||||
plugins = sqlalchemy.ext.mypy.plugin\n
|
||||
show_error_codes = True\n
|
||||
{mypy_path}
|
||||
disable_error_code = no-untyped-call
|
||||
|
||||
[mypy-sqlalchemy.*]
|
||||
ignore_errors = True
|
||||
|
||||
"""
|
||||
)
|
||||
with open(
|
||||
Path(cachedir) / "plain_mypy_config.cfg", "w"
|
||||
) as config_file:
|
||||
config_file.write(
|
||||
f"""
|
||||
[mypy]\n
|
||||
show_error_codes = True\n
|
||||
{mypy_path}
|
||||
disable_error_code = var-annotated,no-untyped-call
|
||||
[mypy-sqlalchemy.*]
|
||||
ignore_errors = True
|
||||
|
||||
"""
|
||||
)
|
||||
yield cachedir
|
||||
|
||||
@config.fixture()
|
||||
def mypy_runner(self, cachedir):
|
||||
from mypy import api
|
||||
|
||||
def run(path, use_plugin=False, use_cachedir=None):
|
||||
if use_cachedir is None:
|
||||
use_cachedir = cachedir
|
||||
args = [
|
||||
"--strict",
|
||||
"--raise-exceptions",
|
||||
"--cache-dir",
|
||||
use_cachedir,
|
||||
"--config-file",
|
||||
os.path.join(
|
||||
use_cachedir,
|
||||
(
|
||||
"sqla_mypy_config.cfg"
|
||||
if use_plugin
|
||||
else "plain_mypy_config.cfg"
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
# mypy as of 0.990 is more aggressively blocking messaging
|
||||
# for paths that are in sys.path, and as pytest puts currdir,
|
||||
# test/ etc in sys.path, just copy the source file to the
|
||||
# tempdir we are working in so that we don't have to try to
|
||||
# manipulate sys.path and/or guess what mypy is doing
|
||||
filename = os.path.basename(path)
|
||||
test_program = os.path.join(use_cachedir, filename)
|
||||
if path != test_program:
|
||||
shutil.copyfile(path, test_program)
|
||||
args.append(test_program)
|
||||
|
||||
# I set this locally but for the suite here needs to be
|
||||
# disabled
|
||||
os.environ.pop("MYPY_FORCE_COLOR", None)
|
||||
|
||||
stdout, stderr, exitcode = api.run(args)
|
||||
return stdout, stderr, exitcode
|
||||
|
||||
return run
|
||||
|
||||
@config.fixture
|
||||
def mypy_typecheck_file(self, mypy_runner):
|
||||
def run(path, use_plugin=False):
|
||||
expected_messages = self._collect_messages(path)
|
||||
stdout, stderr, exitcode = mypy_runner(path, use_plugin=use_plugin)
|
||||
self._check_output(
|
||||
path, expected_messages, stdout, stderr, exitcode
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
@staticmethod
|
||||
def file_combinations(dirname):
|
||||
if os.path.isabs(dirname):
|
||||
path = dirname
|
||||
else:
|
||||
caller_path = inspect.stack()[1].filename
|
||||
path = os.path.join(os.path.dirname(caller_path), dirname)
|
||||
files = list(Path(path).glob("**/*.py"))
|
||||
|
||||
for extra_dir in config.options.mypy_extra_test_paths:
|
||||
if extra_dir and os.path.isdir(extra_dir):
|
||||
files.extend((Path(extra_dir) / dirname).glob("**/*.py"))
|
||||
return files
|
||||
|
||||
def _collect_messages(self, path):
|
||||
from sqlalchemy.ext.mypy.util import mypy_14
|
||||
|
||||
expected_messages = []
|
||||
expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)")
|
||||
py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)")
|
||||
with open(path) as file_:
|
||||
current_assert_messages = []
|
||||
for num, line in enumerate(file_, 1):
|
||||
m = py_ver_re.match(line)
|
||||
if m:
|
||||
major, _, minor = m.group(1).partition(".")
|
||||
if sys.version_info < (int(major), int(minor)):
|
||||
config.skip_test(
|
||||
"Requires python >= %s" % (m.group(1))
|
||||
)
|
||||
continue
|
||||
|
||||
m = expected_re.match(line)
|
||||
if m:
|
||||
is_mypy = bool(m.group(1))
|
||||
is_re = bool(m.group(2))
|
||||
is_type = bool(m.group(3))
|
||||
|
||||
expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4))
|
||||
if is_type:
|
||||
if not is_re:
|
||||
# the goal here is that we can cut-and-paste
|
||||
# from vscode -> pylance into the
|
||||
# EXPECTED_TYPE: line, then the test suite will
|
||||
# validate that line against what mypy produces
|
||||
expected_msg = re.sub(
|
||||
r"([\[\]])",
|
||||
lambda m: rf"\{m.group(0)}",
|
||||
expected_msg,
|
||||
)
|
||||
|
||||
# note making sure preceding text matches
|
||||
# with a dot, so that an expect for "Select"
|
||||
# does not match "TypedSelect"
|
||||
expected_msg = re.sub(
|
||||
r"([\w_]+)",
|
||||
lambda m: rf"(?:.*\.)?{m.group(1)}\*?",
|
||||
expected_msg,
|
||||
)
|
||||
|
||||
expected_msg = re.sub(
|
||||
"List", "builtins.list", expected_msg
|
||||
)
|
||||
|
||||
expected_msg = re.sub(
|
||||
r"\b(int|str|float|bool)\b",
|
||||
lambda m: rf"builtins.{m.group(0)}\*?",
|
||||
expected_msg,
|
||||
)
|
||||
# expected_msg = re.sub(
|
||||
# r"(Sequence|Tuple|List|Union)",
|
||||
# lambda m: fr"typing.{m.group(0)}\*?",
|
||||
# expected_msg,
|
||||
# )
|
||||
|
||||
is_mypy = is_re = True
|
||||
expected_msg = f'Revealed type is "{expected_msg}"'
|
||||
|
||||
if mypy_14 and util.py39:
|
||||
# use_lowercase_names, py39 and above
|
||||
# https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L363 # noqa: E501
|
||||
|
||||
# skip first character which could be capitalized
|
||||
# "List item x not found" type of message
|
||||
expected_msg = expected_msg[0] + re.sub(
|
||||
(
|
||||
r"\b(List|Tuple|Dict|Set)\b"
|
||||
if is_type
|
||||
else r"\b(List|Tuple|Dict|Set|Type)\b"
|
||||
),
|
||||
lambda m: m.group(1).lower(),
|
||||
expected_msg[1:],
|
||||
)
|
||||
|
||||
if mypy_14 and util.py310:
|
||||
# use_or_syntax, py310 and above
|
||||
# https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L368 # noqa: E501
|
||||
expected_msg = re.sub(
|
||||
r"Optional\[(.*?)\]",
|
||||
lambda m: f"{m.group(1)} | None",
|
||||
expected_msg,
|
||||
)
|
||||
current_assert_messages.append(
|
||||
(is_mypy, is_re, expected_msg.strip())
|
||||
)
|
||||
elif current_assert_messages:
|
||||
expected_messages.extend(
|
||||
(num, is_mypy, is_re, expected_msg)
|
||||
for (
|
||||
is_mypy,
|
||||
is_re,
|
||||
expected_msg,
|
||||
) in current_assert_messages
|
||||
)
|
||||
current_assert_messages[:] = []
|
||||
|
||||
return expected_messages
|
||||
|
||||
def _check_output(self, path, expected_messages, stdout, stderr, exitcode):
|
||||
not_located = []
|
||||
filename = os.path.basename(path)
|
||||
if expected_messages:
|
||||
# mypy 0.990 changed how return codes work, so don't assume a
|
||||
# 1 or a 0 return code here, could be either depending on if
|
||||
# errors were generated or not
|
||||
|
||||
output = []
|
||||
|
||||
raw_lines = stdout.split("\n")
|
||||
while raw_lines:
|
||||
e = raw_lines.pop(0)
|
||||
if re.match(r".+\.py:\d+: error: .*", e):
|
||||
output.append(("error", e))
|
||||
elif re.match(
|
||||
r".+\.py:\d+: note: +(?:Possible overload|def ).*", e
|
||||
):
|
||||
while raw_lines:
|
||||
ol = raw_lines.pop(0)
|
||||
if not re.match(r".+\.py:\d+: note: +def \[.*", ol):
|
||||
break
|
||||
elif re.match(
|
||||
r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I
|
||||
):
|
||||
pass
|
||||
elif re.match(r".+\.py:\d+: note: .*", e):
|
||||
output.append(("note", e))
|
||||
|
||||
for num, is_mypy, is_re, msg in expected_messages:
|
||||
msg = msg.replace("'", '"')
|
||||
prefix = "[SQLAlchemy Mypy plugin] " if not is_mypy else ""
|
||||
for idx, (typ, errmsg) in enumerate(output):
|
||||
if is_re:
|
||||
if re.match(
|
||||
rf".*{filename}\:{num}\: {typ}\: {prefix}{msg}",
|
||||
errmsg,
|
||||
):
|
||||
break
|
||||
elif (
|
||||
f"{filename}:{num}: {typ}: {prefix}{msg}"
|
||||
in errmsg.replace("'", '"')
|
||||
):
|
||||
break
|
||||
else:
|
||||
not_located.append(msg)
|
||||
continue
|
||||
del output[idx]
|
||||
|
||||
if not_located:
|
||||
missing = "\n".join(not_located)
|
||||
print("Couldn't locate expected messages:", missing, sep="\n")
|
||||
if output:
|
||||
extra = "\n".join(msg for _, msg in output)
|
||||
print("Remaining messages:", extra, sep="\n")
|
||||
assert False, "expected messages not found, see stdout"
|
||||
|
||||
if output:
|
||||
print(f"{len(output)} messages from mypy were not consumed:")
|
||||
print("\n".join(msg for _, msg in output))
|
||||
assert False, "errors and/or notes remain, see stdout"
|
||||
|
||||
else:
|
||||
if exitcode != 0:
|
||||
print(stdout, stderr, sep="\n")
|
||||
|
||||
eq_(exitcode, 0, msg=stdout)
|
@ -0,0 +1,227 @@
|
||||
# testing/fixtures/orm.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from .base import TestBase
|
||||
from .sql import TablesTest
|
||||
from .. import assertions
|
||||
from .. import config
|
||||
from .. import schema
|
||||
from ..entities import BasicEntity
|
||||
from ..entities import ComparableEntity
|
||||
from ..util import adict
|
||||
from ... import orm
|
||||
from ...orm import DeclarativeBase
|
||||
from ...orm import events as orm_events
|
||||
from ...orm import registry
|
||||
|
||||
|
||||
class ORMTest(TestBase):
|
||||
@config.fixture
|
||||
def fixture_session(self):
|
||||
return fixture_session()
|
||||
|
||||
|
||||
class MappedTest(ORMTest, TablesTest, assertions.AssertsExecutionResults):
|
||||
# 'once', 'each', None
|
||||
run_setup_classes = "once"
|
||||
|
||||
# 'once', 'each', None
|
||||
run_setup_mappers = "each"
|
||||
|
||||
classes: Any = None
|
||||
|
||||
@config.fixture(autouse=True, scope="class")
|
||||
def _setup_tables_test_class(self):
|
||||
cls = self.__class__
|
||||
cls._init_class()
|
||||
|
||||
if cls.classes is None:
|
||||
cls.classes = adict()
|
||||
|
||||
cls._setup_once_tables()
|
||||
cls._setup_once_classes()
|
||||
cls._setup_once_mappers()
|
||||
cls._setup_once_inserts()
|
||||
|
||||
yield
|
||||
|
||||
cls._teardown_once_class()
|
||||
cls._teardown_once_metadata_bind()
|
||||
|
||||
@config.fixture(autouse=True, scope="function")
|
||||
def _setup_tables_test_instance(self):
|
||||
self._setup_each_tables()
|
||||
self._setup_each_classes()
|
||||
self._setup_each_mappers()
|
||||
self._setup_each_inserts()
|
||||
|
||||
yield
|
||||
|
||||
orm.session.close_all_sessions()
|
||||
self._teardown_each_mappers()
|
||||
self._teardown_each_classes()
|
||||
self._teardown_each_tables()
|
||||
|
||||
@classmethod
|
||||
def _teardown_once_class(cls):
|
||||
cls.classes.clear()
|
||||
|
||||
@classmethod
|
||||
def _setup_once_classes(cls):
|
||||
if cls.run_setup_classes == "once":
|
||||
cls._with_register_classes(cls.setup_classes)
|
||||
|
||||
@classmethod
|
||||
def _setup_once_mappers(cls):
|
||||
if cls.run_setup_mappers == "once":
|
||||
cls.mapper_registry, cls.mapper = cls._generate_registry()
|
||||
cls._with_register_classes(cls.setup_mappers)
|
||||
|
||||
def _setup_each_mappers(self):
|
||||
if self.run_setup_mappers != "once":
|
||||
(
|
||||
self.__class__.mapper_registry,
|
||||
self.__class__.mapper,
|
||||
) = self._generate_registry()
|
||||
|
||||
if self.run_setup_mappers == "each":
|
||||
self._with_register_classes(self.setup_mappers)
|
||||
|
||||
def _setup_each_classes(self):
|
||||
if self.run_setup_classes == "each":
|
||||
self._with_register_classes(self.setup_classes)
|
||||
|
||||
@classmethod
|
||||
def _generate_registry(cls):
|
||||
decl = registry(metadata=cls._tables_metadata)
|
||||
return decl, decl.map_imperatively
|
||||
|
||||
@classmethod
|
||||
def _with_register_classes(cls, fn):
|
||||
"""Run a setup method, framing the operation with a Base class
|
||||
that will catch new subclasses to be established within
|
||||
the "classes" registry.
|
||||
|
||||
"""
|
||||
cls_registry = cls.classes
|
||||
|
||||
class _Base:
|
||||
def __init_subclass__(cls) -> None:
|
||||
assert cls_registry is not None
|
||||
cls_registry[cls.__name__] = cls
|
||||
super().__init_subclass__()
|
||||
|
||||
class Basic(BasicEntity, _Base):
|
||||
pass
|
||||
|
||||
class Comparable(ComparableEntity, _Base):
|
||||
pass
|
||||
|
||||
cls.Basic = Basic
|
||||
cls.Comparable = Comparable
|
||||
fn()
|
||||
|
||||
def _teardown_each_mappers(self):
|
||||
# some tests create mappers in the test bodies
|
||||
# and will define setup_mappers as None -
|
||||
# clear mappers in any case
|
||||
if self.run_setup_mappers != "once":
|
||||
orm.clear_mappers()
|
||||
|
||||
def _teardown_each_classes(self):
|
||||
if self.run_setup_classes != "once":
|
||||
self.classes.clear()
|
||||
|
||||
@classmethod
|
||||
def setup_classes(cls):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def setup_mappers(cls):
|
||||
pass
|
||||
|
||||
|
||||
class DeclarativeMappedTest(MappedTest):
|
||||
run_setup_classes = "once"
|
||||
run_setup_mappers = "once"
|
||||
|
||||
@classmethod
|
||||
def _setup_once_tables(cls):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _with_register_classes(cls, fn):
|
||||
cls_registry = cls.classes
|
||||
|
||||
class _DeclBase(DeclarativeBase):
|
||||
__table_cls__ = schema.Table
|
||||
metadata = cls._tables_metadata
|
||||
type_annotation_map = {
|
||||
str: sa.String().with_variant(
|
||||
sa.String(50), "mysql", "mariadb", "oracle"
|
||||
)
|
||||
}
|
||||
|
||||
def __init_subclass__(cls, **kw) -> None:
|
||||
assert cls_registry is not None
|
||||
cls_registry[cls.__name__] = cls
|
||||
super().__init_subclass__(**kw)
|
||||
|
||||
cls.DeclarativeBasic = _DeclBase
|
||||
|
||||
# sets up cls.Basic which is helpful for things like composite
|
||||
# classes
|
||||
super()._with_register_classes(fn)
|
||||
|
||||
if cls._tables_metadata.tables and cls.run_create_tables:
|
||||
cls._tables_metadata.create_all(config.db)
|
||||
|
||||
|
||||
class RemoveORMEventsGlobally:
|
||||
@config.fixture(autouse=True)
|
||||
def _remove_listeners(self):
|
||||
yield
|
||||
orm_events.MapperEvents._clear()
|
||||
orm_events.InstanceEvents._clear()
|
||||
orm_events.SessionEvents._clear()
|
||||
orm_events.InstrumentationEvents._clear()
|
||||
orm_events.QueryEvents._clear()
|
||||
|
||||
|
||||
_fixture_sessions = set()
|
||||
|
||||
|
||||
def fixture_session(**kw):
|
||||
kw.setdefault("autoflush", True)
|
||||
kw.setdefault("expire_on_commit", True)
|
||||
|
||||
bind = kw.pop("bind", config.db)
|
||||
|
||||
sess = orm.Session(bind, **kw)
|
||||
_fixture_sessions.add(sess)
|
||||
return sess
|
||||
|
||||
|
||||
def close_all_sessions():
|
||||
# will close all still-referenced sessions
|
||||
orm.close_all_sessions()
|
||||
_fixture_sessions.clear()
|
||||
|
||||
|
||||
def stop_test_class_inside_fixtures(cls):
|
||||
close_all_sessions()
|
||||
orm.clear_mappers()
|
||||
|
||||
|
||||
def after_test():
|
||||
if _fixture_sessions:
|
||||
close_all_sessions()
|
@ -0,0 +1,503 @@
|
||||
# testing/fixtures/sql.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
|
||||
import sqlalchemy as sa
|
||||
from .base import TestBase
|
||||
from .. import config
|
||||
from .. import mock
|
||||
from ..assertions import eq_
|
||||
from ..assertions import ne_
|
||||
from ..util import adict
|
||||
from ..util import drop_all_tables_from_metadata
|
||||
from ... import event
|
||||
from ... import util
|
||||
from ...schema import sort_tables_and_constraints
|
||||
from ...sql import visitors
|
||||
from ...sql.elements import ClauseElement
|
||||
|
||||
|
||||
class TablesTest(TestBase):
|
||||
# 'once', None
|
||||
run_setup_bind = "once"
|
||||
|
||||
# 'once', 'each', None
|
||||
run_define_tables = "once"
|
||||
|
||||
# 'once', 'each', None
|
||||
run_create_tables = "once"
|
||||
|
||||
# 'once', 'each', None
|
||||
run_inserts = "each"
|
||||
|
||||
# 'each', None
|
||||
run_deletes = "each"
|
||||
|
||||
# 'once', None
|
||||
run_dispose_bind = None
|
||||
|
||||
bind = None
|
||||
_tables_metadata = None
|
||||
tables = None
|
||||
other = None
|
||||
sequences = None
|
||||
|
||||
@config.fixture(autouse=True, scope="class")
|
||||
def _setup_tables_test_class(self):
|
||||
cls = self.__class__
|
||||
cls._init_class()
|
||||
|
||||
cls._setup_once_tables()
|
||||
|
||||
cls._setup_once_inserts()
|
||||
|
||||
yield
|
||||
|
||||
cls._teardown_once_metadata_bind()
|
||||
|
||||
@config.fixture(autouse=True, scope="function")
|
||||
def _setup_tables_test_instance(self):
|
||||
self._setup_each_tables()
|
||||
self._setup_each_inserts()
|
||||
|
||||
yield
|
||||
|
||||
self._teardown_each_tables()
|
||||
|
||||
@property
|
||||
def tables_test_metadata(self):
|
||||
return self._tables_metadata
|
||||
|
||||
@classmethod
|
||||
def _init_class(cls):
|
||||
if cls.run_define_tables == "each":
|
||||
if cls.run_create_tables == "once":
|
||||
cls.run_create_tables = "each"
|
||||
assert cls.run_inserts in ("each", None)
|
||||
|
||||
cls.other = adict()
|
||||
cls.tables = adict()
|
||||
cls.sequences = adict()
|
||||
|
||||
cls.bind = cls.setup_bind()
|
||||
cls._tables_metadata = sa.MetaData()
|
||||
|
||||
@classmethod
|
||||
def _setup_once_inserts(cls):
|
||||
if cls.run_inserts == "once":
|
||||
cls._load_fixtures()
|
||||
with cls.bind.begin() as conn:
|
||||
cls.insert_data(conn)
|
||||
|
||||
@classmethod
|
||||
def _setup_once_tables(cls):
|
||||
if cls.run_define_tables == "once":
|
||||
cls.define_tables(cls._tables_metadata)
|
||||
if cls.run_create_tables == "once":
|
||||
cls._tables_metadata.create_all(cls.bind)
|
||||
cls.tables.update(cls._tables_metadata.tables)
|
||||
cls.sequences.update(cls._tables_metadata._sequences)
|
||||
|
||||
def _setup_each_tables(self):
|
||||
if self.run_define_tables == "each":
|
||||
self.define_tables(self._tables_metadata)
|
||||
if self.run_create_tables == "each":
|
||||
self._tables_metadata.create_all(self.bind)
|
||||
self.tables.update(self._tables_metadata.tables)
|
||||
self.sequences.update(self._tables_metadata._sequences)
|
||||
elif self.run_create_tables == "each":
|
||||
self._tables_metadata.create_all(self.bind)
|
||||
|
||||
def _setup_each_inserts(self):
|
||||
if self.run_inserts == "each":
|
||||
self._load_fixtures()
|
||||
with self.bind.begin() as conn:
|
||||
self.insert_data(conn)
|
||||
|
||||
def _teardown_each_tables(self):
|
||||
if self.run_define_tables == "each":
|
||||
self.tables.clear()
|
||||
if self.run_create_tables == "each":
|
||||
drop_all_tables_from_metadata(self._tables_metadata, self.bind)
|
||||
self._tables_metadata.clear()
|
||||
elif self.run_create_tables == "each":
|
||||
drop_all_tables_from_metadata(self._tables_metadata, self.bind)
|
||||
|
||||
savepoints = getattr(config.requirements, "savepoints", False)
|
||||
if savepoints:
|
||||
savepoints = savepoints.enabled
|
||||
|
||||
# no need to run deletes if tables are recreated on setup
|
||||
if (
|
||||
self.run_define_tables != "each"
|
||||
and self.run_create_tables != "each"
|
||||
and self.run_deletes == "each"
|
||||
):
|
||||
with self.bind.begin() as conn:
|
||||
for table in reversed(
|
||||
[
|
||||
t
|
||||
for (t, fks) in sort_tables_and_constraints(
|
||||
self._tables_metadata.tables.values()
|
||||
)
|
||||
if t is not None
|
||||
]
|
||||
):
|
||||
try:
|
||||
if savepoints:
|
||||
with conn.begin_nested():
|
||||
conn.execute(table.delete())
|
||||
else:
|
||||
conn.execute(table.delete())
|
||||
except sa.exc.DBAPIError as ex:
|
||||
print(
|
||||
("Error emptying table %s: %r" % (table, ex)),
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _teardown_once_metadata_bind(cls):
|
||||
if cls.run_create_tables:
|
||||
drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
|
||||
|
||||
if cls.run_dispose_bind == "once":
|
||||
cls.dispose_bind(cls.bind)
|
||||
|
||||
cls._tables_metadata.bind = None
|
||||
|
||||
if cls.run_setup_bind is not None:
|
||||
cls.bind = None
|
||||
|
||||
@classmethod
|
||||
def setup_bind(cls):
|
||||
return config.db
|
||||
|
||||
@classmethod
|
||||
def dispose_bind(cls, bind):
|
||||
if hasattr(bind, "dispose"):
|
||||
bind.dispose()
|
||||
elif hasattr(bind, "close"):
|
||||
bind.close()
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def fixtures(cls):
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls, connection):
|
||||
pass
|
||||
|
||||
def sql_count_(self, count, fn):
|
||||
self.assert_sql_count(self.bind, fn, count)
|
||||
|
||||
def sql_eq_(self, callable_, statements):
|
||||
self.assert_sql(self.bind, callable_, statements)
|
||||
|
||||
@classmethod
|
||||
def _load_fixtures(cls):
|
||||
"""Insert rows as represented by the fixtures() method."""
|
||||
headers, rows = {}, {}
|
||||
for table, data in cls.fixtures().items():
|
||||
if len(data) < 2:
|
||||
continue
|
||||
if isinstance(table, str):
|
||||
table = cls.tables[table]
|
||||
headers[table] = data[0]
|
||||
rows[table] = data[1:]
|
||||
for table, fks in sort_tables_and_constraints(
|
||||
cls._tables_metadata.tables.values()
|
||||
):
|
||||
if table is None:
|
||||
continue
|
||||
if table not in headers:
|
||||
continue
|
||||
with cls.bind.begin() as conn:
|
||||
conn.execute(
|
||||
table.insert(),
|
||||
[
|
||||
dict(zip(headers[table], column_values))
|
||||
for column_values in rows[table]
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class NoCache:
|
||||
@config.fixture(autouse=True, scope="function")
|
||||
def _disable_cache(self):
|
||||
_cache = config.db._compiled_cache
|
||||
config.db._compiled_cache = None
|
||||
yield
|
||||
config.db._compiled_cache = _cache
|
||||
|
||||
|
||||
class RemovesEvents:
|
||||
@util.memoized_property
|
||||
def _event_fns(self):
|
||||
return set()
|
||||
|
||||
def event_listen(self, target, name, fn, **kw):
|
||||
self._event_fns.add((target, name, fn))
|
||||
event.listen(target, name, fn, **kw)
|
||||
|
||||
@config.fixture(autouse=True, scope="function")
|
||||
def _remove_events(self):
|
||||
yield
|
||||
for key in self._event_fns:
|
||||
event.remove(*key)
|
||||
|
||||
|
||||
class ComputedReflectionFixtureTest(TablesTest):
|
||||
run_inserts = run_deletes = None
|
||||
|
||||
__backend__ = True
|
||||
__requires__ = ("computed_columns", "table_reflection")
|
||||
|
||||
regexp = re.compile(r"[\[\]\(\)\s`'\"]*")
|
||||
|
||||
def normalize(self, text):
|
||||
return self.regexp.sub("", text).lower()
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
from ... import Integer
|
||||
from ... import testing
|
||||
from ...schema import Column
|
||||
from ...schema import Computed
|
||||
from ...schema import Table
|
||||
|
||||
Table(
|
||||
"computed_default_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("normal", Integer),
|
||||
Column("computed_col", Integer, Computed("normal + 42")),
|
||||
Column("with_default", Integer, server_default="42"),
|
||||
)
|
||||
|
||||
t = Table(
|
||||
"computed_column_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("normal", Integer),
|
||||
Column("computed_no_flag", Integer, Computed("normal + 42")),
|
||||
)
|
||||
|
||||
if testing.requires.schemas.enabled:
|
||||
t2 = Table(
|
||||
"computed_column_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("normal", Integer),
|
||||
Column("computed_no_flag", Integer, Computed("normal / 42")),
|
||||
schema=config.test_schema,
|
||||
)
|
||||
|
||||
if testing.requires.computed_columns_virtual.enabled:
|
||||
t.append_column(
|
||||
Column(
|
||||
"computed_virtual",
|
||||
Integer,
|
||||
Computed("normal + 2", persisted=False),
|
||||
)
|
||||
)
|
||||
if testing.requires.schemas.enabled:
|
||||
t2.append_column(
|
||||
Column(
|
||||
"computed_virtual",
|
||||
Integer,
|
||||
Computed("normal / 2", persisted=False),
|
||||
)
|
||||
)
|
||||
if testing.requires.computed_columns_stored.enabled:
|
||||
t.append_column(
|
||||
Column(
|
||||
"computed_stored",
|
||||
Integer,
|
||||
Computed("normal - 42", persisted=True),
|
||||
)
|
||||
)
|
||||
if testing.requires.schemas.enabled:
|
||||
t2.append_column(
|
||||
Column(
|
||||
"computed_stored",
|
||||
Integer,
|
||||
Computed("normal * 42", persisted=True),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class CacheKeyFixture:
|
||||
def _compare_equal(self, a, b, compare_values):
|
||||
a_key = a._generate_cache_key()
|
||||
b_key = b._generate_cache_key()
|
||||
|
||||
if a_key is None:
|
||||
assert a._annotations.get("nocache")
|
||||
|
||||
assert b_key is None
|
||||
else:
|
||||
eq_(a_key.key, b_key.key)
|
||||
eq_(hash(a_key.key), hash(b_key.key))
|
||||
|
||||
for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
|
||||
assert a_param.compare(b_param, compare_values=compare_values)
|
||||
return a_key, b_key
|
||||
|
||||
def _run_cache_key_fixture(self, fixture, compare_values):
|
||||
case_a = fixture()
|
||||
case_b = fixture()
|
||||
|
||||
for a, b in itertools.combinations_with_replacement(
|
||||
range(len(case_a)), 2
|
||||
):
|
||||
if a == b:
|
||||
a_key, b_key = self._compare_equal(
|
||||
case_a[a], case_b[b], compare_values
|
||||
)
|
||||
if a_key is None:
|
||||
continue
|
||||
else:
|
||||
a_key = case_a[a]._generate_cache_key()
|
||||
b_key = case_b[b]._generate_cache_key()
|
||||
|
||||
if a_key is None or b_key is None:
|
||||
if a_key is None:
|
||||
assert case_a[a]._annotations.get("nocache")
|
||||
if b_key is None:
|
||||
assert case_b[b]._annotations.get("nocache")
|
||||
continue
|
||||
|
||||
if a_key.key == b_key.key:
|
||||
for a_param, b_param in zip(
|
||||
a_key.bindparams, b_key.bindparams
|
||||
):
|
||||
if not a_param.compare(
|
||||
b_param, compare_values=compare_values
|
||||
):
|
||||
break
|
||||
else:
|
||||
# this fails unconditionally since we could not
|
||||
# find bound parameter values that differed.
|
||||
# Usually we intended to get two distinct keys here
|
||||
# so the failure will be more descriptive using the
|
||||
# ne_() assertion.
|
||||
ne_(a_key.key, b_key.key)
|
||||
else:
|
||||
ne_(a_key.key, b_key.key)
|
||||
|
||||
# ClauseElement-specific test to ensure the cache key
|
||||
# collected all the bound parameters that aren't marked
|
||||
# as "literal execute"
|
||||
if isinstance(case_a[a], ClauseElement) and isinstance(
|
||||
case_b[b], ClauseElement
|
||||
):
|
||||
assert_a_params = []
|
||||
assert_b_params = []
|
||||
|
||||
for elem in visitors.iterate(case_a[a]):
|
||||
if elem.__visit_name__ == "bindparam":
|
||||
assert_a_params.append(elem)
|
||||
|
||||
for elem in visitors.iterate(case_b[b]):
|
||||
if elem.__visit_name__ == "bindparam":
|
||||
assert_b_params.append(elem)
|
||||
|
||||
# note we're asserting the order of the params as well as
|
||||
# if there are dupes or not. ordering has to be
|
||||
# deterministic and matches what a traversal would provide.
|
||||
eq_(
|
||||
sorted(a_key.bindparams, key=lambda b: b.key),
|
||||
sorted(
|
||||
util.unique_list(assert_a_params), key=lambda b: b.key
|
||||
),
|
||||
)
|
||||
eq_(
|
||||
sorted(b_key.bindparams, key=lambda b: b.key),
|
||||
sorted(
|
||||
util.unique_list(assert_b_params), key=lambda b: b.key
|
||||
),
|
||||
)
|
||||
|
||||
def _run_cache_key_equal_fixture(self, fixture, compare_values):
|
||||
case_a = fixture()
|
||||
case_b = fixture()
|
||||
|
||||
for a, b in itertools.combinations_with_replacement(
|
||||
range(len(case_a)), 2
|
||||
):
|
||||
self._compare_equal(case_a[a], case_b[b], compare_values)
|
||||
|
||||
|
||||
def insertmanyvalues_fixture(
|
||||
connection, randomize_rows=False, warn_on_downgraded=False
|
||||
):
|
||||
dialect = connection.dialect
|
||||
orig_dialect = dialect._deliver_insertmanyvalues_batches
|
||||
orig_conn = connection._exec_insertmany_context
|
||||
|
||||
class RandomCursor:
|
||||
__slots__ = ("cursor",)
|
||||
|
||||
def __init__(self, cursor):
|
||||
self.cursor = cursor
|
||||
|
||||
# only this method is called by the deliver method.
|
||||
# by not having the other methods we assert that those aren't being
|
||||
# used
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return self.cursor.description
|
||||
|
||||
def fetchall(self):
|
||||
rows = self.cursor.fetchall()
|
||||
rows = list(rows)
|
||||
random.shuffle(rows)
|
||||
return rows
|
||||
|
||||
def _deliver_insertmanyvalues_batches(
|
||||
connection,
|
||||
cursor,
|
||||
statement,
|
||||
parameters,
|
||||
generic_setinputsizes,
|
||||
context,
|
||||
):
|
||||
if randomize_rows:
|
||||
cursor = RandomCursor(cursor)
|
||||
for batch in orig_dialect(
|
||||
connection,
|
||||
cursor,
|
||||
statement,
|
||||
parameters,
|
||||
generic_setinputsizes,
|
||||
context,
|
||||
):
|
||||
if warn_on_downgraded and batch.is_downgraded:
|
||||
util.warn("Batches were downgraded for sorted INSERT")
|
||||
|
||||
yield batch
|
||||
|
||||
def _exec_insertmany_context(dialect, context):
|
||||
with mock.patch.object(
|
||||
dialect,
|
||||
"_deliver_insertmanyvalues_batches",
|
||||
new=_deliver_insertmanyvalues_batches,
|
||||
):
|
||||
return orig_conn(dialect, context)
|
||||
|
||||
connection._exec_insertmany_context = _exec_insertmany_context
|
@ -0,0 +1,155 @@
|
||||
# testing/pickleable.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
"""Classes used in pickling tests, need to be at the module level for
|
||||
unpickling.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .entities import ComparableEntity
|
||||
from ..schema import Column
|
||||
from ..types import String
|
||||
|
||||
|
||||
class User(ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Order(ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Dingaling(ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class EmailUser(User):
|
||||
pass
|
||||
|
||||
|
||||
class Address(ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
# TODO: these are kind of arbitrary....
|
||||
class Child1(ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Child2(ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Parent(ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Screen:
|
||||
def __init__(self, obj, parent=None):
|
||||
self.obj = obj
|
||||
self.parent = parent
|
||||
|
||||
|
||||
class Mixin:
|
||||
email_address = Column(String)
|
||||
|
||||
|
||||
class AddressWMixin(Mixin, ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Foo:
|
||||
def __init__(self, moredata, stuff="im stuff"):
|
||||
self.data = "im data"
|
||||
self.stuff = stuff
|
||||
self.moredata = moredata
|
||||
|
||||
__hash__ = object.__hash__
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
other.data == self.data
|
||||
and other.stuff == self.stuff
|
||||
and other.moredata == self.moredata
|
||||
)
|
||||
|
||||
|
||||
class Bar:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
__hash__ = object.__hash__
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
other.__class__ is self.__class__
|
||||
and other.x == self.x
|
||||
and other.y == self.y
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return "Bar(%d, %d)" % (self.x, self.y)
|
||||
|
||||
|
||||
class OldSchool:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
other.__class__ is self.__class__
|
||||
and other.x == self.x
|
||||
and other.y == self.y
|
||||
)
|
||||
|
||||
|
||||
class OldSchoolWithoutCompare:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
|
||||
class BarWithoutCompare:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __str__(self):
|
||||
return "Bar(%d, %d)" % (self.x, self.y)
|
||||
|
||||
|
||||
class NotComparable:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class BrokenComparable:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
raise NotImplementedError
|
||||
|
||||
def __ne__(self, other):
|
||||
raise NotImplementedError
|
@ -0,0 +1,6 @@
|
||||
# testing/plugin/__init__.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,51 @@
|
||||
# testing/plugin/bootstrap.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
Bootstrapper for test framework plugins.
|
||||
|
||||
The entire rationale for this system is to get the modules in plugin/
|
||||
imported without importing all of the supporting library, so that we can
|
||||
set up things for testing before coverage starts.
|
||||
|
||||
The rationale for all of plugin/ being *in* the supporting library in the
|
||||
first place is so that the testing and plugin suite is available to other
|
||||
libraries, mainly external SQLAlchemy and Alembic dialects, to make use
|
||||
of the same test environment and standard suites available to
|
||||
SQLAlchemy/Alembic themselves without the need to ship/install a separate
|
||||
package outside of SQLAlchemy.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
bootstrap_file = locals()["bootstrap_file"]
|
||||
to_bootstrap = locals()["to_bootstrap"]
|
||||
|
||||
|
||||
def load_file_as_module(name):
|
||||
path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
|
||||
|
||||
spec = importlib.util.spec_from_file_location(name, path)
|
||||
assert spec is not None
|
||||
assert spec.loader is not None
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
if to_bootstrap == "pytest":
|
||||
sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
|
||||
sys.modules["sqla_plugin_base"].bootstrapped_as_sqlalchemy = True
|
||||
sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
|
||||
else:
|
||||
raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa
|
@ -0,0 +1,779 @@
|
||||
# testing/plugin/plugin_base.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from argparse import Namespace
|
||||
import configparser
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
"""Testing extensions.
|
||||
|
||||
this module is designed to work as a testing-framework-agnostic library,
|
||||
created so that multiple test frameworks can be supported at once
|
||||
(mostly so that we can migrate to new ones). The current target
|
||||
is pytest.
|
||||
|
||||
"""
|
||||
|
||||
# flag which indicates we are in the SQLAlchemy testing suite,
|
||||
# and not that of Alembic or a third party dialect.
|
||||
bootstrapped_as_sqlalchemy = False
|
||||
|
||||
log = logging.getLogger("sqlalchemy.testing.plugin_base")
|
||||
|
||||
# late imports
|
||||
fixtures = None
|
||||
engines = None
|
||||
exclusions = None
|
||||
warnings = None
|
||||
profiling = None
|
||||
provision = None
|
||||
assertions = None
|
||||
requirements = None
|
||||
config = None
|
||||
testing = None
|
||||
util = None
|
||||
file_config = None
|
||||
|
||||
logging = None
|
||||
include_tags = set()
|
||||
exclude_tags = set()
|
||||
options: Namespace = None # type: ignore
|
||||
|
||||
|
||||
def setup_options(make_option):
|
||||
make_option(
|
||||
"--log-info",
|
||||
action="callback",
|
||||
type=str,
|
||||
callback=_log,
|
||||
help="turn on info logging for <LOG> (multiple OK)",
|
||||
)
|
||||
make_option(
|
||||
"--log-debug",
|
||||
action="callback",
|
||||
type=str,
|
||||
callback=_log,
|
||||
help="turn on debug logging for <LOG> (multiple OK)",
|
||||
)
|
||||
make_option(
|
||||
"--db",
|
||||
action="append",
|
||||
type=str,
|
||||
dest="db",
|
||||
help="Use prefab database uri. Multiple OK, "
|
||||
"first one is run by default.",
|
||||
)
|
||||
make_option(
|
||||
"--dbs",
|
||||
action="callback",
|
||||
zeroarg_callback=_list_dbs,
|
||||
help="List available prefab dbs",
|
||||
)
|
||||
make_option(
|
||||
"--dburi",
|
||||
action="append",
|
||||
type=str,
|
||||
dest="dburi",
|
||||
help="Database uri. Multiple OK, first one is run by default.",
|
||||
)
|
||||
make_option(
|
||||
"--dbdriver",
|
||||
action="append",
|
||||
type=str,
|
||||
dest="dbdriver",
|
||||
help="Additional database drivers to include in tests. "
|
||||
"These are linked to the existing database URLs by the "
|
||||
"provisioning system.",
|
||||
)
|
||||
make_option(
|
||||
"--dropfirst",
|
||||
action="store_true",
|
||||
dest="dropfirst",
|
||||
help="Drop all tables in the target database first",
|
||||
)
|
||||
make_option(
|
||||
"--disable-asyncio",
|
||||
action="store_true",
|
||||
help="disable test / fixtures / provisoning running in asyncio",
|
||||
)
|
||||
make_option(
|
||||
"--backend-only",
|
||||
action="callback",
|
||||
zeroarg_callback=_set_tag_include("backend"),
|
||||
help=(
|
||||
"Run only tests marked with __backend__ or __sparse_backend__; "
|
||||
"this is now equivalent to the pytest -m backend mark expression"
|
||||
),
|
||||
)
|
||||
make_option(
|
||||
"--nomemory",
|
||||
action="callback",
|
||||
zeroarg_callback=_set_tag_exclude("memory_intensive"),
|
||||
help="Don't run memory profiling tests; "
|
||||
"this is now equivalent to the pytest -m 'not memory_intensive' "
|
||||
"mark expression",
|
||||
)
|
||||
make_option(
|
||||
"--notimingintensive",
|
||||
action="callback",
|
||||
zeroarg_callback=_set_tag_exclude("timing_intensive"),
|
||||
help="Don't run timing intensive tests; "
|
||||
"this is now equivalent to the pytest -m 'not timing_intensive' "
|
||||
"mark expression",
|
||||
)
|
||||
make_option(
|
||||
"--nomypy",
|
||||
action="callback",
|
||||
zeroarg_callback=_set_tag_exclude("mypy"),
|
||||
help="Don't run mypy typing tests; "
|
||||
"this is now equivalent to the pytest -m 'not mypy' mark expression",
|
||||
)
|
||||
make_option(
|
||||
"--profile-sort",
|
||||
type=str,
|
||||
default="cumulative",
|
||||
dest="profilesort",
|
||||
help="Type of sort for profiling standard output",
|
||||
)
|
||||
make_option(
|
||||
"--profile-dump",
|
||||
type=str,
|
||||
dest="profiledump",
|
||||
help="Filename where a single profile run will be dumped",
|
||||
)
|
||||
make_option(
|
||||
"--low-connections",
|
||||
action="store_true",
|
||||
dest="low_connections",
|
||||
help="Use a low number of distinct connections - "
|
||||
"i.e. for Oracle TNS",
|
||||
)
|
||||
make_option(
|
||||
"--write-idents",
|
||||
type=str,
|
||||
dest="write_idents",
|
||||
help="write out generated follower idents to <file>, "
|
||||
"when -n<num> is used",
|
||||
)
|
||||
make_option(
|
||||
"--requirements",
|
||||
action="callback",
|
||||
type=str,
|
||||
callback=_requirements_opt,
|
||||
help="requirements class for testing, overrides setup.cfg",
|
||||
)
|
||||
make_option(
|
||||
"--include-tag",
|
||||
action="callback",
|
||||
callback=_include_tag,
|
||||
type=str,
|
||||
help="Include tests with tag <tag>; "
|
||||
"legacy, use pytest -m 'tag' instead",
|
||||
)
|
||||
make_option(
|
||||
"--exclude-tag",
|
||||
action="callback",
|
||||
callback=_exclude_tag,
|
||||
type=str,
|
||||
help="Exclude tests with tag <tag>; "
|
||||
"legacy, use pytest -m 'not tag' instead",
|
||||
)
|
||||
make_option(
|
||||
"--write-profiles",
|
||||
action="store_true",
|
||||
dest="write_profiles",
|
||||
default=False,
|
||||
help="Write/update failing profiling data.",
|
||||
)
|
||||
make_option(
|
||||
"--force-write-profiles",
|
||||
action="store_true",
|
||||
dest="force_write_profiles",
|
||||
default=False,
|
||||
help="Unconditionally write/update profiling data.",
|
||||
)
|
||||
make_option(
|
||||
"--dump-pyannotate",
|
||||
type=str,
|
||||
dest="dump_pyannotate",
|
||||
help="Run pyannotate and dump json info to given file",
|
||||
)
|
||||
make_option(
|
||||
"--mypy-extra-test-path",
|
||||
type=str,
|
||||
action="append",
|
||||
default=[],
|
||||
dest="mypy_extra_test_paths",
|
||||
help="Additional test directories to add to the mypy tests. "
|
||||
"This is used only when running mypy tests. Multiple OK",
|
||||
)
|
||||
# db specific options
|
||||
make_option(
|
||||
"--postgresql-templatedb",
|
||||
type=str,
|
||||
help="name of template database to use for PostgreSQL "
|
||||
"CREATE DATABASE (defaults to current database)",
|
||||
)
|
||||
make_option(
|
||||
"--oracledb-thick-mode",
|
||||
action="store_true",
|
||||
help="enables the 'thick mode' when testing with oracle+oracledb",
|
||||
)
|
||||
|
||||
|
||||
def configure_follower(follower_ident):
|
||||
"""Configure required state for a follower.
|
||||
|
||||
This invokes in the parent process and typically includes
|
||||
database creation.
|
||||
|
||||
"""
|
||||
from sqlalchemy.testing import provision
|
||||
|
||||
provision.FOLLOWER_IDENT = follower_ident
|
||||
|
||||
|
||||
def memoize_important_follower_config(dict_):
|
||||
"""Store important configuration we will need to send to a follower.
|
||||
|
||||
This invokes in the parent process after normal config is set up.
|
||||
|
||||
Hook is currently not used.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def restore_important_follower_config(dict_):
|
||||
"""Restore important configuration needed by a follower.
|
||||
|
||||
This invokes in the follower process.
|
||||
|
||||
Hook is currently not used.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def read_config(root_path):
|
||||
global file_config
|
||||
file_config = configparser.ConfigParser()
|
||||
file_config.read(
|
||||
[str(root_path / "setup.cfg"), str(root_path / "test.cfg")]
|
||||
)
|
||||
|
||||
|
||||
def pre_begin(opt):
|
||||
"""things to set up early, before coverage might be setup."""
|
||||
global options
|
||||
options = opt
|
||||
for fn in pre_configure:
|
||||
fn(options, file_config)
|
||||
|
||||
|
||||
def set_coverage_flag(value):
|
||||
options.has_coverage = value
|
||||
|
||||
|
||||
def post_begin():
|
||||
"""things to set up later, once we know coverage is running."""
|
||||
# Lazy setup of other options (post coverage)
|
||||
for fn in post_configure:
|
||||
fn(options, file_config)
|
||||
|
||||
# late imports, has to happen after config.
|
||||
global util, fixtures, engines, exclusions, assertions, provision
|
||||
global warnings, profiling, config, testing
|
||||
from sqlalchemy import testing # noqa
|
||||
from sqlalchemy.testing import fixtures, engines, exclusions # noqa
|
||||
from sqlalchemy.testing import assertions, warnings, profiling # noqa
|
||||
from sqlalchemy.testing import config, provision # noqa
|
||||
from sqlalchemy import util # noqa
|
||||
|
||||
warnings.setup_filters()
|
||||
|
||||
|
||||
def _log(opt_str, value, parser):
|
||||
global logging
|
||||
if not logging:
|
||||
import logging
|
||||
|
||||
logging.basicConfig()
|
||||
|
||||
if opt_str.endswith("-info"):
|
||||
logging.getLogger(value).setLevel(logging.INFO)
|
||||
elif opt_str.endswith("-debug"):
|
||||
logging.getLogger(value).setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def _list_dbs(*args):
|
||||
if file_config is None:
|
||||
# assume the current working directory is the one containing the
|
||||
# setup file
|
||||
read_config(Path.cwd())
|
||||
print("Available --db options (use --dburi to override)")
|
||||
for macro in sorted(file_config.options("db")):
|
||||
print("%20s\t%s" % (macro, file_config.get("db", macro)))
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def _requirements_opt(opt_str, value, parser):
|
||||
_setup_requirements(value)
|
||||
|
||||
|
||||
def _set_tag_include(tag):
|
||||
def _do_include_tag(opt_str, value, parser):
|
||||
_include_tag(opt_str, tag, parser)
|
||||
|
||||
return _do_include_tag
|
||||
|
||||
|
||||
def _set_tag_exclude(tag):
|
||||
def _do_exclude_tag(opt_str, value, parser):
|
||||
_exclude_tag(opt_str, tag, parser)
|
||||
|
||||
return _do_exclude_tag
|
||||
|
||||
|
||||
def _exclude_tag(opt_str, value, parser):
|
||||
exclude_tags.add(value.replace("-", "_"))
|
||||
|
||||
|
||||
def _include_tag(opt_str, value, parser):
|
||||
include_tags.add(value.replace("-", "_"))
|
||||
|
||||
|
||||
pre_configure = []
|
||||
post_configure = []
|
||||
|
||||
|
||||
def pre(fn):
|
||||
pre_configure.append(fn)
|
||||
return fn
|
||||
|
||||
|
||||
def post(fn):
|
||||
post_configure.append(fn)
|
||||
return fn
|
||||
|
||||
|
||||
@pre
|
||||
def _setup_options(opt, file_config):
|
||||
global options
|
||||
options = opt
|
||||
|
||||
|
||||
@pre
|
||||
def _register_sqlite_numeric_dialect(opt, file_config):
|
||||
from sqlalchemy.dialects import registry
|
||||
|
||||
registry.register(
|
||||
"sqlite.pysqlite_numeric",
|
||||
"sqlalchemy.dialects.sqlite.pysqlite",
|
||||
"_SQLiteDialect_pysqlite_numeric",
|
||||
)
|
||||
registry.register(
|
||||
"sqlite.pysqlite_dollar",
|
||||
"sqlalchemy.dialects.sqlite.pysqlite",
|
||||
"_SQLiteDialect_pysqlite_dollar",
|
||||
)
|
||||
|
||||
|
||||
@post
|
||||
def __ensure_cext(opt, file_config):
|
||||
if os.environ.get("REQUIRE_SQLALCHEMY_CEXT", "0") == "1":
|
||||
from sqlalchemy.util import has_compiled_ext
|
||||
|
||||
try:
|
||||
has_compiled_ext(raise_=True)
|
||||
except ImportError as err:
|
||||
raise AssertionError(
|
||||
"REQUIRE_SQLALCHEMY_CEXT is set but can't import the "
|
||||
"cython extensions"
|
||||
) from err
|
||||
|
||||
|
||||
@post
|
||||
def _init_symbols(options, file_config):
|
||||
from sqlalchemy.testing import config
|
||||
|
||||
config._fixture_functions = _fixture_fn_class()
|
||||
|
||||
|
||||
@pre
|
||||
def _set_disable_asyncio(opt, file_config):
|
||||
if opt.disable_asyncio:
|
||||
asyncio.ENABLE_ASYNCIO = False
|
||||
|
||||
|
||||
@post
|
||||
def _engine_uri(options, file_config):
|
||||
from sqlalchemy import testing
|
||||
from sqlalchemy.testing import config
|
||||
from sqlalchemy.testing import provision
|
||||
from sqlalchemy.engine import url as sa_url
|
||||
|
||||
if options.dburi:
|
||||
db_urls = list(options.dburi)
|
||||
else:
|
||||
db_urls = []
|
||||
|
||||
extra_drivers = options.dbdriver or []
|
||||
|
||||
if options.db:
|
||||
for db_token in options.db:
|
||||
for db in re.split(r"[,\s]+", db_token):
|
||||
if db not in file_config.options("db"):
|
||||
raise RuntimeError(
|
||||
"Unknown URI specifier '%s'. "
|
||||
"Specify --dbs for known uris." % db
|
||||
)
|
||||
else:
|
||||
db_urls.append(file_config.get("db", db))
|
||||
|
||||
if not db_urls:
|
||||
db_urls.append(file_config.get("db", "default"))
|
||||
|
||||
config._current = None
|
||||
|
||||
if options.write_idents and provision.FOLLOWER_IDENT:
|
||||
for db_url in [sa_url.make_url(db_url) for db_url in db_urls]:
|
||||
with open(options.write_idents, "a") as file_:
|
||||
file_.write(
|
||||
f"{provision.FOLLOWER_IDENT} "
|
||||
f"{db_url.render_as_string(hide_password=False)}\n"
|
||||
)
|
||||
|
||||
expanded_urls = list(provision.generate_db_urls(db_urls, extra_drivers))
|
||||
|
||||
for db_url in expanded_urls:
|
||||
log.info("Adding database URL: %s", db_url)
|
||||
|
||||
cfg = provision.setup_config(
|
||||
db_url, options, file_config, provision.FOLLOWER_IDENT
|
||||
)
|
||||
if not config._current:
|
||||
cfg.set_as_current(cfg, testing)
|
||||
|
||||
|
||||
@post
|
||||
def _requirements(options, file_config):
|
||||
requirement_cls = file_config.get("sqla_testing", "requirement_cls")
|
||||
_setup_requirements(requirement_cls)
|
||||
|
||||
|
||||
def _setup_requirements(argument):
|
||||
from sqlalchemy.testing import config
|
||||
from sqlalchemy import testing
|
||||
|
||||
modname, clsname = argument.split(":")
|
||||
|
||||
# importlib.import_module() only introduced in 2.7, a little
|
||||
# late
|
||||
mod = __import__(modname)
|
||||
for component in modname.split(".")[1:]:
|
||||
mod = getattr(mod, component)
|
||||
req_cls = getattr(mod, clsname)
|
||||
|
||||
config.requirements = testing.requires = req_cls()
|
||||
|
||||
config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy
|
||||
|
||||
|
||||
@post
|
||||
def _prep_testing_database(options, file_config):
|
||||
from sqlalchemy.testing import config
|
||||
|
||||
if options.dropfirst:
|
||||
from sqlalchemy.testing import provision
|
||||
|
||||
for cfg in config.Config.all_configs():
|
||||
provision.drop_all_schema_objects(cfg, cfg.db)
|
||||
|
||||
|
||||
@post
|
||||
def _post_setup_options(opt, file_config):
|
||||
from sqlalchemy.testing import config
|
||||
|
||||
config.options = options
|
||||
config.file_config = file_config
|
||||
|
||||
|
||||
@post
|
||||
def _setup_profiling(options, file_config):
|
||||
from sqlalchemy.testing import profiling
|
||||
|
||||
profiling._profile_stats = profiling.ProfileStatsFile(
|
||||
file_config.get("sqla_testing", "profile_file"),
|
||||
sort=options.profilesort,
|
||||
dump=options.profiledump,
|
||||
)
|
||||
|
||||
|
||||
def want_class(name, cls):
|
||||
if not issubclass(cls, fixtures.TestBase):
|
||||
return False
|
||||
elif name.startswith("_"):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def want_method(cls, fn):
|
||||
if not fn.__name__.startswith("test_"):
|
||||
return False
|
||||
elif fn.__module__ is None:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def generate_sub_tests(cls, module, markers):
|
||||
if "backend" in markers or "sparse_backend" in markers:
|
||||
sparse = "sparse_backend" in markers
|
||||
for cfg in _possible_configs_for_cls(cls, sparse=sparse):
|
||||
orig_name = cls.__name__
|
||||
|
||||
# we can have special chars in these names except for the
|
||||
# pytest junit plugin, which is tripped up by the brackets
|
||||
# and periods, so sanitize
|
||||
|
||||
alpha_name = re.sub(r"[_\[\]\.]+", "_", cfg.name)
|
||||
alpha_name = re.sub(r"_+$", "", alpha_name)
|
||||
name = "%s_%s" % (cls.__name__, alpha_name)
|
||||
subcls = type(
|
||||
name,
|
||||
(cls,),
|
||||
{"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg},
|
||||
)
|
||||
setattr(module, name, subcls)
|
||||
yield subcls
|
||||
else:
|
||||
yield cls
|
||||
|
||||
|
||||
def start_test_class_outside_fixtures(cls):
|
||||
_do_skips(cls)
|
||||
_setup_engine(cls)
|
||||
|
||||
|
||||
def stop_test_class(cls):
|
||||
# close sessions, immediate connections, etc.
|
||||
fixtures.stop_test_class_inside_fixtures(cls)
|
||||
|
||||
# close outstanding connection pool connections, dispose of
|
||||
# additional engines
|
||||
engines.testing_reaper.stop_test_class_inside_fixtures()
|
||||
|
||||
|
||||
def stop_test_class_outside_fixtures(cls):
|
||||
engines.testing_reaper.stop_test_class_outside_fixtures()
|
||||
provision.stop_test_class_outside_fixtures(config, config.db, cls)
|
||||
try:
|
||||
if not options.low_connections:
|
||||
assertions.global_cleanup_assertions()
|
||||
finally:
|
||||
_restore_engine()
|
||||
|
||||
|
||||
def _restore_engine():
|
||||
if config._current:
|
||||
config._current.reset(testing)
|
||||
|
||||
|
||||
def final_process_cleanup():
|
||||
engines.testing_reaper.final_cleanup()
|
||||
assertions.global_cleanup_assertions()
|
||||
_restore_engine()
|
||||
|
||||
|
||||
def _setup_engine(cls):
|
||||
if getattr(cls, "__engine_options__", None):
|
||||
opts = dict(cls.__engine_options__)
|
||||
opts["scope"] = "class"
|
||||
eng = engines.testing_engine(options=opts)
|
||||
config._current.push_engine(eng, testing)
|
||||
|
||||
|
||||
def before_test(test, test_module_name, test_class, test_name):
|
||||
# format looks like:
|
||||
# "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause"
|
||||
|
||||
name = getattr(test_class, "_sa_orig_cls_name", test_class.__name__)
|
||||
|
||||
id_ = "%s.%s.%s" % (test_module_name, name, test_name)
|
||||
|
||||
profiling._start_current_test(id_)
|
||||
|
||||
|
||||
def after_test(test):
|
||||
fixtures.after_test()
|
||||
engines.testing_reaper.after_test()
|
||||
|
||||
|
||||
def after_test_fixtures(test):
|
||||
engines.testing_reaper.after_test_outside_fixtures(test)
|
||||
|
||||
|
||||
def _possible_configs_for_cls(cls, reasons=None, sparse=False):
|
||||
all_configs = set(config.Config.all_configs())
|
||||
|
||||
if cls.__unsupported_on__:
|
||||
spec = exclusions.db_spec(*cls.__unsupported_on__)
|
||||
for config_obj in list(all_configs):
|
||||
if spec(config_obj):
|
||||
all_configs.remove(config_obj)
|
||||
|
||||
if getattr(cls, "__only_on__", None):
|
||||
spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
|
||||
for config_obj in list(all_configs):
|
||||
if not spec(config_obj):
|
||||
all_configs.remove(config_obj)
|
||||
|
||||
if getattr(cls, "__only_on_config__", None):
|
||||
all_configs.intersection_update([cls.__only_on_config__])
|
||||
|
||||
if hasattr(cls, "__requires__"):
|
||||
requirements = config.requirements
|
||||
for config_obj in list(all_configs):
|
||||
for requirement in cls.__requires__:
|
||||
check = getattr(requirements, requirement)
|
||||
|
||||
skip_reasons = check.matching_config_reasons(config_obj)
|
||||
if skip_reasons:
|
||||
all_configs.remove(config_obj)
|
||||
if reasons is not None:
|
||||
reasons.extend(skip_reasons)
|
||||
break
|
||||
|
||||
if hasattr(cls, "__prefer_requires__"):
|
||||
non_preferred = set()
|
||||
requirements = config.requirements
|
||||
for config_obj in list(all_configs):
|
||||
for requirement in cls.__prefer_requires__:
|
||||
check = getattr(requirements, requirement)
|
||||
|
||||
if not check.enabled_for_config(config_obj):
|
||||
non_preferred.add(config_obj)
|
||||
if all_configs.difference(non_preferred):
|
||||
all_configs.difference_update(non_preferred)
|
||||
|
||||
if sparse:
|
||||
# pick only one config from each base dialect
|
||||
# sorted so we get the same backend each time selecting the highest
|
||||
# server version info.
|
||||
per_dialect = {}
|
||||
for cfg in reversed(
|
||||
sorted(
|
||||
all_configs,
|
||||
key=lambda cfg: (
|
||||
cfg.db.name,
|
||||
cfg.db.driver,
|
||||
cfg.db.dialect.server_version_info,
|
||||
),
|
||||
)
|
||||
):
|
||||
db = cfg.db.name
|
||||
if db not in per_dialect:
|
||||
per_dialect[db] = cfg
|
||||
return per_dialect.values()
|
||||
|
||||
return all_configs
|
||||
|
||||
|
||||
def _do_skips(cls):
|
||||
reasons = []
|
||||
all_configs = _possible_configs_for_cls(cls, reasons)
|
||||
|
||||
if getattr(cls, "__skip_if__", False):
|
||||
for c in getattr(cls, "__skip_if__"):
|
||||
if c():
|
||||
config.skip_test(
|
||||
"'%s' skipped by %s" % (cls.__name__, c.__name__)
|
||||
)
|
||||
|
||||
if not all_configs:
|
||||
msg = "'%s.%s' unsupported on any DB implementation %s%s" % (
|
||||
cls.__module__,
|
||||
cls.__name__,
|
||||
", ".join(
|
||||
"'%s(%s)+%s'"
|
||||
% (
|
||||
config_obj.db.name,
|
||||
".".join(
|
||||
str(dig)
|
||||
for dig in exclusions._server_version(config_obj.db)
|
||||
),
|
||||
config_obj.db.driver,
|
||||
)
|
||||
for config_obj in config.Config.all_configs()
|
||||
),
|
||||
", ".join(reasons),
|
||||
)
|
||||
config.skip_test(msg)
|
||||
elif hasattr(cls, "__prefer_backends__"):
|
||||
non_preferred = set()
|
||||
spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
|
||||
for config_obj in all_configs:
|
||||
if not spec(config_obj):
|
||||
non_preferred.add(config_obj)
|
||||
if all_configs.difference(non_preferred):
|
||||
all_configs.difference_update(non_preferred)
|
||||
|
||||
if config._current not in all_configs:
|
||||
_setup_config(all_configs.pop(), cls)
|
||||
|
||||
|
||||
def _setup_config(config_obj, ctx):
|
||||
config._current.push(config_obj, testing)
|
||||
|
||||
|
||||
class FixtureFunctions(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def skip_test_exception(self, *arg, **kw):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def combinations(self, *args, **kw):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def param_ident(self, *args, **kw):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def fixture(self, *arg, **kw):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_current_test_name(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def mark_base_test_class(self) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractproperty
|
||||
def add_to_marker(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
_fixture_fn_class = None
|
||||
|
||||
|
||||
def set_fixture_functions(fixture_fn_class):
|
||||
global _fixture_fn_class
|
||||
_fixture_fn_class = fixture_fn_class
|
@ -0,0 +1,868 @@
|
||||
# testing/plugin/pytestplugin.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
from functools import update_wrapper
|
||||
import inspect
|
||||
import itertools
|
||||
import operator
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
# installed by bootstrap.py
|
||||
if not TYPE_CHECKING:
|
||||
import sqla_plugin_base as plugin_base
|
||||
except ImportError:
|
||||
# assume we're a package, use traditional import
|
||||
from . import plugin_base
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
group = parser.getgroup("sqlalchemy")
|
||||
|
||||
def make_option(name, **kw):
|
||||
callback_ = kw.pop("callback", None)
|
||||
if callback_:
|
||||
|
||||
class CallableAction(argparse.Action):
|
||||
def __call__(
|
||||
self, parser, namespace, values, option_string=None
|
||||
):
|
||||
callback_(option_string, values, parser)
|
||||
|
||||
kw["action"] = CallableAction
|
||||
|
||||
zeroarg_callback = kw.pop("zeroarg_callback", None)
|
||||
if zeroarg_callback:
|
||||
|
||||
class CallableAction(argparse.Action):
|
||||
def __init__(
|
||||
self,
|
||||
option_strings,
|
||||
dest,
|
||||
default=False,
|
||||
required=False,
|
||||
help=None, # noqa
|
||||
):
|
||||
super().__init__(
|
||||
option_strings=option_strings,
|
||||
dest=dest,
|
||||
nargs=0,
|
||||
const=True,
|
||||
default=default,
|
||||
required=required,
|
||||
help=help,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, parser, namespace, values, option_string=None
|
||||
):
|
||||
zeroarg_callback(option_string, values, parser)
|
||||
|
||||
kw["action"] = CallableAction
|
||||
|
||||
group.addoption(name, **kw)
|
||||
|
||||
plugin_base.setup_options(make_option)
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config):
|
||||
plugin_base.read_config(config.rootpath)
|
||||
if plugin_base.exclude_tags or plugin_base.include_tags:
|
||||
new_expr = " and ".join(
|
||||
list(plugin_base.include_tags)
|
||||
+ [f"not {tag}" for tag in plugin_base.exclude_tags]
|
||||
)
|
||||
|
||||
if config.option.markexpr:
|
||||
config.option.markexpr += f" and {new_expr}"
|
||||
else:
|
||||
config.option.markexpr = new_expr
|
||||
|
||||
if config.pluginmanager.hasplugin("xdist"):
|
||||
config.pluginmanager.register(XDistHooks())
|
||||
|
||||
if hasattr(config, "workerinput"):
|
||||
plugin_base.restore_important_follower_config(config.workerinput)
|
||||
plugin_base.configure_follower(config.workerinput["follower_ident"])
|
||||
else:
|
||||
if config.option.write_idents and os.path.exists(
|
||||
config.option.write_idents
|
||||
):
|
||||
os.remove(config.option.write_idents)
|
||||
|
||||
plugin_base.pre_begin(config.option)
|
||||
|
||||
plugin_base.set_coverage_flag(
|
||||
bool(getattr(config.option, "cov_source", False))
|
||||
)
|
||||
|
||||
plugin_base.set_fixture_functions(PytestFixtureFunctions)
|
||||
|
||||
if config.option.dump_pyannotate:
|
||||
global DUMP_PYANNOTATE
|
||||
DUMP_PYANNOTATE = True
|
||||
|
||||
|
||||
DUMP_PYANNOTATE = False
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def collect_types_fixture():
|
||||
if DUMP_PYANNOTATE:
|
||||
from pyannotate_runtime import collect_types
|
||||
|
||||
collect_types.start()
|
||||
yield
|
||||
if DUMP_PYANNOTATE:
|
||||
collect_types.stop()
|
||||
|
||||
|
||||
def _log_sqlalchemy_info(session):
|
||||
import sqlalchemy
|
||||
from sqlalchemy import __version__
|
||||
from sqlalchemy.util import has_compiled_ext
|
||||
from sqlalchemy.util._has_cy import _CYEXTENSION_MSG
|
||||
|
||||
greet = "sqlalchemy installation"
|
||||
site = "no user site" if sys.flags.no_user_site else "user site loaded"
|
||||
msgs = [
|
||||
f"SQLAlchemy {__version__} ({site})",
|
||||
f"Path: {sqlalchemy.__file__}",
|
||||
]
|
||||
|
||||
if has_compiled_ext():
|
||||
from sqlalchemy.cyextension import util
|
||||
|
||||
msgs.append(f"compiled extension enabled, e.g. {util.__file__} ")
|
||||
else:
|
||||
msgs.append(f"compiled extension not enabled; {_CYEXTENSION_MSG}")
|
||||
|
||||
pm = session.config.pluginmanager.get_plugin("terminalreporter")
|
||||
if pm:
|
||||
pm.write_sep("=", greet)
|
||||
for m in msgs:
|
||||
pm.write_line(m)
|
||||
else:
|
||||
# fancy pants reporter not found, fallback to plain print
|
||||
print("=" * 25, greet, "=" * 25)
|
||||
for m in msgs:
|
||||
print(m)
|
||||
|
||||
|
||||
def pytest_sessionstart(session):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
_log_sqlalchemy_info(session)
|
||||
asyncio._assume_async(plugin_base.post_begin)
|
||||
|
||||
|
||||
def pytest_sessionfinish(session):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
asyncio._maybe_async_provisioning(plugin_base.final_process_cleanup)
|
||||
|
||||
if session.config.option.dump_pyannotate:
|
||||
from pyannotate_runtime import collect_types
|
||||
|
||||
collect_types.dump_stats(session.config.option.dump_pyannotate)
|
||||
|
||||
|
||||
def pytest_unconfigure(config):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
asyncio._shutdown()
|
||||
|
||||
|
||||
def pytest_collection_finish(session):
|
||||
if session.config.option.dump_pyannotate:
|
||||
from pyannotate_runtime import collect_types
|
||||
|
||||
lib_sqlalchemy = os.path.abspath("lib/sqlalchemy")
|
||||
|
||||
def _filter(filename):
|
||||
filename = os.path.normpath(os.path.abspath(filename))
|
||||
if "lib/sqlalchemy" not in os.path.commonpath(
|
||||
[filename, lib_sqlalchemy]
|
||||
):
|
||||
return None
|
||||
if "testing" in filename:
|
||||
return None
|
||||
|
||||
return filename
|
||||
|
||||
collect_types.init_types_collection(filter_filename=_filter)
|
||||
|
||||
|
||||
class XDistHooks:
|
||||
def pytest_configure_node(self, node):
|
||||
from sqlalchemy.testing import provision
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
# the master for each node fills workerinput dictionary
|
||||
# which pytest-xdist will transfer to the subprocess
|
||||
|
||||
plugin_base.memoize_important_follower_config(node.workerinput)
|
||||
|
||||
node.workerinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]
|
||||
|
||||
asyncio._maybe_async_provisioning(
|
||||
provision.create_follower_db, node.workerinput["follower_ident"]
|
||||
)
|
||||
|
||||
def pytest_testnodedown(self, node, error):
|
||||
from sqlalchemy.testing import provision
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
asyncio._maybe_async_provisioning(
|
||||
provision.drop_follower_db, node.workerinput["follower_ident"]
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(session, config, items):
|
||||
# look for all those classes that specify __backend__ and
|
||||
# expand them out into per-database test cases.
|
||||
|
||||
# this is much easier to do within pytest_pycollect_makeitem, however
|
||||
# pytest is iterating through cls.__dict__ as makeitem is
|
||||
# called which causes a "dictionary changed size" error on py3k.
|
||||
# I'd submit a pullreq for them to turn it into a list first, but
|
||||
# it's to suit the rather odd use case here which is that we are adding
|
||||
# new classes to a module on the fly.
|
||||
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
rebuilt_items = collections.defaultdict(
|
||||
lambda: collections.defaultdict(list)
|
||||
)
|
||||
|
||||
items[:] = [
|
||||
item
|
||||
for item in items
|
||||
if item.getparent(pytest.Class) is not None
|
||||
and not item.getparent(pytest.Class).name.startswith("_")
|
||||
]
|
||||
|
||||
test_classes = {item.getparent(pytest.Class) for item in items}
|
||||
|
||||
def collect(element):
|
||||
for inst_or_fn in element.collect():
|
||||
if isinstance(inst_or_fn, pytest.Collector):
|
||||
yield from collect(inst_or_fn)
|
||||
else:
|
||||
yield inst_or_fn
|
||||
|
||||
def setup_test_classes():
|
||||
for test_class in test_classes:
|
||||
# transfer legacy __backend__ and __sparse_backend__ symbols
|
||||
# to be markers
|
||||
add_markers = set()
|
||||
if getattr(test_class.cls, "__backend__", False) or getattr(
|
||||
test_class.cls, "__only_on__", False
|
||||
):
|
||||
add_markers = {"backend"}
|
||||
elif getattr(test_class.cls, "__sparse_backend__", False):
|
||||
add_markers = {"sparse_backend"}
|
||||
else:
|
||||
add_markers = frozenset()
|
||||
|
||||
existing_markers = {
|
||||
mark.name for mark in test_class.iter_markers()
|
||||
}
|
||||
add_markers = add_markers - existing_markers
|
||||
all_markers = existing_markers.union(add_markers)
|
||||
|
||||
for marker in add_markers:
|
||||
test_class.add_marker(marker)
|
||||
|
||||
for sub_cls in plugin_base.generate_sub_tests(
|
||||
test_class.cls, test_class.module, all_markers
|
||||
):
|
||||
if sub_cls is not test_class.cls:
|
||||
per_cls_dict = rebuilt_items[test_class.cls]
|
||||
|
||||
module = test_class.getparent(pytest.Module)
|
||||
|
||||
new_cls = pytest.Class.from_parent(
|
||||
name=sub_cls.__name__, parent=module
|
||||
)
|
||||
for marker in add_markers:
|
||||
new_cls.add_marker(marker)
|
||||
|
||||
for fn in collect(new_cls):
|
||||
per_cls_dict[fn.name].append(fn)
|
||||
|
||||
# class requirements will sometimes need to access the DB to check
|
||||
# capabilities, so need to do this for async
|
||||
asyncio._maybe_async_provisioning(setup_test_classes)
|
||||
|
||||
newitems = []
|
||||
for item in items:
|
||||
cls_ = item.cls
|
||||
if cls_ in rebuilt_items:
|
||||
newitems.extend(rebuilt_items[cls_][item.name])
|
||||
else:
|
||||
newitems.append(item)
|
||||
|
||||
# seems like the functions attached to a test class aren't sorted already?
|
||||
# is that true and why's that? (when using unittest, they're sorted)
|
||||
items[:] = sorted(
|
||||
newitems,
|
||||
key=lambda item: (
|
||||
item.getparent(pytest.Module).name,
|
||||
item.getparent(pytest.Class).name,
|
||||
item.name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def pytest_pycollect_makeitem(collector, name, obj):
|
||||
if inspect.isclass(obj) and plugin_base.want_class(name, obj):
|
||||
from sqlalchemy.testing import config
|
||||
|
||||
if config.any_async:
|
||||
obj = _apply_maybe_async(obj)
|
||||
|
||||
return [
|
||||
pytest.Class.from_parent(
|
||||
name=parametrize_cls.__name__, parent=collector
|
||||
)
|
||||
for parametrize_cls in _parametrize_cls(collector.module, obj)
|
||||
]
|
||||
elif (
|
||||
inspect.isfunction(obj)
|
||||
and collector.cls is not None
|
||||
and plugin_base.want_method(collector.cls, obj)
|
||||
):
|
||||
# None means, fall back to default logic, which includes
|
||||
# method-level parametrize
|
||||
return None
|
||||
else:
|
||||
# empty list means skip this item
|
||||
return []
|
||||
|
||||
|
||||
def _is_wrapped_coroutine_function(fn):
|
||||
while hasattr(fn, "__wrapped__"):
|
||||
fn = fn.__wrapped__
|
||||
|
||||
return inspect.iscoroutinefunction(fn)
|
||||
|
||||
|
||||
def _apply_maybe_async(obj, recurse=True):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
for name, value in vars(obj).items():
|
||||
if (
|
||||
(callable(value) or isinstance(value, classmethod))
|
||||
and not getattr(value, "_maybe_async_applied", False)
|
||||
and (name.startswith("test_"))
|
||||
and not _is_wrapped_coroutine_function(value)
|
||||
):
|
||||
is_classmethod = False
|
||||
if isinstance(value, classmethod):
|
||||
value = value.__func__
|
||||
is_classmethod = True
|
||||
|
||||
@_pytest_fn_decorator
|
||||
def make_async(fn, *args, **kwargs):
|
||||
return asyncio._maybe_async(fn, *args, **kwargs)
|
||||
|
||||
do_async = make_async(value)
|
||||
if is_classmethod:
|
||||
do_async = classmethod(do_async)
|
||||
do_async._maybe_async_applied = True
|
||||
|
||||
setattr(obj, name, do_async)
|
||||
if recurse:
|
||||
for cls in obj.mro()[1:]:
|
||||
if cls != object:
|
||||
_apply_maybe_async(cls, False)
|
||||
return obj
|
||||
|
||||
|
||||
def _parametrize_cls(module, cls):
|
||||
"""implement a class-based version of pytest parametrize."""
|
||||
|
||||
if "_sa_parametrize" not in cls.__dict__:
|
||||
return [cls]
|
||||
|
||||
_sa_parametrize = cls._sa_parametrize
|
||||
classes = []
|
||||
for full_param_set in itertools.product(
|
||||
*[params for argname, params in _sa_parametrize]
|
||||
):
|
||||
cls_variables = {}
|
||||
|
||||
for argname, param in zip(
|
||||
[_sa_param[0] for _sa_param in _sa_parametrize], full_param_set
|
||||
):
|
||||
if not argname:
|
||||
raise TypeError("need argnames for class-based combinations")
|
||||
argname_split = re.split(r",\s*", argname)
|
||||
for arg, val in zip(argname_split, param.values):
|
||||
cls_variables[arg] = val
|
||||
parametrized_name = "_".join(
|
||||
re.sub(r"\W", "", token)
|
||||
for param in full_param_set
|
||||
for token in param.id.split("-")
|
||||
)
|
||||
name = "%s_%s" % (cls.__name__, parametrized_name)
|
||||
newcls = type.__new__(type, name, (cls,), cls_variables)
|
||||
setattr(module, name, newcls)
|
||||
classes.append(newcls)
|
||||
return classes
|
||||
|
||||
|
||||
_current_class = None
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
# pytest_runtest_setup runs *before* pytest fixtures with scope="class".
|
||||
# plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest
|
||||
# for the whole class and has to run things that are across all current
|
||||
# databases, so we run this outside of the pytest fixture system altogether
|
||||
# and ensure asyncio greenlet if any engines are async
|
||||
|
||||
global _current_class
|
||||
|
||||
if isinstance(item, pytest.Function) and _current_class is None:
|
||||
asyncio._maybe_async_provisioning(
|
||||
plugin_base.start_test_class_outside_fixtures,
|
||||
item.cls,
|
||||
)
|
||||
_current_class = item.getparent(pytest.Class)
|
||||
|
||||
|
||||
@pytest.hookimpl(hookwrapper=True)
|
||||
def pytest_runtest_teardown(item, nextitem):
|
||||
# runs inside of pytest function fixture scope
|
||||
# after test function runs
|
||||
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
asyncio._maybe_async(plugin_base.after_test, item)
|
||||
|
||||
yield
|
||||
# this is now after all the fixture teardown have run, the class can be
|
||||
# finalized. Since pytest v7 this finalizer can no longer be added in
|
||||
# pytest_runtest_setup since the class has not yet been setup at that
|
||||
# time.
|
||||
# See https://github.com/pytest-dev/pytest/issues/9343
|
||||
global _current_class, _current_report
|
||||
|
||||
if _current_class is not None and (
|
||||
# last test or a new class
|
||||
nextitem is None
|
||||
or nextitem.getparent(pytest.Class) is not _current_class
|
||||
):
|
||||
_current_class = None
|
||||
|
||||
try:
|
||||
asyncio._maybe_async_provisioning(
|
||||
plugin_base.stop_test_class_outside_fixtures, item.cls
|
||||
)
|
||||
except Exception as e:
|
||||
# in case of an exception during teardown attach the original
|
||||
# error to the exception message, otherwise it will get lost
|
||||
if _current_report.failed:
|
||||
if not e.args:
|
||||
e.args = (
|
||||
"__Original test failure__:\n"
|
||||
+ _current_report.longreprtext,
|
||||
)
|
||||
elif e.args[-1] and isinstance(e.args[-1], str):
|
||||
args = list(e.args)
|
||||
args[-1] += (
|
||||
"\n__Original test failure__:\n"
|
||||
+ _current_report.longreprtext
|
||||
)
|
||||
e.args = tuple(args)
|
||||
else:
|
||||
e.args += (
|
||||
"__Original test failure__",
|
||||
_current_report.longreprtext,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
_current_report = None
|
||||
|
||||
|
||||
def pytest_runtest_call(item):
|
||||
# runs inside of pytest function fixture scope
|
||||
# before test function runs
|
||||
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
asyncio._maybe_async(
|
||||
plugin_base.before_test,
|
||||
item,
|
||||
item.module.__name__,
|
||||
item.cls,
|
||||
item.name,
|
||||
)
|
||||
|
||||
|
||||
_current_report = None
|
||||
|
||||
|
||||
def pytest_runtest_logreport(report):
|
||||
global _current_report
|
||||
if report.when == "call":
|
||||
_current_report = report
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def setup_class_methods(request):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
cls = request.cls
|
||||
|
||||
if hasattr(cls, "setup_test_class"):
|
||||
asyncio._maybe_async(cls.setup_test_class)
|
||||
|
||||
yield
|
||||
|
||||
if hasattr(cls, "teardown_test_class"):
|
||||
asyncio._maybe_async(cls.teardown_test_class)
|
||||
|
||||
asyncio._maybe_async(plugin_base.stop_test_class, cls)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def setup_test_methods(request):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
# called for each test
|
||||
|
||||
self = request.instance
|
||||
|
||||
# before this fixture runs:
|
||||
|
||||
# 1. function level "autouse" fixtures under py3k (examples: TablesTest
|
||||
# define tables / data, MappedTest define tables / mappers / data)
|
||||
|
||||
# 2. was for p2k. no longer applies
|
||||
|
||||
# 3. run outer xdist-style setup
|
||||
if hasattr(self, "setup_test"):
|
||||
asyncio._maybe_async(self.setup_test)
|
||||
|
||||
# alembic test suite is using setUp and tearDown
|
||||
# xdist methods; support these in the test suite
|
||||
# for the near term
|
||||
if hasattr(self, "setUp"):
|
||||
asyncio._maybe_async(self.setUp)
|
||||
|
||||
# inside the yield:
|
||||
# 4. function level fixtures defined on test functions themselves,
|
||||
# e.g. "connection", "metadata" run next
|
||||
|
||||
# 5. pytest hook pytest_runtest_call then runs
|
||||
|
||||
# 6. test itself runs
|
||||
|
||||
yield
|
||||
|
||||
# yield finishes:
|
||||
|
||||
# 7. function level fixtures defined on test functions
|
||||
# themselves, e.g. "connection" rolls back the transaction, "metadata"
|
||||
# emits drop all
|
||||
|
||||
# 8. pytest hook pytest_runtest_teardown hook runs, this is associated
|
||||
# with fixtures close all sessions, provisioning.stop_test_class(),
|
||||
# engines.testing_reaper -> ensure all connection pool connections
|
||||
# are returned, engines created by testing_engine that aren't the
|
||||
# config engine are disposed
|
||||
|
||||
asyncio._maybe_async(plugin_base.after_test_fixtures, self)
|
||||
|
||||
# 10. run xdist-style teardown
|
||||
if hasattr(self, "tearDown"):
|
||||
asyncio._maybe_async(self.tearDown)
|
||||
|
||||
if hasattr(self, "teardown_test"):
|
||||
asyncio._maybe_async(self.teardown_test)
|
||||
|
||||
# 11. was for p2k. no longer applies
|
||||
|
||||
# 12. function level "autouse" fixtures under py3k (examples: TablesTest /
|
||||
# MappedTest delete table data, possibly drop tables and clear mappers
|
||||
# depending on the flags defined by the test class)
|
||||
|
||||
|
||||
def _pytest_fn_decorator(target):
|
||||
"""Port of langhelpers.decorator with pytest-specific tricks."""
|
||||
|
||||
from sqlalchemy.util.langhelpers import format_argspec_plus
|
||||
from sqlalchemy.util.compat import inspect_getfullargspec
|
||||
|
||||
def _exec_code_in_env(code, env, fn_name):
|
||||
# note this is affected by "from __future__ import annotations" at
|
||||
# the top; exec'ed code will use non-evaluated annotations
|
||||
# which allows us to be more flexible with code rendering
|
||||
# in format_argpsec_plus()
|
||||
exec(code, env)
|
||||
return env[fn_name]
|
||||
|
||||
def decorate(fn, add_positional_parameters=()):
|
||||
spec = inspect_getfullargspec(fn)
|
||||
if add_positional_parameters:
|
||||
spec.args.extend(add_positional_parameters)
|
||||
|
||||
metadata = dict(
|
||||
__target_fn="__target_fn", __orig_fn="__orig_fn", name=fn.__name__
|
||||
)
|
||||
metadata.update(format_argspec_plus(spec, grouped=False))
|
||||
code = (
|
||||
"""\
|
||||
def %(name)s%(grouped_args)s:
|
||||
return %(__target_fn)s(%(__orig_fn)s, %(apply_kw)s)
|
||||
"""
|
||||
% metadata
|
||||
)
|
||||
decorated = _exec_code_in_env(
|
||||
code, {"__target_fn": target, "__orig_fn": fn}, fn.__name__
|
||||
)
|
||||
if not add_positional_parameters:
|
||||
decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
|
||||
decorated.__wrapped__ = fn
|
||||
return update_wrapper(decorated, fn)
|
||||
else:
|
||||
# this is the pytest hacky part. don't do a full update wrapper
|
||||
# because pytest is really being sneaky about finding the args
|
||||
# for the wrapped function
|
||||
decorated.__module__ = fn.__module__
|
||||
decorated.__name__ = fn.__name__
|
||||
if hasattr(fn, "pytestmark"):
|
||||
decorated.pytestmark = fn.pytestmark
|
||||
return decorated
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
class PytestFixtureFunctions(plugin_base.FixtureFunctions):
|
||||
def skip_test_exception(self, *arg, **kw):
|
||||
return pytest.skip.Exception(*arg, **kw)
|
||||
|
||||
@property
|
||||
def add_to_marker(self):
|
||||
return pytest.mark
|
||||
|
||||
def mark_base_test_class(self):
|
||||
return pytest.mark.usefixtures(
|
||||
"setup_class_methods", "setup_test_methods"
|
||||
)
|
||||
|
||||
_combination_id_fns = {
|
||||
"i": lambda obj: obj,
|
||||
"r": repr,
|
||||
"s": str,
|
||||
"n": lambda obj: (
|
||||
obj.__name__ if hasattr(obj, "__name__") else type(obj).__name__
|
||||
),
|
||||
}
|
||||
|
||||
def combinations(self, *arg_sets, **kw):
|
||||
"""Facade for pytest.mark.parametrize.
|
||||
|
||||
Automatically derives argument names from the callable which in our
|
||||
case is always a method on a class with positional arguments.
|
||||
|
||||
ids for parameter sets are derived using an optional template.
|
||||
|
||||
"""
|
||||
from sqlalchemy.testing import exclusions
|
||||
|
||||
if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"):
|
||||
arg_sets = list(arg_sets[0])
|
||||
|
||||
argnames = kw.pop("argnames", None)
|
||||
|
||||
def _filter_exclusions(args):
|
||||
result = []
|
||||
gathered_exclusions = []
|
||||
for a in args:
|
||||
if isinstance(a, exclusions.compound):
|
||||
gathered_exclusions.append(a)
|
||||
else:
|
||||
result.append(a)
|
||||
|
||||
return result, gathered_exclusions
|
||||
|
||||
id_ = kw.pop("id_", None)
|
||||
|
||||
tobuild_pytest_params = []
|
||||
has_exclusions = False
|
||||
if id_:
|
||||
_combination_id_fns = self._combination_id_fns
|
||||
|
||||
# because itemgetter is not consistent for one argument vs.
|
||||
# multiple, make it multiple in all cases and use a slice
|
||||
# to omit the first argument
|
||||
_arg_getter = operator.itemgetter(
|
||||
0,
|
||||
*[
|
||||
idx
|
||||
for idx, char in enumerate(id_)
|
||||
if char in ("n", "r", "s", "a")
|
||||
],
|
||||
)
|
||||
fns = [
|
||||
(operator.itemgetter(idx), _combination_id_fns[char])
|
||||
for idx, char in enumerate(id_)
|
||||
if char in _combination_id_fns
|
||||
]
|
||||
|
||||
for arg in arg_sets:
|
||||
if not isinstance(arg, tuple):
|
||||
arg = (arg,)
|
||||
|
||||
fn_params, param_exclusions = _filter_exclusions(arg)
|
||||
|
||||
parameters = _arg_getter(fn_params)[1:]
|
||||
|
||||
if param_exclusions:
|
||||
has_exclusions = True
|
||||
|
||||
tobuild_pytest_params.append(
|
||||
(
|
||||
parameters,
|
||||
param_exclusions,
|
||||
"-".join(
|
||||
comb_fn(getter(arg)) for getter, comb_fn in fns
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
for arg in arg_sets:
|
||||
if not isinstance(arg, tuple):
|
||||
arg = (arg,)
|
||||
|
||||
fn_params, param_exclusions = _filter_exclusions(arg)
|
||||
|
||||
if param_exclusions:
|
||||
has_exclusions = True
|
||||
|
||||
tobuild_pytest_params.append(
|
||||
(fn_params, param_exclusions, None)
|
||||
)
|
||||
|
||||
pytest_params = []
|
||||
for parameters, param_exclusions, id_ in tobuild_pytest_params:
|
||||
if has_exclusions:
|
||||
parameters += (param_exclusions,)
|
||||
|
||||
param = pytest.param(*parameters, id=id_)
|
||||
pytest_params.append(param)
|
||||
|
||||
def decorate(fn):
|
||||
if inspect.isclass(fn):
|
||||
if has_exclusions:
|
||||
raise NotImplementedError(
|
||||
"exclusions not supported for class level combinations"
|
||||
)
|
||||
if "_sa_parametrize" not in fn.__dict__:
|
||||
fn._sa_parametrize = []
|
||||
fn._sa_parametrize.append((argnames, pytest_params))
|
||||
return fn
|
||||
else:
|
||||
_fn_argnames = inspect.getfullargspec(fn).args[1:]
|
||||
if argnames is None:
|
||||
_argnames = _fn_argnames
|
||||
else:
|
||||
_argnames = re.split(r", *", argnames)
|
||||
|
||||
if has_exclusions:
|
||||
existing_exl = sum(
|
||||
1 for n in _fn_argnames if n.startswith("_exclusions")
|
||||
)
|
||||
current_exclusion_name = f"_exclusions_{existing_exl}"
|
||||
_argnames += [current_exclusion_name]
|
||||
|
||||
@_pytest_fn_decorator
|
||||
def check_exclusions(fn, *args, **kw):
|
||||
_exclusions = args[-1]
|
||||
if _exclusions:
|
||||
exlu = exclusions.compound().add(*_exclusions)
|
||||
fn = exlu(fn)
|
||||
return fn(*args[:-1], **kw)
|
||||
|
||||
fn = check_exclusions(
|
||||
fn, add_positional_parameters=(current_exclusion_name,)
|
||||
)
|
||||
|
||||
return pytest.mark.parametrize(_argnames, pytest_params)(fn)
|
||||
|
||||
return decorate
|
||||
|
||||
def param_ident(self, *parameters):
|
||||
ident = parameters[0]
|
||||
return pytest.param(*parameters[1:], id=ident)
|
||||
|
||||
def fixture(self, *arg, **kw):
|
||||
from sqlalchemy.testing import config
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
# wrapping pytest.fixture function. determine if
|
||||
# decorator was called as @fixture or @fixture().
|
||||
if len(arg) > 0 and callable(arg[0]):
|
||||
# was called as @fixture(), we have the function to wrap.
|
||||
fn = arg[0]
|
||||
arg = arg[1:]
|
||||
else:
|
||||
# was called as @fixture, don't have the function yet.
|
||||
fn = None
|
||||
|
||||
# create a pytest.fixture marker. because the fn is not being
|
||||
# passed, this is always a pytest.FixtureFunctionMarker()
|
||||
# object (or whatever pytest is calling it when you read this)
|
||||
# that is waiting for a function.
|
||||
fixture = pytest.fixture(*arg, **kw)
|
||||
|
||||
# now apply wrappers to the function, including fixture itself
|
||||
|
||||
def wrap(fn):
|
||||
if config.any_async:
|
||||
fn = asyncio._maybe_async_wrapper(fn)
|
||||
# other wrappers may be added here
|
||||
|
||||
# now apply FixtureFunctionMarker
|
||||
fn = fixture(fn)
|
||||
|
||||
return fn
|
||||
|
||||
if fn:
|
||||
return wrap(fn)
|
||||
else:
|
||||
return wrap
|
||||
|
||||
def get_current_test_name(self):
|
||||
return os.environ.get("PYTEST_CURRENT_TEST")
|
||||
|
||||
def async_test(self, fn):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
@_pytest_fn_decorator
|
||||
def decorate(fn, *args, **kwargs):
|
||||
asyncio._run_coroutine_function(fn, *args, **kwargs)
|
||||
|
||||
return decorate(fn)
|
@ -0,0 +1,324 @@
|
||||
# testing/profiling.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
"""Profiling support for unit and performance tests.
|
||||
|
||||
These are special purpose profiling methods which operate
|
||||
in a more fine-grained way than nose's profiling plugin.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import os
|
||||
import platform
|
||||
import pstats
|
||||
import re
|
||||
import sys
|
||||
|
||||
from . import config
|
||||
from .util import gc_collect
|
||||
from ..util import has_compiled_ext
|
||||
|
||||
|
||||
try:
|
||||
import cProfile
|
||||
except ImportError:
|
||||
cProfile = None
|
||||
|
||||
_profile_stats = None
|
||||
"""global ProfileStatsFileInstance.
|
||||
|
||||
plugin_base assigns this at the start of all tests.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
_current_test = None
|
||||
"""String id of current test.
|
||||
|
||||
plugin_base assigns this at the start of each test using
|
||||
_start_current_test.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def _start_current_test(id_):
|
||||
global _current_test
|
||||
_current_test = id_
|
||||
|
||||
if _profile_stats.force_write:
|
||||
_profile_stats.reset_count()
|
||||
|
||||
|
||||
class ProfileStatsFile:
|
||||
"""Store per-platform/fn profiling results in a file.
|
||||
|
||||
There was no json module available when this was written, but now
|
||||
the file format which is very deterministically line oriented is kind of
|
||||
handy in any case for diffs and merges.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, filename, sort="cumulative", dump=None):
|
||||
self.force_write = (
|
||||
config.options is not None and config.options.force_write_profiles
|
||||
)
|
||||
self.write = self.force_write or (
|
||||
config.options is not None and config.options.write_profiles
|
||||
)
|
||||
self.fname = os.path.abspath(filename)
|
||||
self.short_fname = os.path.split(self.fname)[-1]
|
||||
self.data = collections.defaultdict(
|
||||
lambda: collections.defaultdict(dict)
|
||||
)
|
||||
self.dump = dump
|
||||
self.sort = sort
|
||||
self._read()
|
||||
if self.write:
|
||||
# rewrite for the case where features changed,
|
||||
# etc.
|
||||
self._write()
|
||||
|
||||
@property
|
||||
def platform_key(self):
|
||||
dbapi_key = config.db.name + "_" + config.db.driver
|
||||
|
||||
if config.db.name == "sqlite" and config.db.dialect._is_url_file_db(
|
||||
config.db.url
|
||||
):
|
||||
dbapi_key += "_file"
|
||||
|
||||
# keep it at 2.7, 3.1, 3.2, etc. for now.
|
||||
py_version = ".".join([str(v) for v in sys.version_info[0:2]])
|
||||
|
||||
platform_tokens = [
|
||||
platform.machine(),
|
||||
platform.system().lower(),
|
||||
platform.python_implementation().lower(),
|
||||
py_version,
|
||||
dbapi_key,
|
||||
]
|
||||
|
||||
platform_tokens.append("dbapiunicode")
|
||||
_has_cext = has_compiled_ext()
|
||||
platform_tokens.append(_has_cext and "cextensions" or "nocextensions")
|
||||
return "_".join(platform_tokens)
|
||||
|
||||
def has_stats(self):
|
||||
test_key = _current_test
|
||||
return (
|
||||
test_key in self.data and self.platform_key in self.data[test_key]
|
||||
)
|
||||
|
||||
def result(self, callcount):
|
||||
test_key = _current_test
|
||||
per_fn = self.data[test_key]
|
||||
per_platform = per_fn[self.platform_key]
|
||||
|
||||
if "counts" not in per_platform:
|
||||
per_platform["counts"] = counts = []
|
||||
else:
|
||||
counts = per_platform["counts"]
|
||||
|
||||
if "current_count" not in per_platform:
|
||||
per_platform["current_count"] = current_count = 0
|
||||
else:
|
||||
current_count = per_platform["current_count"]
|
||||
|
||||
has_count = len(counts) > current_count
|
||||
|
||||
if not has_count:
|
||||
counts.append(callcount)
|
||||
if self.write:
|
||||
self._write()
|
||||
result = None
|
||||
else:
|
||||
result = per_platform["lineno"], counts[current_count]
|
||||
per_platform["current_count"] += 1
|
||||
return result
|
||||
|
||||
def reset_count(self):
|
||||
test_key = _current_test
|
||||
# since self.data is a defaultdict, don't access a key
|
||||
# if we don't know it's there first.
|
||||
if test_key not in self.data:
|
||||
return
|
||||
per_fn = self.data[test_key]
|
||||
if self.platform_key not in per_fn:
|
||||
return
|
||||
per_platform = per_fn[self.platform_key]
|
||||
if "counts" in per_platform:
|
||||
per_platform["counts"][:] = []
|
||||
|
||||
def replace(self, callcount):
|
||||
test_key = _current_test
|
||||
per_fn = self.data[test_key]
|
||||
per_platform = per_fn[self.platform_key]
|
||||
counts = per_platform["counts"]
|
||||
current_count = per_platform["current_count"]
|
||||
if current_count < len(counts):
|
||||
counts[current_count - 1] = callcount
|
||||
else:
|
||||
counts[-1] = callcount
|
||||
if self.write:
|
||||
self._write()
|
||||
|
||||
def _header(self):
|
||||
return (
|
||||
"# %s\n"
|
||||
"# This file is written out on a per-environment basis.\n"
|
||||
"# For each test in aaa_profiling, the corresponding "
|
||||
"function and \n"
|
||||
"# environment is located within this file. "
|
||||
"If it doesn't exist,\n"
|
||||
"# the test is skipped.\n"
|
||||
"# If a callcount does exist, it is compared "
|
||||
"to what we received. \n"
|
||||
"# assertions are raised if the counts do not match.\n"
|
||||
"# \n"
|
||||
"# To add a new callcount test, apply the function_call_count \n"
|
||||
"# decorator and re-run the tests using the --write-profiles \n"
|
||||
"# option - this file will be rewritten including the new count.\n"
|
||||
"# \n"
|
||||
) % (self.fname)
|
||||
|
||||
def _read(self):
|
||||
try:
|
||||
profile_f = open(self.fname)
|
||||
except OSError:
|
||||
return
|
||||
for lineno, line in enumerate(profile_f):
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
test_key, platform_key, counts = line.split()
|
||||
per_fn = self.data[test_key]
|
||||
per_platform = per_fn[platform_key]
|
||||
c = [int(count) for count in counts.split(",")]
|
||||
per_platform["counts"] = c
|
||||
per_platform["lineno"] = lineno + 1
|
||||
per_platform["current_count"] = 0
|
||||
profile_f.close()
|
||||
|
||||
def _write(self):
|
||||
print("Writing profile file %s" % self.fname)
|
||||
profile_f = open(self.fname, "w")
|
||||
profile_f.write(self._header())
|
||||
for test_key in sorted(self.data):
|
||||
per_fn = self.data[test_key]
|
||||
profile_f.write("\n# TEST: %s\n\n" % test_key)
|
||||
for platform_key in sorted(per_fn):
|
||||
per_platform = per_fn[platform_key]
|
||||
c = ",".join(str(count) for count in per_platform["counts"])
|
||||
profile_f.write("%s %s %s\n" % (test_key, platform_key, c))
|
||||
profile_f.close()
|
||||
|
||||
|
||||
def function_call_count(variance=0.05, times=1, warmup=0):
|
||||
"""Assert a target for a test case's function call count.
|
||||
|
||||
The main purpose of this assertion is to detect changes in
|
||||
callcounts for various functions - the actual number is not as important.
|
||||
Callcounts are stored in a file keyed to Python version and OS platform
|
||||
information. This file is generated automatically for new tests,
|
||||
and versioned so that unexpected changes in callcounts will be detected.
|
||||
|
||||
"""
|
||||
|
||||
# use signature-rewriting decorator function so that pytest fixtures
|
||||
# still work on py27. In Py3, update_wrapper() alone is good enough,
|
||||
# likely due to the introduction of __signature__.
|
||||
|
||||
from sqlalchemy.util import decorator
|
||||
|
||||
@decorator
|
||||
def wrap(fn, *args, **kw):
|
||||
for warm in range(warmup):
|
||||
fn(*args, **kw)
|
||||
|
||||
timerange = range(times)
|
||||
with count_functions(variance=variance):
|
||||
for time in timerange:
|
||||
rv = fn(*args, **kw)
|
||||
return rv
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def count_functions(variance=0.05):
|
||||
if cProfile is None:
|
||||
raise config._skip_test_exception("cProfile is not installed")
|
||||
|
||||
if not _profile_stats.has_stats() and not _profile_stats.write:
|
||||
config.skip_test(
|
||||
"No profiling stats available on this "
|
||||
"platform for this function. Run tests with "
|
||||
"--write-profiles to add statistics to %s for "
|
||||
"this platform." % _profile_stats.short_fname
|
||||
)
|
||||
|
||||
gc_collect()
|
||||
|
||||
pr = cProfile.Profile()
|
||||
pr.enable()
|
||||
# began = time.time()
|
||||
yield
|
||||
# ended = time.time()
|
||||
pr.disable()
|
||||
|
||||
# s = StringIO()
|
||||
stats = pstats.Stats(pr, stream=sys.stdout)
|
||||
|
||||
# timespent = ended - began
|
||||
callcount = stats.total_calls
|
||||
|
||||
expected = _profile_stats.result(callcount)
|
||||
|
||||
if expected is None:
|
||||
expected_count = None
|
||||
else:
|
||||
line_no, expected_count = expected
|
||||
|
||||
print("Pstats calls: %d Expected %s" % (callcount, expected_count))
|
||||
stats.sort_stats(*re.split(r"[, ]", _profile_stats.sort))
|
||||
stats.print_stats()
|
||||
if _profile_stats.dump:
|
||||
base, ext = os.path.splitext(_profile_stats.dump)
|
||||
test_name = _current_test.split(".")[-1]
|
||||
dumpfile = "%s_%s%s" % (base, test_name, ext or ".profile")
|
||||
stats.dump_stats(dumpfile)
|
||||
print("Dumped stats to file %s" % dumpfile)
|
||||
# stats.print_callers()
|
||||
if _profile_stats.force_write:
|
||||
_profile_stats.replace(callcount)
|
||||
elif expected_count:
|
||||
deviance = int(callcount * variance)
|
||||
failed = abs(callcount - expected_count) > deviance
|
||||
|
||||
if failed:
|
||||
if _profile_stats.write:
|
||||
_profile_stats.replace(callcount)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Adjusted function call count %s not within %s%% "
|
||||
"of expected %s, platform %s. Rerun with "
|
||||
"--write-profiles to "
|
||||
"regenerate this callcount."
|
||||
% (
|
||||
callcount,
|
||||
(variance * 100),
|
||||
expected_count,
|
||||
_profile_stats.platform_key,
|
||||
)
|
||||
)
|
@ -0,0 +1,502 @@
|
||||
# testing/provision.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import logging
|
||||
|
||||
from . import config
|
||||
from . import engines
|
||||
from . import util
|
||||
from .. import exc
|
||||
from .. import inspect
|
||||
from ..engine import url as sa_url
|
||||
from ..sql import ddl
|
||||
from ..sql import schema
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
FOLLOWER_IDENT = None
|
||||
|
||||
|
||||
class register:
|
||||
def __init__(self, decorator=None):
|
||||
self.fns = {}
|
||||
self.decorator = decorator
|
||||
|
||||
@classmethod
|
||||
def init(cls, fn):
|
||||
return register().for_db("*")(fn)
|
||||
|
||||
@classmethod
|
||||
def init_decorator(cls, decorator):
|
||||
return register(decorator).for_db("*")
|
||||
|
||||
def for_db(self, *dbnames):
|
||||
def decorate(fn):
|
||||
if self.decorator:
|
||||
fn = self.decorator(fn)
|
||||
for dbname in dbnames:
|
||||
self.fns[dbname] = fn
|
||||
return self
|
||||
|
||||
return decorate
|
||||
|
||||
def __call__(self, cfg, *arg, **kw):
|
||||
if isinstance(cfg, str):
|
||||
url = sa_url.make_url(cfg)
|
||||
elif isinstance(cfg, sa_url.URL):
|
||||
url = cfg
|
||||
else:
|
||||
url = cfg.db.url
|
||||
backend = url.get_backend_name()
|
||||
if backend in self.fns:
|
||||
return self.fns[backend](cfg, *arg, **kw)
|
||||
else:
|
||||
return self.fns["*"](cfg, *arg, **kw)
|
||||
|
||||
|
||||
def create_follower_db(follower_ident):
|
||||
for cfg in _configs_for_db_operation():
|
||||
log.info("CREATE database %s, URI %r", follower_ident, cfg.db.url)
|
||||
create_db(cfg, cfg.db, follower_ident)
|
||||
|
||||
|
||||
def setup_config(db_url, options, file_config, follower_ident):
|
||||
# load the dialect, which should also have it set up its provision
|
||||
# hooks
|
||||
|
||||
dialect = sa_url.make_url(db_url).get_dialect()
|
||||
|
||||
dialect.load_provisioning()
|
||||
|
||||
if follower_ident:
|
||||
db_url = follower_url_from_main(db_url, follower_ident)
|
||||
db_opts = {}
|
||||
update_db_opts(db_url, db_opts, options)
|
||||
db_opts["scope"] = "global"
|
||||
eng = engines.testing_engine(db_url, db_opts)
|
||||
post_configure_engine(db_url, eng, follower_ident)
|
||||
eng.connect().close()
|
||||
|
||||
cfg = config.Config.register(eng, db_opts, options, file_config)
|
||||
|
||||
# a symbolic name that tests can use if they need to disambiguate
|
||||
# names across databases
|
||||
if follower_ident:
|
||||
config.ident = follower_ident
|
||||
|
||||
if follower_ident:
|
||||
configure_follower(cfg, follower_ident)
|
||||
return cfg
|
||||
|
||||
|
||||
def drop_follower_db(follower_ident):
|
||||
for cfg in _configs_for_db_operation():
|
||||
log.info("DROP database %s, URI %r", follower_ident, cfg.db.url)
|
||||
drop_db(cfg, cfg.db, follower_ident)
|
||||
|
||||
|
||||
def generate_db_urls(db_urls, extra_drivers):
|
||||
"""Generate a set of URLs to test given configured URLs plus additional
|
||||
driver names.
|
||||
|
||||
Given:
|
||||
|
||||
.. sourcecode:: text
|
||||
|
||||
--dburi postgresql://db1 \
|
||||
--dburi postgresql://db2 \
|
||||
--dburi postgresql://db2 \
|
||||
--dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true
|
||||
|
||||
Noting that the default postgresql driver is psycopg2, the output
|
||||
would be:
|
||||
|
||||
.. sourcecode:: text
|
||||
|
||||
postgresql+psycopg2://db1
|
||||
postgresql+asyncpg://db1
|
||||
postgresql+psycopg2://db2
|
||||
postgresql+psycopg2://db3
|
||||
|
||||
That is, for the driver in a --dburi, we want to keep that and use that
|
||||
driver for each URL it's part of . For a driver that is only
|
||||
in --dbdrivers, we want to use it just once for one of the URLs.
|
||||
for a driver that is both coming from --dburi as well as --dbdrivers,
|
||||
we want to keep it in that dburi.
|
||||
|
||||
Driver specific query options can be specified by added them to the
|
||||
driver name. For example, to enable the async fallback option for
|
||||
asyncpg::
|
||||
|
||||
.. sourcecode:: text
|
||||
|
||||
--dburi postgresql://db1 \
|
||||
--dbdriver=asyncpg?async_fallback=true
|
||||
|
||||
"""
|
||||
urls = set()
|
||||
|
||||
backend_to_driver_we_already_have = collections.defaultdict(set)
|
||||
|
||||
urls_plus_dialects = [
|
||||
(url_obj, url_obj.get_dialect())
|
||||
for url_obj in [sa_url.make_url(db_url) for db_url in db_urls]
|
||||
]
|
||||
|
||||
for url_obj, dialect in urls_plus_dialects:
|
||||
# use get_driver_name instead of dialect.driver to account for
|
||||
# "_async" virtual drivers like oracledb and psycopg
|
||||
driver_name = url_obj.get_driver_name()
|
||||
backend_to_driver_we_already_have[dialect.name].add(driver_name)
|
||||
|
||||
backend_to_driver_we_need = {}
|
||||
|
||||
for url_obj, dialect in urls_plus_dialects:
|
||||
backend = dialect.name
|
||||
dialect.load_provisioning()
|
||||
|
||||
if backend not in backend_to_driver_we_need:
|
||||
backend_to_driver_we_need[backend] = extra_per_backend = set(
|
||||
extra_drivers
|
||||
).difference(backend_to_driver_we_already_have[backend])
|
||||
else:
|
||||
extra_per_backend = backend_to_driver_we_need[backend]
|
||||
|
||||
for driver_url in _generate_driver_urls(url_obj, extra_per_backend):
|
||||
if driver_url in urls:
|
||||
continue
|
||||
urls.add(driver_url)
|
||||
yield driver_url
|
||||
|
||||
|
||||
def _generate_driver_urls(url, extra_drivers):
|
||||
main_driver = url.get_driver_name()
|
||||
extra_drivers.discard(main_driver)
|
||||
|
||||
url = generate_driver_url(url, main_driver, "")
|
||||
yield url
|
||||
|
||||
for drv in list(extra_drivers):
|
||||
if "?" in drv:
|
||||
driver_only, query_str = drv.split("?", 1)
|
||||
|
||||
else:
|
||||
driver_only = drv
|
||||
query_str = None
|
||||
|
||||
new_url = generate_driver_url(url, driver_only, query_str)
|
||||
if new_url:
|
||||
extra_drivers.remove(drv)
|
||||
|
||||
yield new_url
|
||||
|
||||
|
||||
@register.init
|
||||
def generate_driver_url(url, driver, query_str):
|
||||
backend = url.get_backend_name()
|
||||
|
||||
new_url = url.set(
|
||||
drivername="%s+%s" % (backend, driver),
|
||||
)
|
||||
if query_str:
|
||||
new_url = new_url.update_query_string(query_str)
|
||||
|
||||
try:
|
||||
new_url.get_dialect()
|
||||
except exc.NoSuchModuleError:
|
||||
return None
|
||||
else:
|
||||
return new_url
|
||||
|
||||
|
||||
def _configs_for_db_operation():
|
||||
hosts = set()
|
||||
|
||||
for cfg in config.Config.all_configs():
|
||||
cfg.db.dispose()
|
||||
|
||||
for cfg in config.Config.all_configs():
|
||||
url = cfg.db.url
|
||||
backend = url.get_backend_name()
|
||||
host_conf = (backend, url.username, url.host, url.database)
|
||||
|
||||
if host_conf not in hosts:
|
||||
yield cfg
|
||||
hosts.add(host_conf)
|
||||
|
||||
for cfg in config.Config.all_configs():
|
||||
cfg.db.dispose()
|
||||
|
||||
|
||||
@register.init
|
||||
def drop_all_schema_objects_pre_tables(cfg, eng):
|
||||
pass
|
||||
|
||||
|
||||
@register.init
|
||||
def drop_all_schema_objects_post_tables(cfg, eng):
|
||||
pass
|
||||
|
||||
|
||||
def drop_all_schema_objects(cfg, eng):
|
||||
drop_all_schema_objects_pre_tables(cfg, eng)
|
||||
|
||||
drop_views(cfg, eng)
|
||||
|
||||
if config.requirements.materialized_views.enabled:
|
||||
drop_materialized_views(cfg, eng)
|
||||
|
||||
inspector = inspect(eng)
|
||||
|
||||
consider_schemas = (None,)
|
||||
if config.requirements.schemas.enabled_for_config(cfg):
|
||||
consider_schemas += (cfg.test_schema, cfg.test_schema_2)
|
||||
util.drop_all_tables(eng, inspector, consider_schemas=consider_schemas)
|
||||
|
||||
drop_all_schema_objects_post_tables(cfg, eng)
|
||||
|
||||
if config.requirements.sequences.enabled_for_config(cfg):
|
||||
with eng.begin() as conn:
|
||||
for seq in inspector.get_sequence_names():
|
||||
conn.execute(ddl.DropSequence(schema.Sequence(seq)))
|
||||
if config.requirements.schemas.enabled_for_config(cfg):
|
||||
for schema_name in [cfg.test_schema, cfg.test_schema_2]:
|
||||
for seq in inspector.get_sequence_names(
|
||||
schema=schema_name
|
||||
):
|
||||
conn.execute(
|
||||
ddl.DropSequence(
|
||||
schema.Sequence(seq, schema=schema_name)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def drop_views(cfg, eng):
|
||||
inspector = inspect(eng)
|
||||
|
||||
try:
|
||||
view_names = inspector.get_view_names()
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
with eng.begin() as conn:
|
||||
for vname in view_names:
|
||||
conn.execute(
|
||||
ddl._DropView(schema.Table(vname, schema.MetaData()))
|
||||
)
|
||||
|
||||
if config.requirements.schemas.enabled_for_config(cfg):
|
||||
try:
|
||||
view_names = inspector.get_view_names(schema=cfg.test_schema)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
with eng.begin() as conn:
|
||||
for vname in view_names:
|
||||
conn.execute(
|
||||
ddl._DropView(
|
||||
schema.Table(
|
||||
vname,
|
||||
schema.MetaData(),
|
||||
schema=cfg.test_schema,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def drop_materialized_views(cfg, eng):
|
||||
inspector = inspect(eng)
|
||||
|
||||
mview_names = inspector.get_materialized_view_names()
|
||||
|
||||
with eng.begin() as conn:
|
||||
for vname in mview_names:
|
||||
conn.exec_driver_sql(f"DROP MATERIALIZED VIEW {vname}")
|
||||
|
||||
if config.requirements.schemas.enabled_for_config(cfg):
|
||||
mview_names = inspector.get_materialized_view_names(
|
||||
schema=cfg.test_schema
|
||||
)
|
||||
with eng.begin() as conn:
|
||||
for vname in mview_names:
|
||||
conn.exec_driver_sql(
|
||||
f"DROP MATERIALIZED VIEW {cfg.test_schema}.{vname}"
|
||||
)
|
||||
|
||||
|
||||
@register.init
|
||||
def create_db(cfg, eng, ident):
|
||||
"""Dynamically create a database for testing.
|
||||
|
||||
Used when a test run will employ multiple processes, e.g., when run
|
||||
via `tox` or `pytest -n4`.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"no DB creation routine for cfg: %s" % (eng.url,)
|
||||
)
|
||||
|
||||
|
||||
@register.init
|
||||
def drop_db(cfg, eng, ident):
|
||||
"""Drop a database that we dynamically created for testing."""
|
||||
raise NotImplementedError("no DB drop routine for cfg: %s" % (eng.url,))
|
||||
|
||||
|
||||
def _adapt_update_db_opts(fn):
|
||||
insp = util.inspect_getfullargspec(fn)
|
||||
if len(insp.args) == 3:
|
||||
return fn
|
||||
else:
|
||||
return lambda db_url, db_opts, _options: fn(db_url, db_opts)
|
||||
|
||||
|
||||
@register.init_decorator(_adapt_update_db_opts)
|
||||
def update_db_opts(db_url, db_opts, options):
|
||||
"""Set database options (db_opts) for a test database that we created."""
|
||||
|
||||
|
||||
@register.init
|
||||
def post_configure_engine(url, engine, follower_ident):
|
||||
"""Perform extra steps after configuring an engine for testing.
|
||||
|
||||
(For the internal dialects, currently only used by sqlite, oracle, mssql)
|
||||
"""
|
||||
|
||||
|
||||
@register.init
|
||||
def follower_url_from_main(url, ident):
|
||||
"""Create a connection URL for a dynamically-created test database.
|
||||
|
||||
:param url: the connection URL specified when the test run was invoked
|
||||
:param ident: the pytest-xdist "worker identifier" to be used as the
|
||||
database name
|
||||
"""
|
||||
url = sa_url.make_url(url)
|
||||
return url.set(database=ident)
|
||||
|
||||
|
||||
@register.init
|
||||
def configure_follower(cfg, ident):
|
||||
"""Create dialect-specific config settings for a follower database."""
|
||||
pass
|
||||
|
||||
|
||||
@register.init
|
||||
def run_reap_dbs(url, ident):
|
||||
"""Remove databases that were created during the test process, after the
|
||||
process has ended.
|
||||
|
||||
This is an optional step that is invoked for certain backends that do not
|
||||
reliably release locks on the database as long as a process is still in
|
||||
use. For the internal dialects, this is currently only necessary for
|
||||
mssql and oracle.
|
||||
"""
|
||||
|
||||
|
||||
def reap_dbs(idents_file):
|
||||
log.info("Reaping databases...")
|
||||
|
||||
urls = collections.defaultdict(set)
|
||||
idents = collections.defaultdict(set)
|
||||
dialects = {}
|
||||
|
||||
with open(idents_file) as file_:
|
||||
for line in file_:
|
||||
line = line.strip()
|
||||
db_name, db_url = line.split(" ")
|
||||
url_obj = sa_url.make_url(db_url)
|
||||
if db_name not in dialects:
|
||||
dialects[db_name] = url_obj.get_dialect()
|
||||
dialects[db_name].load_provisioning()
|
||||
url_key = (url_obj.get_backend_name(), url_obj.host)
|
||||
urls[url_key].add(db_url)
|
||||
idents[url_key].add(db_name)
|
||||
|
||||
for url_key in urls:
|
||||
url = list(urls[url_key])[0]
|
||||
ident = idents[url_key]
|
||||
run_reap_dbs(url, ident)
|
||||
|
||||
|
||||
@register.init
|
||||
def temp_table_keyword_args(cfg, eng):
|
||||
"""Specify keyword arguments for creating a temporary Table.
|
||||
|
||||
Dialect-specific implementations of this method will return the
|
||||
kwargs that are passed to the Table method when creating a temporary
|
||||
table for testing, e.g., in the define_temp_tables method of the
|
||||
ComponentReflectionTest class in suite/test_reflection.py
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"no temp table keyword args routine for cfg: %s" % (eng.url,)
|
||||
)
|
||||
|
||||
|
||||
@register.init
|
||||
def prepare_for_drop_tables(config, connection):
|
||||
pass
|
||||
|
||||
|
||||
@register.init
|
||||
def stop_test_class_outside_fixtures(config, db, testcls):
|
||||
pass
|
||||
|
||||
|
||||
@register.init
|
||||
def get_temp_table_name(cfg, eng, base_name):
|
||||
"""Specify table name for creating a temporary Table.
|
||||
|
||||
Dialect-specific implementations of this method will return the
|
||||
name to use when creating a temporary table for testing,
|
||||
e.g., in the define_temp_tables method of the
|
||||
ComponentReflectionTest class in suite/test_reflection.py
|
||||
|
||||
Default to just the base name since that's what most dialects will
|
||||
use. The mssql dialect's implementation will need a "#" prepended.
|
||||
"""
|
||||
return base_name
|
||||
|
||||
|
||||
@register.init
|
||||
def set_default_schema_on_connection(cfg, dbapi_connection, schema_name):
|
||||
raise NotImplementedError(
|
||||
"backend does not implement a schema name set function: %s"
|
||||
% (cfg.db.url,)
|
||||
)
|
||||
|
||||
|
||||
@register.init
|
||||
def upsert(
|
||||
cfg, table, returning, *, set_lambda=None, sort_by_parameter_order=False
|
||||
):
|
||||
"""return the backends insert..on conflict / on dupe etc. construct.
|
||||
|
||||
while we should add a backend-neutral upsert construct as well, such as
|
||||
insert().upsert(), it's important that we continue to test the
|
||||
backend-specific insert() constructs since if we do implement
|
||||
insert().upsert(), that would be using a different codepath for the things
|
||||
we need to test like insertmanyvalues, etc.
|
||||
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"backend does not include an upsert implementation: {cfg.db.url}"
|
||||
)
|
||||
|
||||
|
||||
@register.init
|
||||
def normalize_sequence(cfg, sequence):
|
||||
"""Normalize sequence parameters for dialect that don't start with 1
|
||||
by default.
|
||||
|
||||
The default implementation does nothing
|
||||
"""
|
||||
return sequence
|
1847
venv/lib/python3.11/site-packages/sqlalchemy/testing/requirements.py
Normal file
1847
venv/lib/python3.11/site-packages/sqlalchemy/testing/requirements.py
Normal file
File diff suppressed because it is too large
Load Diff
224
venv/lib/python3.11/site-packages/sqlalchemy/testing/schema.py
Normal file
224
venv/lib/python3.11/site-packages/sqlalchemy/testing/schema.py
Normal file
@ -0,0 +1,224 @@
|
||||
# testing/schema.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
from . import config
|
||||
from . import exclusions
|
||||
from .. import event
|
||||
from .. import schema
|
||||
from .. import types as sqltypes
|
||||
from ..orm import mapped_column as _orm_mapped_column
|
||||
from ..util import OrderedDict
|
||||
|
||||
__all__ = ["Table", "Column"]
|
||||
|
||||
table_options = {}
|
||||
|
||||
|
||||
def Table(*args, **kw) -> schema.Table:
|
||||
"""A schema.Table wrapper/hook for dialect-specific tweaks."""
|
||||
|
||||
test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
|
||||
|
||||
kw.update(table_options)
|
||||
|
||||
if exclusions.against(config._current, "mysql"):
|
||||
if (
|
||||
"mysql_engine" not in kw
|
||||
and "mysql_type" not in kw
|
||||
and "autoload_with" not in kw
|
||||
):
|
||||
if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts:
|
||||
kw["mysql_engine"] = "InnoDB"
|
||||
else:
|
||||
# there are in fact test fixtures that rely upon MyISAM,
|
||||
# due to MySQL / MariaDB having poor FK behavior under innodb,
|
||||
# such as a self-referential table can't be deleted from at
|
||||
# once without attending to per-row dependencies. We'd need to
|
||||
# add special steps to some fixtures if we want to not
|
||||
# explicitly state MyISAM here
|
||||
kw["mysql_engine"] = "MyISAM"
|
||||
elif exclusions.against(config._current, "mariadb"):
|
||||
if (
|
||||
"mariadb_engine" not in kw
|
||||
and "mariadb_type" not in kw
|
||||
and "autoload_with" not in kw
|
||||
):
|
||||
if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts:
|
||||
kw["mariadb_engine"] = "InnoDB"
|
||||
else:
|
||||
kw["mariadb_engine"] = "MyISAM"
|
||||
|
||||
return schema.Table(*args, **kw)
|
||||
|
||||
|
||||
def mapped_column(*args, **kw):
|
||||
"""An orm.mapped_column wrapper/hook for dialect-specific tweaks."""
|
||||
|
||||
return _schema_column(_orm_mapped_column, args, kw)
|
||||
|
||||
|
||||
def Column(*args, **kw):
|
||||
"""A schema.Column wrapper/hook for dialect-specific tweaks."""
|
||||
|
||||
return _schema_column(schema.Column, args, kw)
|
||||
|
||||
|
||||
def _schema_column(factory, args, kw):
|
||||
test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
|
||||
|
||||
if not config.requirements.foreign_key_ddl.enabled_for_config(config):
|
||||
args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)]
|
||||
|
||||
construct = factory(*args, **kw)
|
||||
|
||||
if factory is schema.Column:
|
||||
col = construct
|
||||
else:
|
||||
col = construct.column
|
||||
|
||||
if test_opts.get("test_needs_autoincrement", False) and kw.get(
|
||||
"primary_key", False
|
||||
):
|
||||
if col.default is None and col.server_default is None:
|
||||
col.autoincrement = True
|
||||
|
||||
# allow any test suite to pick up on this
|
||||
col.info["test_needs_autoincrement"] = True
|
||||
|
||||
# hardcoded rule for oracle; this should
|
||||
# be moved out
|
||||
if exclusions.against(config._current, "oracle"):
|
||||
|
||||
def add_seq(c, tbl):
|
||||
c._init_items(
|
||||
schema.Sequence(
|
||||
_truncate_name(
|
||||
config.db.dialect, tbl.name + "_" + c.name + "_seq"
|
||||
),
|
||||
optional=True,
|
||||
)
|
||||
)
|
||||
|
||||
event.listen(col, "after_parent_attach", add_seq, propagate=True)
|
||||
return construct
|
||||
|
||||
|
||||
class eq_type_affinity:
|
||||
"""Helper to compare types inside of datastructures based on affinity.
|
||||
|
||||
E.g.::
|
||||
|
||||
eq_(
|
||||
inspect(connection).get_columns("foo"),
|
||||
[
|
||||
{
|
||||
"name": "id",
|
||||
"type": testing.eq_type_affinity(sqltypes.INTEGER),
|
||||
"nullable": False,
|
||||
"default": None,
|
||||
"autoincrement": False,
|
||||
},
|
||||
{
|
||||
"name": "data",
|
||||
"type": testing.eq_type_affinity(sqltypes.NullType),
|
||||
"nullable": True,
|
||||
"default": None,
|
||||
"autoincrement": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, target):
|
||||
self.target = sqltypes.to_instance(target)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.target._type_affinity is other._type_affinity
|
||||
|
||||
def __ne__(self, other):
|
||||
return self.target._type_affinity is not other._type_affinity
|
||||
|
||||
|
||||
class eq_compile_type:
|
||||
"""similar to eq_type_affinity but uses compile"""
|
||||
|
||||
def __init__(self, target):
|
||||
self.target = target
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.target == other.compile()
|
||||
|
||||
def __ne__(self, other):
|
||||
return self.target != other.compile()
|
||||
|
||||
|
||||
class eq_clause_element:
|
||||
"""Helper to compare SQL structures based on compare()"""
|
||||
|
||||
def __init__(self, target):
|
||||
self.target = target
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.target.compare(other)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.target.compare(other)
|
||||
|
||||
|
||||
def _truncate_name(dialect, name):
|
||||
if len(name) > dialect.max_identifier_length:
|
||||
return (
|
||||
name[0 : max(dialect.max_identifier_length - 6, 0)]
|
||||
+ "_"
|
||||
+ hex(hash(name) % 64)[2:]
|
||||
)
|
||||
else:
|
||||
return name
|
||||
|
||||
|
||||
def pep435_enum(name):
|
||||
# Implements PEP 435 in the minimal fashion needed by SQLAlchemy
|
||||
__members__ = OrderedDict()
|
||||
|
||||
def __init__(self, name, value, alias=None):
|
||||
self.name = name
|
||||
self.value = value
|
||||
self.__members__[name] = self
|
||||
value_to_member[value] = self
|
||||
setattr(self.__class__, name, self)
|
||||
if alias:
|
||||
self.__members__[alias] = self
|
||||
setattr(self.__class__, alias, self)
|
||||
|
||||
value_to_member = {}
|
||||
|
||||
@classmethod
|
||||
def get(cls, value):
|
||||
return value_to_member[value]
|
||||
|
||||
someenum = type(
|
||||
name,
|
||||
(object,),
|
||||
{"__members__": __members__, "__init__": __init__, "get": get},
|
||||
)
|
||||
|
||||
# getframe() trick for pickling I don't understand courtesy
|
||||
# Python namedtuple()
|
||||
try:
|
||||
module = sys._getframe(1).f_globals.get("__name__", "__main__")
|
||||
except (AttributeError, ValueError):
|
||||
pass
|
||||
if module is not None:
|
||||
someenum.__module__ = module
|
||||
|
||||
return someenum
|
@ -0,0 +1,19 @@
|
||||
# testing/suite/__init__.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
from .test_cte import * # noqa
|
||||
from .test_ddl import * # noqa
|
||||
from .test_deprecations import * # noqa
|
||||
from .test_dialect import * # noqa
|
||||
from .test_insert import * # noqa
|
||||
from .test_reflection import * # noqa
|
||||
from .test_results import * # noqa
|
||||
from .test_rowcount import * # noqa
|
||||
from .test_select import * # noqa
|
||||
from .test_sequence import * # noqa
|
||||
from .test_types import * # noqa
|
||||
from .test_unicode_ddl import * # noqa
|
||||
from .test_update_delete import * # noqa
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,211 @@
|
||||
# testing/suite/test_cte.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from .. import fixtures
|
||||
from ..assertions import eq_
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import ForeignKey
|
||||
from ... import Integer
|
||||
from ... import select
|
||||
from ... import String
|
||||
from ... import testing
|
||||
|
||||
|
||||
class CTETest(fixtures.TablesTest):
|
||||
__backend__ = True
|
||||
__requires__ = ("ctes",)
|
||||
|
||||
run_inserts = "each"
|
||||
run_deletes = "each"
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"some_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("data", String(50)),
|
||||
Column("parent_id", ForeignKey("some_table.id")),
|
||||
)
|
||||
|
||||
Table(
|
||||
"some_other_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("data", String(50)),
|
||||
Column("parent_id", Integer),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls, connection):
|
||||
connection.execute(
|
||||
cls.tables.some_table.insert(),
|
||||
[
|
||||
{"id": 1, "data": "d1", "parent_id": None},
|
||||
{"id": 2, "data": "d2", "parent_id": 1},
|
||||
{"id": 3, "data": "d3", "parent_id": 1},
|
||||
{"id": 4, "data": "d4", "parent_id": 3},
|
||||
{"id": 5, "data": "d5", "parent_id": 3},
|
||||
],
|
||||
)
|
||||
|
||||
def test_select_nonrecursive_round_trip(self, connection):
|
||||
some_table = self.tables.some_table
|
||||
|
||||
cte = (
|
||||
select(some_table)
|
||||
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||||
.cte("some_cte")
|
||||
)
|
||||
result = connection.execute(
|
||||
select(cte.c.data).where(cte.c.data.in_(["d4", "d5"]))
|
||||
)
|
||||
eq_(result.fetchall(), [("d4",)])
|
||||
|
||||
def test_select_recursive_round_trip(self, connection):
|
||||
some_table = self.tables.some_table
|
||||
|
||||
cte = (
|
||||
select(some_table)
|
||||
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||||
.cte("some_cte", recursive=True)
|
||||
)
|
||||
|
||||
cte_alias = cte.alias("c1")
|
||||
st1 = some_table.alias()
|
||||
# note that SQL Server requires this to be UNION ALL,
|
||||
# can't be UNION
|
||||
cte = cte.union_all(
|
||||
select(st1).where(st1.c.id == cte_alias.c.parent_id)
|
||||
)
|
||||
result = connection.execute(
|
||||
select(cte.c.data)
|
||||
.where(cte.c.data != "d2")
|
||||
.order_by(cte.c.data.desc())
|
||||
)
|
||||
eq_(
|
||||
result.fetchall(),
|
||||
[("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)],
|
||||
)
|
||||
|
||||
def test_insert_from_select_round_trip(self, connection):
|
||||
some_table = self.tables.some_table
|
||||
some_other_table = self.tables.some_other_table
|
||||
|
||||
cte = (
|
||||
select(some_table)
|
||||
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||||
.cte("some_cte")
|
||||
)
|
||||
connection.execute(
|
||||
some_other_table.insert().from_select(
|
||||
["id", "data", "parent_id"], select(cte)
|
||||
)
|
||||
)
|
||||
eq_(
|
||||
connection.execute(
|
||||
select(some_other_table).order_by(some_other_table.c.id)
|
||||
).fetchall(),
|
||||
[(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)],
|
||||
)
|
||||
|
||||
@testing.requires.ctes_with_update_delete
|
||||
@testing.requires.update_from
|
||||
def test_update_from_round_trip(self, connection):
|
||||
some_table = self.tables.some_table
|
||||
some_other_table = self.tables.some_other_table
|
||||
|
||||
connection.execute(
|
||||
some_other_table.insert().from_select(
|
||||
["id", "data", "parent_id"], select(some_table)
|
||||
)
|
||||
)
|
||||
|
||||
cte = (
|
||||
select(some_table)
|
||||
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||||
.cte("some_cte")
|
||||
)
|
||||
connection.execute(
|
||||
some_other_table.update()
|
||||
.values(parent_id=5)
|
||||
.where(some_other_table.c.data == cte.c.data)
|
||||
)
|
||||
eq_(
|
||||
connection.execute(
|
||||
select(some_other_table).order_by(some_other_table.c.id)
|
||||
).fetchall(),
|
||||
[
|
||||
(1, "d1", None),
|
||||
(2, "d2", 5),
|
||||
(3, "d3", 5),
|
||||
(4, "d4", 5),
|
||||
(5, "d5", 3),
|
||||
],
|
||||
)
|
||||
|
||||
@testing.requires.ctes_with_update_delete
|
||||
@testing.requires.delete_from
|
||||
def test_delete_from_round_trip(self, connection):
|
||||
some_table = self.tables.some_table
|
||||
some_other_table = self.tables.some_other_table
|
||||
|
||||
connection.execute(
|
||||
some_other_table.insert().from_select(
|
||||
["id", "data", "parent_id"], select(some_table)
|
||||
)
|
||||
)
|
||||
|
||||
cte = (
|
||||
select(some_table)
|
||||
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||||
.cte("some_cte")
|
||||
)
|
||||
connection.execute(
|
||||
some_other_table.delete().where(
|
||||
some_other_table.c.data == cte.c.data
|
||||
)
|
||||
)
|
||||
eq_(
|
||||
connection.execute(
|
||||
select(some_other_table).order_by(some_other_table.c.id)
|
||||
).fetchall(),
|
||||
[(1, "d1", None), (5, "d5", 3)],
|
||||
)
|
||||
|
||||
@testing.requires.ctes_with_update_delete
|
||||
def test_delete_scalar_subq_round_trip(self, connection):
|
||||
some_table = self.tables.some_table
|
||||
some_other_table = self.tables.some_other_table
|
||||
|
||||
connection.execute(
|
||||
some_other_table.insert().from_select(
|
||||
["id", "data", "parent_id"], select(some_table)
|
||||
)
|
||||
)
|
||||
|
||||
cte = (
|
||||
select(some_table)
|
||||
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||||
.cte("some_cte")
|
||||
)
|
||||
connection.execute(
|
||||
some_other_table.delete().where(
|
||||
some_other_table.c.data
|
||||
== select(cte.c.data)
|
||||
.where(cte.c.id == some_other_table.c.id)
|
||||
.scalar_subquery()
|
||||
)
|
||||
)
|
||||
eq_(
|
||||
connection.execute(
|
||||
select(some_other_table).order_by(some_other_table.c.id)
|
||||
).fetchall(),
|
||||
[(1, "d1", None), (5, "d5", 3)],
|
||||
)
|
@ -0,0 +1,389 @@
|
||||
# testing/suite/test_ddl.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
import random
|
||||
|
||||
from . import testing
|
||||
from .. import config
|
||||
from .. import fixtures
|
||||
from .. import util
|
||||
from ..assertions import eq_
|
||||
from ..assertions import is_false
|
||||
from ..assertions import is_true
|
||||
from ..config import requirements
|
||||
from ..schema import Table
|
||||
from ... import CheckConstraint
|
||||
from ... import Column
|
||||
from ... import ForeignKeyConstraint
|
||||
from ... import Index
|
||||
from ... import inspect
|
||||
from ... import Integer
|
||||
from ... import schema
|
||||
from ... import String
|
||||
from ... import UniqueConstraint
|
||||
|
||||
|
||||
class TableDDLTest(fixtures.TestBase):
|
||||
__backend__ = True
|
||||
|
||||
def _simple_fixture(self, schema=None):
|
||||
return Table(
|
||||
"test_table",
|
||||
self.metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("data", String(50)),
|
||||
schema=schema,
|
||||
)
|
||||
|
||||
def _underscore_fixture(self):
|
||||
return Table(
|
||||
"_test_table",
|
||||
self.metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("_data", String(50)),
|
||||
)
|
||||
|
||||
def _table_index_fixture(self, schema=None):
|
||||
table = self._simple_fixture(schema=schema)
|
||||
idx = Index("test_index", table.c.data)
|
||||
return table, idx
|
||||
|
||||
def _simple_roundtrip(self, table):
|
||||
with config.db.begin() as conn:
|
||||
conn.execute(table.insert().values((1, "some data")))
|
||||
result = conn.execute(table.select())
|
||||
eq_(result.first(), (1, "some data"))
|
||||
|
||||
@requirements.create_table
|
||||
@util.provide_metadata
|
||||
def test_create_table(self):
|
||||
table = self._simple_fixture()
|
||||
table.create(config.db, checkfirst=False)
|
||||
self._simple_roundtrip(table)
|
||||
|
||||
@requirements.create_table
|
||||
@requirements.schemas
|
||||
@util.provide_metadata
|
||||
def test_create_table_schema(self):
|
||||
table = self._simple_fixture(schema=config.test_schema)
|
||||
table.create(config.db, checkfirst=False)
|
||||
self._simple_roundtrip(table)
|
||||
|
||||
@requirements.drop_table
|
||||
@util.provide_metadata
|
||||
def test_drop_table(self):
|
||||
table = self._simple_fixture()
|
||||
table.create(config.db, checkfirst=False)
|
||||
table.drop(config.db, checkfirst=False)
|
||||
|
||||
@requirements.create_table
|
||||
@util.provide_metadata
|
||||
def test_underscore_names(self):
|
||||
table = self._underscore_fixture()
|
||||
table.create(config.db, checkfirst=False)
|
||||
self._simple_roundtrip(table)
|
||||
|
||||
@requirements.comment_reflection
|
||||
@util.provide_metadata
|
||||
def test_add_table_comment(self, connection):
|
||||
table = self._simple_fixture()
|
||||
table.create(connection, checkfirst=False)
|
||||
table.comment = "a comment"
|
||||
connection.execute(schema.SetTableComment(table))
|
||||
eq_(
|
||||
inspect(connection).get_table_comment("test_table"),
|
||||
{"text": "a comment"},
|
||||
)
|
||||
|
||||
@requirements.comment_reflection
|
||||
@util.provide_metadata
|
||||
def test_drop_table_comment(self, connection):
|
||||
table = self._simple_fixture()
|
||||
table.create(connection, checkfirst=False)
|
||||
table.comment = "a comment"
|
||||
connection.execute(schema.SetTableComment(table))
|
||||
connection.execute(schema.DropTableComment(table))
|
||||
eq_(
|
||||
inspect(connection).get_table_comment("test_table"), {"text": None}
|
||||
)
|
||||
|
||||
@requirements.table_ddl_if_exists
|
||||
@util.provide_metadata
|
||||
def test_create_table_if_not_exists(self, connection):
|
||||
table = self._simple_fixture()
|
||||
|
||||
connection.execute(schema.CreateTable(table, if_not_exists=True))
|
||||
|
||||
is_true(inspect(connection).has_table("test_table"))
|
||||
connection.execute(schema.CreateTable(table, if_not_exists=True))
|
||||
|
||||
@requirements.index_ddl_if_exists
|
||||
@util.provide_metadata
|
||||
def test_create_index_if_not_exists(self, connection):
|
||||
table, idx = self._table_index_fixture()
|
||||
|
||||
connection.execute(schema.CreateTable(table, if_not_exists=True))
|
||||
is_true(inspect(connection).has_table("test_table"))
|
||||
is_false(
|
||||
"test_index"
|
||||
in [
|
||||
ix["name"]
|
||||
for ix in inspect(connection).get_indexes("test_table")
|
||||
]
|
||||
)
|
||||
|
||||
connection.execute(schema.CreateIndex(idx, if_not_exists=True))
|
||||
|
||||
is_true(
|
||||
"test_index"
|
||||
in [
|
||||
ix["name"]
|
||||
for ix in inspect(connection).get_indexes("test_table")
|
||||
]
|
||||
)
|
||||
|
||||
connection.execute(schema.CreateIndex(idx, if_not_exists=True))
|
||||
|
||||
@requirements.table_ddl_if_exists
|
||||
@util.provide_metadata
|
||||
def test_drop_table_if_exists(self, connection):
|
||||
table = self._simple_fixture()
|
||||
|
||||
table.create(connection)
|
||||
|
||||
is_true(inspect(connection).has_table("test_table"))
|
||||
|
||||
connection.execute(schema.DropTable(table, if_exists=True))
|
||||
|
||||
is_false(inspect(connection).has_table("test_table"))
|
||||
|
||||
connection.execute(schema.DropTable(table, if_exists=True))
|
||||
|
||||
@requirements.index_ddl_if_exists
|
||||
@util.provide_metadata
|
||||
def test_drop_index_if_exists(self, connection):
|
||||
table, idx = self._table_index_fixture()
|
||||
|
||||
table.create(connection)
|
||||
|
||||
is_true(
|
||||
"test_index"
|
||||
in [
|
||||
ix["name"]
|
||||
for ix in inspect(connection).get_indexes("test_table")
|
||||
]
|
||||
)
|
||||
|
||||
connection.execute(schema.DropIndex(idx, if_exists=True))
|
||||
|
||||
is_false(
|
||||
"test_index"
|
||||
in [
|
||||
ix["name"]
|
||||
for ix in inspect(connection).get_indexes("test_table")
|
||||
]
|
||||
)
|
||||
|
||||
connection.execute(schema.DropIndex(idx, if_exists=True))
|
||||
|
||||
|
||||
class FutureTableDDLTest(fixtures.FutureEngineMixin, TableDDLTest):
|
||||
pass
|
||||
|
||||
|
||||
class LongNameBlowoutTest(fixtures.TestBase):
|
||||
"""test the creation of a variety of DDL structures and ensure
|
||||
label length limits pass on backends
|
||||
|
||||
"""
|
||||
|
||||
__backend__ = True
|
||||
|
||||
def fk(self, metadata, connection):
|
||||
convention = {
|
||||
"fk": "foreign_key_%(table_name)s_"
|
||||
"%(column_0_N_name)s_"
|
||||
"%(referred_table_name)s_"
|
||||
+ (
|
||||
"_".join(
|
||||
"".join(random.choice("abcdef") for j in range(20))
|
||||
for i in range(10)
|
||||
)
|
||||
),
|
||||
}
|
||||
metadata.naming_convention = convention
|
||||
|
||||
Table(
|
||||
"a_things_with_stuff",
|
||||
metadata,
|
||||
Column("id_long_column_name", Integer, primary_key=True),
|
||||
test_needs_fk=True,
|
||||
)
|
||||
|
||||
cons = ForeignKeyConstraint(
|
||||
["aid"], ["a_things_with_stuff.id_long_column_name"]
|
||||
)
|
||||
Table(
|
||||
"b_related_things_of_value",
|
||||
metadata,
|
||||
Column(
|
||||
"aid",
|
||||
),
|
||||
cons,
|
||||
test_needs_fk=True,
|
||||
)
|
||||
actual_name = cons.name
|
||||
|
||||
metadata.create_all(connection)
|
||||
|
||||
if testing.requires.foreign_key_constraint_name_reflection.enabled:
|
||||
insp = inspect(connection)
|
||||
fks = insp.get_foreign_keys("b_related_things_of_value")
|
||||
reflected_name = fks[0]["name"]
|
||||
|
||||
return actual_name, reflected_name
|
||||
else:
|
||||
return actual_name, None
|
||||
|
||||
def pk(self, metadata, connection):
|
||||
convention = {
|
||||
"pk": "primary_key_%(table_name)s_"
|
||||
"%(column_0_N_name)s"
|
||||
+ (
|
||||
"_".join(
|
||||
"".join(random.choice("abcdef") for j in range(30))
|
||||
for i in range(10)
|
||||
)
|
||||
),
|
||||
}
|
||||
metadata.naming_convention = convention
|
||||
|
||||
a = Table(
|
||||
"a_things_with_stuff",
|
||||
metadata,
|
||||
Column("id_long_column_name", Integer, primary_key=True),
|
||||
Column("id_another_long_name", Integer, primary_key=True),
|
||||
)
|
||||
cons = a.primary_key
|
||||
actual_name = cons.name
|
||||
|
||||
metadata.create_all(connection)
|
||||
insp = inspect(connection)
|
||||
pk = insp.get_pk_constraint("a_things_with_stuff")
|
||||
reflected_name = pk["name"]
|
||||
return actual_name, reflected_name
|
||||
|
||||
def ix(self, metadata, connection):
|
||||
convention = {
|
||||
"ix": "index_%(table_name)s_"
|
||||
"%(column_0_N_name)s"
|
||||
+ (
|
||||
"_".join(
|
||||
"".join(random.choice("abcdef") for j in range(30))
|
||||
for i in range(10)
|
||||
)
|
||||
),
|
||||
}
|
||||
metadata.naming_convention = convention
|
||||
|
||||
a = Table(
|
||||
"a_things_with_stuff",
|
||||
metadata,
|
||||
Column("id_long_column_name", Integer, primary_key=True),
|
||||
Column("id_another_long_name", Integer),
|
||||
)
|
||||
cons = Index(None, a.c.id_long_column_name, a.c.id_another_long_name)
|
||||
actual_name = cons.name
|
||||
|
||||
metadata.create_all(connection)
|
||||
insp = inspect(connection)
|
||||
ix = insp.get_indexes("a_things_with_stuff")
|
||||
reflected_name = ix[0]["name"]
|
||||
return actual_name, reflected_name
|
||||
|
||||
def uq(self, metadata, connection):
|
||||
convention = {
|
||||
"uq": "unique_constraint_%(table_name)s_"
|
||||
"%(column_0_N_name)s"
|
||||
+ (
|
||||
"_".join(
|
||||
"".join(random.choice("abcdef") for j in range(30))
|
||||
for i in range(10)
|
||||
)
|
||||
),
|
||||
}
|
||||
metadata.naming_convention = convention
|
||||
|
||||
cons = UniqueConstraint("id_long_column_name", "id_another_long_name")
|
||||
Table(
|
||||
"a_things_with_stuff",
|
||||
metadata,
|
||||
Column("id_long_column_name", Integer, primary_key=True),
|
||||
Column("id_another_long_name", Integer),
|
||||
cons,
|
||||
)
|
||||
actual_name = cons.name
|
||||
|
||||
metadata.create_all(connection)
|
||||
insp = inspect(connection)
|
||||
uq = insp.get_unique_constraints("a_things_with_stuff")
|
||||
reflected_name = uq[0]["name"]
|
||||
return actual_name, reflected_name
|
||||
|
||||
def ck(self, metadata, connection):
|
||||
convention = {
|
||||
"ck": "check_constraint_%(table_name)s"
|
||||
+ (
|
||||
"_".join(
|
||||
"".join(random.choice("abcdef") for j in range(30))
|
||||
for i in range(10)
|
||||
)
|
||||
),
|
||||
}
|
||||
metadata.naming_convention = convention
|
||||
|
||||
cons = CheckConstraint("some_long_column_name > 5")
|
||||
Table(
|
||||
"a_things_with_stuff",
|
||||
metadata,
|
||||
Column("id_long_column_name", Integer, primary_key=True),
|
||||
Column("some_long_column_name", Integer),
|
||||
cons,
|
||||
)
|
||||
actual_name = cons.name
|
||||
|
||||
metadata.create_all(connection)
|
||||
insp = inspect(connection)
|
||||
ck = insp.get_check_constraints("a_things_with_stuff")
|
||||
reflected_name = ck[0]["name"]
|
||||
return actual_name, reflected_name
|
||||
|
||||
@testing.combinations(
|
||||
("fk",),
|
||||
("pk",),
|
||||
("ix",),
|
||||
("ck", testing.requires.check_constraint_reflection.as_skips()),
|
||||
("uq", testing.requires.unique_constraint_reflection.as_skips()),
|
||||
argnames="type_",
|
||||
)
|
||||
def test_long_convention_name(self, type_, metadata, connection):
|
||||
actual_name, reflected_name = getattr(self, type_)(
|
||||
metadata, connection
|
||||
)
|
||||
|
||||
assert len(actual_name) > 255
|
||||
|
||||
if reflected_name is not None:
|
||||
overlap = actual_name[0 : len(reflected_name)]
|
||||
if len(overlap) < len(actual_name):
|
||||
eq_(overlap[0:-5], reflected_name[0 : len(overlap) - 5])
|
||||
else:
|
||||
eq_(overlap, reflected_name)
|
||||
|
||||
|
||||
__all__ = ("TableDDLTest", "FutureTableDDLTest", "LongNameBlowoutTest")
|
@ -0,0 +1,153 @@
|
||||
# testing/suite/test_deprecations.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from .. import fixtures
|
||||
from ..assertions import eq_
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import Integer
|
||||
from ... import select
|
||||
from ... import testing
|
||||
from ... import union
|
||||
|
||||
|
||||
class DeprecatedCompoundSelectTest(fixtures.TablesTest):
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"some_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("x", Integer),
|
||||
Column("y", Integer),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls, connection):
|
||||
connection.execute(
|
||||
cls.tables.some_table.insert(),
|
||||
[
|
||||
{"id": 1, "x": 1, "y": 2},
|
||||
{"id": 2, "x": 2, "y": 3},
|
||||
{"id": 3, "x": 3, "y": 4},
|
||||
{"id": 4, "x": 4, "y": 5},
|
||||
],
|
||||
)
|
||||
|
||||
def _assert_result(self, conn, select, result, params=()):
|
||||
eq_(conn.execute(select, params).fetchall(), result)
|
||||
|
||||
def test_plain_union(self, connection):
|
||||
table = self.tables.some_table
|
||||
s1 = select(table).where(table.c.id == 2)
|
||||
s2 = select(table).where(table.c.id == 3)
|
||||
|
||||
u1 = union(s1, s2)
|
||||
with testing.expect_deprecated(
|
||||
"The SelectBase.c and SelectBase.columns "
|
||||
"attributes are deprecated"
|
||||
):
|
||||
self._assert_result(
|
||||
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
# note we've had to remove one use case entirely, which is this
|
||||
# one. the Select gets its FROMS from the WHERE clause and the
|
||||
# columns clause, but not the ORDER BY, which means the old ".c" system
|
||||
# allowed you to "order_by(s.c.foo)" to get an unnamed column in the
|
||||
# ORDER BY without adding the SELECT into the FROM and breaking the
|
||||
# query. Users will have to adjust for this use case if they were doing
|
||||
# it before.
|
||||
def _dont_test_select_from_plain_union(self, connection):
|
||||
table = self.tables.some_table
|
||||
s1 = select(table).where(table.c.id == 2)
|
||||
s2 = select(table).where(table.c.id == 3)
|
||||
|
||||
u1 = union(s1, s2).alias().select()
|
||||
with testing.expect_deprecated(
|
||||
"The SelectBase.c and SelectBase.columns "
|
||||
"attributes are deprecated"
|
||||
):
|
||||
self._assert_result(
|
||||
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
@testing.requires.order_by_col_from_union
|
||||
@testing.requires.parens_in_union_contained_select_w_limit_offset
|
||||
def test_limit_offset_selectable_in_unions(self, connection):
|
||||
table = self.tables.some_table
|
||||
s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id)
|
||||
s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id)
|
||||
|
||||
u1 = union(s1, s2).limit(2)
|
||||
with testing.expect_deprecated(
|
||||
"The SelectBase.c and SelectBase.columns "
|
||||
"attributes are deprecated"
|
||||
):
|
||||
self._assert_result(
|
||||
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
@testing.requires.parens_in_union_contained_select_wo_limit_offset
|
||||
def test_order_by_selectable_in_unions(self, connection):
|
||||
table = self.tables.some_table
|
||||
s1 = select(table).where(table.c.id == 2).order_by(table.c.id)
|
||||
s2 = select(table).where(table.c.id == 3).order_by(table.c.id)
|
||||
|
||||
u1 = union(s1, s2).limit(2)
|
||||
with testing.expect_deprecated(
|
||||
"The SelectBase.c and SelectBase.columns "
|
||||
"attributes are deprecated"
|
||||
):
|
||||
self._assert_result(
|
||||
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
def test_distinct_selectable_in_unions(self, connection):
|
||||
table = self.tables.some_table
|
||||
s1 = select(table).where(table.c.id == 2).distinct()
|
||||
s2 = select(table).where(table.c.id == 3).distinct()
|
||||
|
||||
u1 = union(s1, s2).limit(2)
|
||||
with testing.expect_deprecated(
|
||||
"The SelectBase.c and SelectBase.columns "
|
||||
"attributes are deprecated"
|
||||
):
|
||||
self._assert_result(
|
||||
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
def test_limit_offset_aliased_selectable_in_unions(self, connection):
|
||||
table = self.tables.some_table
|
||||
s1 = (
|
||||
select(table)
|
||||
.where(table.c.id == 2)
|
||||
.limit(1)
|
||||
.order_by(table.c.id)
|
||||
.alias()
|
||||
.select()
|
||||
)
|
||||
s2 = (
|
||||
select(table)
|
||||
.where(table.c.id == 3)
|
||||
.limit(1)
|
||||
.order_by(table.c.id)
|
||||
.alias()
|
||||
.select()
|
||||
)
|
||||
|
||||
u1 = union(s1, s2).limit(2)
|
||||
with testing.expect_deprecated(
|
||||
"The SelectBase.c and SelectBase.columns "
|
||||
"attributes are deprecated"
|
||||
):
|
||||
self._assert_result(
|
||||
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
|
||||
)
|
@ -0,0 +1,740 @@
|
||||
# testing/suite/test_dialect.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
import importlib
|
||||
|
||||
from . import testing
|
||||
from .. import assert_raises
|
||||
from .. import config
|
||||
from .. import engines
|
||||
from .. import eq_
|
||||
from .. import fixtures
|
||||
from .. import is_not_none
|
||||
from .. import is_true
|
||||
from .. import ne_
|
||||
from .. import provide_metadata
|
||||
from ..assertions import expect_raises
|
||||
from ..assertions import expect_raises_message
|
||||
from ..config import requirements
|
||||
from ..provision import set_default_schema_on_connection
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import bindparam
|
||||
from ... import dialects
|
||||
from ... import event
|
||||
from ... import exc
|
||||
from ... import Integer
|
||||
from ... import literal_column
|
||||
from ... import select
|
||||
from ... import String
|
||||
from ...sql.compiler import Compiled
|
||||
from ...util import inspect_getfullargspec
|
||||
|
||||
|
||||
class PingTest(fixtures.TestBase):
|
||||
__backend__ = True
|
||||
|
||||
def test_do_ping(self):
|
||||
with testing.db.connect() as conn:
|
||||
is_true(
|
||||
testing.db.dialect.do_ping(conn.connection.dbapi_connection)
|
||||
)
|
||||
|
||||
|
||||
class ArgSignatureTest(fixtures.TestBase):
|
||||
"""test that all visit_XYZ() in :class:`_sql.Compiler` subclasses have
|
||||
``**kw``, for #8988.
|
||||
|
||||
This test uses runtime code inspection. Does not need to be a
|
||||
``__backend__`` test as it only needs to run once provided all target
|
||||
dialects have been imported.
|
||||
|
||||
For third party dialects, the suite would be run with that third
|
||||
party as a "--dburi", which means its compiler classes will have been
|
||||
imported by the time this test runs.
|
||||
|
||||
"""
|
||||
|
||||
def _all_subclasses(): # type: ignore # noqa
|
||||
for d in dialects.__all__:
|
||||
if not d.startswith("_"):
|
||||
importlib.import_module("sqlalchemy.dialects.%s" % d)
|
||||
|
||||
stack = [Compiled]
|
||||
|
||||
while stack:
|
||||
cls = stack.pop(0)
|
||||
stack.extend(cls.__subclasses__())
|
||||
yield cls
|
||||
|
||||
@testing.fixture(params=list(_all_subclasses()))
|
||||
def all_subclasses(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_all_visit_methods_accept_kw(self, all_subclasses):
|
||||
cls = all_subclasses
|
||||
|
||||
for k in cls.__dict__:
|
||||
if k.startswith("visit_"):
|
||||
meth = getattr(cls, k)
|
||||
|
||||
insp = inspect_getfullargspec(meth)
|
||||
is_not_none(
|
||||
insp.varkw,
|
||||
f"Compiler visit method {cls.__name__}.{k}() does "
|
||||
"not accommodate for **kw in its argument signature",
|
||||
)
|
||||
|
||||
|
||||
class ExceptionTest(fixtures.TablesTest):
|
||||
"""Test basic exception wrapping.
|
||||
|
||||
DBAPIs vary a lot in exception behavior so to actually anticipate
|
||||
specific exceptions from real round trips, we need to be conservative.
|
||||
|
||||
"""
|
||||
|
||||
run_deletes = "each"
|
||||
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"manual_pk",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
@requirements.duplicate_key_raises_integrity_error
|
||||
def test_integrity_error(self):
|
||||
with config.db.connect() as conn:
|
||||
trans = conn.begin()
|
||||
conn.execute(
|
||||
self.tables.manual_pk.insert(), {"id": 1, "data": "d1"}
|
||||
)
|
||||
|
||||
assert_raises(
|
||||
exc.IntegrityError,
|
||||
conn.execute,
|
||||
self.tables.manual_pk.insert(),
|
||||
{"id": 1, "data": "d1"},
|
||||
)
|
||||
|
||||
trans.rollback()
|
||||
|
||||
def test_exception_with_non_ascii(self):
|
||||
with config.db.connect() as conn:
|
||||
try:
|
||||
# try to create an error message that likely has non-ascii
|
||||
# characters in the DBAPI's message string. unfortunately
|
||||
# there's no way to make this happen with some drivers like
|
||||
# mysqlclient, pymysql. this at least does produce a non-
|
||||
# ascii error message for cx_oracle, psycopg2
|
||||
conn.execute(select(literal_column("méil")))
|
||||
assert False
|
||||
except exc.DBAPIError as err:
|
||||
err_str = str(err)
|
||||
|
||||
assert str(err.orig) in str(err)
|
||||
|
||||
assert isinstance(err_str, str)
|
||||
|
||||
|
||||
class IsolationLevelTest(fixtures.TestBase):
|
||||
__backend__ = True
|
||||
|
||||
__requires__ = ("isolation_level",)
|
||||
|
||||
def _get_non_default_isolation_level(self):
|
||||
levels = requirements.get_isolation_levels(config)
|
||||
|
||||
default = levels["default"]
|
||||
supported = levels["supported"]
|
||||
|
||||
s = set(supported).difference(["AUTOCOMMIT", default])
|
||||
if s:
|
||||
return s.pop()
|
||||
else:
|
||||
config.skip_test("no non-default isolation level available")
|
||||
|
||||
def test_default_isolation_level(self):
|
||||
eq_(
|
||||
config.db.dialect.default_isolation_level,
|
||||
requirements.get_isolation_levels(config)["default"],
|
||||
)
|
||||
|
||||
def test_non_default_isolation_level(self):
|
||||
non_default = self._get_non_default_isolation_level()
|
||||
|
||||
with config.db.connect() as conn:
|
||||
existing = conn.get_isolation_level()
|
||||
|
||||
ne_(existing, non_default)
|
||||
|
||||
conn.execution_options(isolation_level=non_default)
|
||||
|
||||
eq_(conn.get_isolation_level(), non_default)
|
||||
|
||||
conn.dialect.reset_isolation_level(
|
||||
conn.connection.dbapi_connection
|
||||
)
|
||||
|
||||
eq_(conn.get_isolation_level(), existing)
|
||||
|
||||
def test_all_levels(self):
|
||||
levels = requirements.get_isolation_levels(config)
|
||||
|
||||
all_levels = levels["supported"]
|
||||
|
||||
for level in set(all_levels).difference(["AUTOCOMMIT"]):
|
||||
with config.db.connect() as conn:
|
||||
conn.execution_options(isolation_level=level)
|
||||
|
||||
eq_(conn.get_isolation_level(), level)
|
||||
|
||||
trans = conn.begin()
|
||||
trans.rollback()
|
||||
|
||||
eq_(conn.get_isolation_level(), level)
|
||||
|
||||
with config.db.connect() as conn:
|
||||
eq_(
|
||||
conn.get_isolation_level(),
|
||||
levels["default"],
|
||||
)
|
||||
|
||||
@testing.requires.get_isolation_level_values
|
||||
def test_invalid_level_execution_option(self, connection_no_trans):
|
||||
"""test for the new get_isolation_level_values() method"""
|
||||
|
||||
connection = connection_no_trans
|
||||
with expect_raises_message(
|
||||
exc.ArgumentError,
|
||||
"Invalid value '%s' for isolation_level. "
|
||||
"Valid isolation levels for '%s' are %s"
|
||||
% (
|
||||
"FOO",
|
||||
connection.dialect.name,
|
||||
", ".join(
|
||||
requirements.get_isolation_levels(config)["supported"]
|
||||
),
|
||||
),
|
||||
):
|
||||
connection.execution_options(isolation_level="FOO")
|
||||
|
||||
@testing.requires.get_isolation_level_values
|
||||
@testing.requires.dialect_level_isolation_level_param
|
||||
def test_invalid_level_engine_param(self, testing_engine):
|
||||
"""test for the new get_isolation_level_values() method
|
||||
and support for the dialect-level 'isolation_level' parameter.
|
||||
|
||||
"""
|
||||
|
||||
eng = testing_engine(options=dict(isolation_level="FOO"))
|
||||
with expect_raises_message(
|
||||
exc.ArgumentError,
|
||||
"Invalid value '%s' for isolation_level. "
|
||||
"Valid isolation levels for '%s' are %s"
|
||||
% (
|
||||
"FOO",
|
||||
eng.dialect.name,
|
||||
", ".join(
|
||||
requirements.get_isolation_levels(config)["supported"]
|
||||
),
|
||||
),
|
||||
):
|
||||
eng.connect()
|
||||
|
||||
@testing.requires.independent_readonly_connections
|
||||
def test_dialect_user_setting_is_restored(self, testing_engine):
|
||||
levels = requirements.get_isolation_levels(config)
|
||||
default = levels["default"]
|
||||
supported = (
|
||||
sorted(
|
||||
set(levels["supported"]).difference([default, "AUTOCOMMIT"])
|
||||
)
|
||||
)[0]
|
||||
|
||||
e = testing_engine(options={"isolation_level": supported})
|
||||
|
||||
with e.connect() as conn:
|
||||
eq_(conn.get_isolation_level(), supported)
|
||||
|
||||
with e.connect() as conn:
|
||||
conn.execution_options(isolation_level=default)
|
||||
eq_(conn.get_isolation_level(), default)
|
||||
|
||||
with e.connect() as conn:
|
||||
eq_(conn.get_isolation_level(), supported)
|
||||
|
||||
|
||||
class AutocommitIsolationTest(fixtures.TablesTest):
|
||||
run_deletes = "each"
|
||||
|
||||
__requires__ = ("autocommit",)
|
||||
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"some_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("data", String(50)),
|
||||
test_needs_acid=True,
|
||||
)
|
||||
|
||||
def _test_conn_autocommits(self, conn, autocommit):
|
||||
trans = conn.begin()
|
||||
conn.execute(
|
||||
self.tables.some_table.insert(), {"id": 1, "data": "some data"}
|
||||
)
|
||||
trans.rollback()
|
||||
|
||||
eq_(
|
||||
conn.scalar(select(self.tables.some_table.c.id)),
|
||||
1 if autocommit else None,
|
||||
)
|
||||
conn.rollback()
|
||||
|
||||
with conn.begin():
|
||||
conn.execute(self.tables.some_table.delete())
|
||||
|
||||
def test_autocommit_on(self, connection_no_trans):
|
||||
conn = connection_no_trans
|
||||
c2 = conn.execution_options(isolation_level="AUTOCOMMIT")
|
||||
self._test_conn_autocommits(c2, True)
|
||||
|
||||
c2.dialect.reset_isolation_level(c2.connection.dbapi_connection)
|
||||
|
||||
self._test_conn_autocommits(conn, False)
|
||||
|
||||
def test_autocommit_off(self, connection_no_trans):
|
||||
conn = connection_no_trans
|
||||
self._test_conn_autocommits(conn, False)
|
||||
|
||||
def test_turn_autocommit_off_via_default_iso_level(
|
||||
self, connection_no_trans
|
||||
):
|
||||
conn = connection_no_trans
|
||||
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
|
||||
self._test_conn_autocommits(conn, True)
|
||||
|
||||
conn.execution_options(
|
||||
isolation_level=requirements.get_isolation_levels(config)[
|
||||
"default"
|
||||
]
|
||||
)
|
||||
self._test_conn_autocommits(conn, False)
|
||||
|
||||
@testing.requires.independent_readonly_connections
|
||||
@testing.variation("use_dialect_setting", [True, False])
|
||||
def test_dialect_autocommit_is_restored(
|
||||
self, testing_engine, use_dialect_setting
|
||||
):
|
||||
"""test #10147"""
|
||||
|
||||
if use_dialect_setting:
|
||||
e = testing_engine(options={"isolation_level": "AUTOCOMMIT"})
|
||||
else:
|
||||
e = testing_engine().execution_options(
|
||||
isolation_level="AUTOCOMMIT"
|
||||
)
|
||||
|
||||
levels = requirements.get_isolation_levels(config)
|
||||
|
||||
default = levels["default"]
|
||||
|
||||
with e.connect() as conn:
|
||||
self._test_conn_autocommits(conn, True)
|
||||
|
||||
with e.connect() as conn:
|
||||
conn.execution_options(isolation_level=default)
|
||||
self._test_conn_autocommits(conn, False)
|
||||
|
||||
with e.connect() as conn:
|
||||
self._test_conn_autocommits(conn, True)
|
||||
|
||||
|
||||
class EscapingTest(fixtures.TestBase):
|
||||
@provide_metadata
|
||||
def test_percent_sign_round_trip(self):
|
||||
"""test that the DBAPI accommodates for escaped / nonescaped
|
||||
percent signs in a way that matches the compiler
|
||||
|
||||
"""
|
||||
m = self.metadata
|
||||
t = Table("t", m, Column("data", String(50)))
|
||||
t.create(config.db)
|
||||
with config.db.begin() as conn:
|
||||
conn.execute(t.insert(), dict(data="some % value"))
|
||||
conn.execute(t.insert(), dict(data="some %% other value"))
|
||||
|
||||
eq_(
|
||||
conn.scalar(
|
||||
select(t.c.data).where(
|
||||
t.c.data == literal_column("'some % value'")
|
||||
)
|
||||
),
|
||||
"some % value",
|
||||
)
|
||||
|
||||
eq_(
|
||||
conn.scalar(
|
||||
select(t.c.data).where(
|
||||
t.c.data == literal_column("'some %% other value'")
|
||||
)
|
||||
),
|
||||
"some %% other value",
|
||||
)
|
||||
|
||||
|
||||
class WeCanSetDefaultSchemaWEventsTest(fixtures.TestBase):
|
||||
__backend__ = True
|
||||
|
||||
__requires__ = ("default_schema_name_switch",)
|
||||
|
||||
def test_control_case(self):
|
||||
default_schema_name = config.db.dialect.default_schema_name
|
||||
|
||||
eng = engines.testing_engine()
|
||||
with eng.connect():
|
||||
pass
|
||||
|
||||
eq_(eng.dialect.default_schema_name, default_schema_name)
|
||||
|
||||
def test_wont_work_wo_insert(self):
|
||||
default_schema_name = config.db.dialect.default_schema_name
|
||||
|
||||
eng = engines.testing_engine()
|
||||
|
||||
@event.listens_for(eng, "connect")
|
||||
def on_connect(dbapi_connection, connection_record):
|
||||
set_default_schema_on_connection(
|
||||
config, dbapi_connection, config.test_schema
|
||||
)
|
||||
|
||||
with eng.connect() as conn:
|
||||
what_it_should_be = eng.dialect._get_default_schema_name(conn)
|
||||
eq_(what_it_should_be, config.test_schema)
|
||||
|
||||
eq_(eng.dialect.default_schema_name, default_schema_name)
|
||||
|
||||
def test_schema_change_on_connect(self):
|
||||
eng = engines.testing_engine()
|
||||
|
||||
@event.listens_for(eng, "connect", insert=True)
|
||||
def on_connect(dbapi_connection, connection_record):
|
||||
set_default_schema_on_connection(
|
||||
config, dbapi_connection, config.test_schema
|
||||
)
|
||||
|
||||
with eng.connect() as conn:
|
||||
what_it_should_be = eng.dialect._get_default_schema_name(conn)
|
||||
eq_(what_it_should_be, config.test_schema)
|
||||
|
||||
eq_(eng.dialect.default_schema_name, config.test_schema)
|
||||
|
||||
def test_schema_change_works_w_transactions(self):
|
||||
eng = engines.testing_engine()
|
||||
|
||||
@event.listens_for(eng, "connect", insert=True)
|
||||
def on_connect(dbapi_connection, *arg):
|
||||
set_default_schema_on_connection(
|
||||
config, dbapi_connection, config.test_schema
|
||||
)
|
||||
|
||||
with eng.connect() as conn:
|
||||
trans = conn.begin()
|
||||
what_it_should_be = eng.dialect._get_default_schema_name(conn)
|
||||
eq_(what_it_should_be, config.test_schema)
|
||||
trans.rollback()
|
||||
|
||||
what_it_should_be = eng.dialect._get_default_schema_name(conn)
|
||||
eq_(what_it_should_be, config.test_schema)
|
||||
|
||||
eq_(eng.dialect.default_schema_name, config.test_schema)
|
||||
|
||||
|
||||
class FutureWeCanSetDefaultSchemaWEventsTest(
|
||||
fixtures.FutureEngineMixin, WeCanSetDefaultSchemaWEventsTest
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class DifficultParametersTest(fixtures.TestBase):
|
||||
__backend__ = True
|
||||
|
||||
tough_parameters = testing.combinations(
|
||||
("boring",),
|
||||
("per cent",),
|
||||
("per % cent",),
|
||||
("%percent",),
|
||||
("par(ens)",),
|
||||
("percent%(ens)yah",),
|
||||
("col:ons",),
|
||||
("_starts_with_underscore",),
|
||||
("dot.s",),
|
||||
("more :: %colons%",),
|
||||
("_name",),
|
||||
("___name",),
|
||||
("[BracketsAndCase]",),
|
||||
("42numbers",),
|
||||
("percent%signs",),
|
||||
("has spaces",),
|
||||
("/slashes/",),
|
||||
("more/slashes",),
|
||||
("q?marks",),
|
||||
("1param",),
|
||||
("1col:on",),
|
||||
argnames="paramname",
|
||||
)
|
||||
|
||||
@tough_parameters
|
||||
@config.requirements.unusual_column_name_characters
|
||||
def test_round_trip_same_named_column(
|
||||
self, paramname, connection, metadata
|
||||
):
|
||||
name = paramname
|
||||
|
||||
t = Table(
|
||||
"t",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column(name, String(50), nullable=False),
|
||||
)
|
||||
|
||||
# table is created
|
||||
t.create(connection)
|
||||
|
||||
# automatic param generated by insert
|
||||
connection.execute(t.insert().values({"id": 1, name: "some name"}))
|
||||
|
||||
# automatic param generated by criteria, plus selecting the column
|
||||
stmt = select(t.c[name]).where(t.c[name] == "some name")
|
||||
|
||||
eq_(connection.scalar(stmt), "some name")
|
||||
|
||||
# use the name in a param explicitly
|
||||
stmt = select(t.c[name]).where(t.c[name] == bindparam(name))
|
||||
|
||||
row = connection.execute(stmt, {name: "some name"}).first()
|
||||
|
||||
# name works as the key from cursor.description
|
||||
eq_(row._mapping[name], "some name")
|
||||
|
||||
# use expanding IN
|
||||
stmt = select(t.c[name]).where(
|
||||
t.c[name].in_(["some name", "some other_name"])
|
||||
)
|
||||
|
||||
row = connection.execute(stmt).first()
|
||||
|
||||
@testing.fixture
|
||||
def multirow_fixture(self, metadata, connection):
|
||||
mytable = Table(
|
||||
"mytable",
|
||||
metadata,
|
||||
Column("myid", Integer),
|
||||
Column("name", String(50)),
|
||||
Column("desc", String(50)),
|
||||
)
|
||||
|
||||
mytable.create(connection)
|
||||
|
||||
connection.execute(
|
||||
mytable.insert(),
|
||||
[
|
||||
{"myid": 1, "name": "a", "desc": "a_desc"},
|
||||
{"myid": 2, "name": "b", "desc": "b_desc"},
|
||||
{"myid": 3, "name": "c", "desc": "c_desc"},
|
||||
{"myid": 4, "name": "d", "desc": "d_desc"},
|
||||
],
|
||||
)
|
||||
yield mytable
|
||||
|
||||
@tough_parameters
|
||||
def test_standalone_bindparam_escape(
|
||||
self, paramname, connection, multirow_fixture
|
||||
):
|
||||
tbl1 = multirow_fixture
|
||||
stmt = select(tbl1.c.myid).where(
|
||||
tbl1.c.name == bindparam(paramname, value="x")
|
||||
)
|
||||
res = connection.scalar(stmt, {paramname: "c"})
|
||||
eq_(res, 3)
|
||||
|
||||
@tough_parameters
|
||||
def test_standalone_bindparam_escape_expanding(
|
||||
self, paramname, connection, multirow_fixture
|
||||
):
|
||||
tbl1 = multirow_fixture
|
||||
stmt = (
|
||||
select(tbl1.c.myid)
|
||||
.where(tbl1.c.name.in_(bindparam(paramname, value=["a", "b"])))
|
||||
.order_by(tbl1.c.myid)
|
||||
)
|
||||
|
||||
res = connection.scalars(stmt, {paramname: ["d", "a"]}).all()
|
||||
eq_(res, [1, 4])
|
||||
|
||||
|
||||
class ReturningGuardsTest(fixtures.TablesTest):
|
||||
"""test that the various 'returning' flags are set appropriately"""
|
||||
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"t",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
@testing.fixture
|
||||
def run_stmt(self, connection):
|
||||
t = self.tables.t
|
||||
|
||||
def go(stmt, executemany, id_param_name, expect_success):
|
||||
stmt = stmt.returning(t.c.id)
|
||||
|
||||
if executemany:
|
||||
if not expect_success:
|
||||
# for RETURNING executemany(), we raise our own
|
||||
# error as this is independent of general RETURNING
|
||||
# support
|
||||
with expect_raises_message(
|
||||
exc.StatementError,
|
||||
rf"Dialect {connection.dialect.name}\+"
|
||||
f"{connection.dialect.driver} with "
|
||||
f"current server capabilities does not support "
|
||||
f".*RETURNING when executemany is used",
|
||||
):
|
||||
result = connection.execute(
|
||||
stmt,
|
||||
[
|
||||
{id_param_name: 1, "data": "d1"},
|
||||
{id_param_name: 2, "data": "d2"},
|
||||
{id_param_name: 3, "data": "d3"},
|
||||
],
|
||||
)
|
||||
else:
|
||||
result = connection.execute(
|
||||
stmt,
|
||||
[
|
||||
{id_param_name: 1, "data": "d1"},
|
||||
{id_param_name: 2, "data": "d2"},
|
||||
{id_param_name: 3, "data": "d3"},
|
||||
],
|
||||
)
|
||||
eq_(result.all(), [(1,), (2,), (3,)])
|
||||
else:
|
||||
if not expect_success:
|
||||
# for RETURNING execute(), we pass all the way to the DB
|
||||
# and let it fail
|
||||
with expect_raises(exc.DBAPIError):
|
||||
connection.execute(
|
||||
stmt, {id_param_name: 1, "data": "d1"}
|
||||
)
|
||||
else:
|
||||
result = connection.execute(
|
||||
stmt, {id_param_name: 1, "data": "d1"}
|
||||
)
|
||||
eq_(result.all(), [(1,)])
|
||||
|
||||
return go
|
||||
|
||||
def test_insert_single(self, connection, run_stmt):
|
||||
t = self.tables.t
|
||||
|
||||
stmt = t.insert()
|
||||
|
||||
run_stmt(stmt, False, "id", connection.dialect.insert_returning)
|
||||
|
||||
def test_insert_many(self, connection, run_stmt):
|
||||
t = self.tables.t
|
||||
|
||||
stmt = t.insert()
|
||||
|
||||
run_stmt(
|
||||
stmt, True, "id", connection.dialect.insert_executemany_returning
|
||||
)
|
||||
|
||||
def test_update_single(self, connection, run_stmt):
|
||||
t = self.tables.t
|
||||
|
||||
connection.execute(
|
||||
t.insert(),
|
||||
[
|
||||
{"id": 1, "data": "d1"},
|
||||
{"id": 2, "data": "d2"},
|
||||
{"id": 3, "data": "d3"},
|
||||
],
|
||||
)
|
||||
|
||||
stmt = t.update().where(t.c.id == bindparam("b_id"))
|
||||
|
||||
run_stmt(stmt, False, "b_id", connection.dialect.update_returning)
|
||||
|
||||
def test_update_many(self, connection, run_stmt):
|
||||
t = self.tables.t
|
||||
|
||||
connection.execute(
|
||||
t.insert(),
|
||||
[
|
||||
{"id": 1, "data": "d1"},
|
||||
{"id": 2, "data": "d2"},
|
||||
{"id": 3, "data": "d3"},
|
||||
],
|
||||
)
|
||||
|
||||
stmt = t.update().where(t.c.id == bindparam("b_id"))
|
||||
|
||||
run_stmt(
|
||||
stmt, True, "b_id", connection.dialect.update_executemany_returning
|
||||
)
|
||||
|
||||
def test_delete_single(self, connection, run_stmt):
|
||||
t = self.tables.t
|
||||
|
||||
connection.execute(
|
||||
t.insert(),
|
||||
[
|
||||
{"id": 1, "data": "d1"},
|
||||
{"id": 2, "data": "d2"},
|
||||
{"id": 3, "data": "d3"},
|
||||
],
|
||||
)
|
||||
|
||||
stmt = t.delete().where(t.c.id == bindparam("b_id"))
|
||||
|
||||
run_stmt(stmt, False, "b_id", connection.dialect.delete_returning)
|
||||
|
||||
def test_delete_many(self, connection, run_stmt):
|
||||
t = self.tables.t
|
||||
|
||||
connection.execute(
|
||||
t.insert(),
|
||||
[
|
||||
{"id": 1, "data": "d1"},
|
||||
{"id": 2, "data": "d2"},
|
||||
{"id": 3, "data": "d3"},
|
||||
],
|
||||
)
|
||||
|
||||
stmt = t.delete().where(t.c.id == bindparam("b_id"))
|
||||
|
||||
run_stmt(
|
||||
stmt, True, "b_id", connection.dialect.delete_executemany_returning
|
||||
)
|
@ -0,0 +1,630 @@
|
||||
# testing/suite/test_insert.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from decimal import Decimal
|
||||
import uuid
|
||||
|
||||
from . import testing
|
||||
from .. import fixtures
|
||||
from ..assertions import eq_
|
||||
from ..config import requirements
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import Double
|
||||
from ... import Float
|
||||
from ... import Identity
|
||||
from ... import Integer
|
||||
from ... import literal
|
||||
from ... import literal_column
|
||||
from ... import Numeric
|
||||
from ... import select
|
||||
from ... import String
|
||||
from ...types import LargeBinary
|
||||
from ...types import UUID
|
||||
from ...types import Uuid
|
||||
|
||||
|
||||
class LastrowidTest(fixtures.TablesTest):
|
||||
run_deletes = "each"
|
||||
|
||||
__backend__ = True
|
||||
|
||||
__requires__ = "implements_get_lastrowid", "autoincrement_insert"
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"autoinc_pk",
|
||||
metadata,
|
||||
Column(
|
||||
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
||||
),
|
||||
Column("data", String(50)),
|
||||
implicit_returning=False,
|
||||
)
|
||||
|
||||
Table(
|
||||
"manual_pk",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("data", String(50)),
|
||||
implicit_returning=False,
|
||||
)
|
||||
|
||||
def _assert_round_trip(self, table, conn):
|
||||
row = conn.execute(table.select()).first()
|
||||
eq_(
|
||||
row,
|
||||
(
|
||||
conn.dialect.default_sequence_base,
|
||||
"some data",
|
||||
),
|
||||
)
|
||||
|
||||
def test_autoincrement_on_insert(self, connection):
|
||||
connection.execute(
|
||||
self.tables.autoinc_pk.insert(), dict(data="some data")
|
||||
)
|
||||
self._assert_round_trip(self.tables.autoinc_pk, connection)
|
||||
|
||||
def test_last_inserted_id(self, connection):
|
||||
r = connection.execute(
|
||||
self.tables.autoinc_pk.insert(), dict(data="some data")
|
||||
)
|
||||
pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
|
||||
eq_(r.inserted_primary_key, (pk,))
|
||||
|
||||
@requirements.dbapi_lastrowid
|
||||
def test_native_lastrowid_autoinc(self, connection):
|
||||
r = connection.execute(
|
||||
self.tables.autoinc_pk.insert(), dict(data="some data")
|
||||
)
|
||||
lastrowid = r.lastrowid
|
||||
pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
|
||||
eq_(lastrowid, pk)
|
||||
|
||||
|
||||
class InsertBehaviorTest(fixtures.TablesTest):
|
||||
run_deletes = "each"
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"autoinc_pk",
|
||||
metadata,
|
||||
Column(
|
||||
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
||||
),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
Table(
|
||||
"manual_pk",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
Table(
|
||||
"no_implicit_returning",
|
||||
metadata,
|
||||
Column(
|
||||
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
||||
),
|
||||
Column("data", String(50)),
|
||||
implicit_returning=False,
|
||||
)
|
||||
Table(
|
||||
"includes_defaults",
|
||||
metadata,
|
||||
Column(
|
||||
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
||||
),
|
||||
Column("data", String(50)),
|
||||
Column("x", Integer, default=5),
|
||||
Column(
|
||||
"y",
|
||||
Integer,
|
||||
default=literal_column("2", type_=Integer) + literal(2),
|
||||
),
|
||||
)
|
||||
|
||||
@testing.variation("style", ["plain", "return_defaults"])
|
||||
@testing.variation("executemany", [True, False])
|
||||
def test_no_results_for_non_returning_insert(
|
||||
self, connection, style, executemany
|
||||
):
|
||||
"""test another INSERT issue found during #10453"""
|
||||
|
||||
table = self.tables.no_implicit_returning
|
||||
|
||||
stmt = table.insert()
|
||||
if style.return_defaults:
|
||||
stmt = stmt.return_defaults()
|
||||
|
||||
if executemany:
|
||||
data = [
|
||||
{"data": "d1"},
|
||||
{"data": "d2"},
|
||||
{"data": "d3"},
|
||||
{"data": "d4"},
|
||||
{"data": "d5"},
|
||||
]
|
||||
else:
|
||||
data = {"data": "d1"}
|
||||
|
||||
r = connection.execute(stmt, data)
|
||||
assert not r.returns_rows
|
||||
|
||||
@requirements.autoincrement_insert
|
||||
def test_autoclose_on_insert(self, connection):
|
||||
r = connection.execute(
|
||||
self.tables.autoinc_pk.insert(), dict(data="some data")
|
||||
)
|
||||
assert r._soft_closed
|
||||
assert not r.closed
|
||||
assert r.is_insert
|
||||
|
||||
# new as of I8091919d45421e3f53029b8660427f844fee0228; for the moment
|
||||
# an insert where the PK was taken from a row that the dialect
|
||||
# selected, as is the case for mssql/pyodbc, will still report
|
||||
# returns_rows as true because there's a cursor description. in that
|
||||
# case, the row had to have been consumed at least.
|
||||
assert not r.returns_rows or r.fetchone() is None
|
||||
|
||||
@requirements.insert_returning
|
||||
def test_autoclose_on_insert_implicit_returning(self, connection):
|
||||
r = connection.execute(
|
||||
# return_defaults() ensures RETURNING will be used,
|
||||
# new in 2.0 as sqlite/mariadb offer both RETURNING and
|
||||
# cursor.lastrowid
|
||||
self.tables.autoinc_pk.insert().return_defaults(),
|
||||
dict(data="some data"),
|
||||
)
|
||||
assert r._soft_closed
|
||||
assert not r.closed
|
||||
assert r.is_insert
|
||||
|
||||
# note we are experimenting with having this be True
|
||||
# as of I8091919d45421e3f53029b8660427f844fee0228 .
|
||||
# implicit returning has fetched the row, but it still is a
|
||||
# "returns rows"
|
||||
assert r.returns_rows
|
||||
|
||||
# and we should be able to fetchone() on it, we just get no row
|
||||
eq_(r.fetchone(), None)
|
||||
|
||||
# and the keys, etc.
|
||||
eq_(r.keys(), ["id"])
|
||||
|
||||
# but the dialect took in the row already. not really sure
|
||||
# what the best behavior is.
|
||||
|
||||
@requirements.empty_inserts
|
||||
def test_empty_insert(self, connection):
|
||||
r = connection.execute(self.tables.autoinc_pk.insert())
|
||||
assert r._soft_closed
|
||||
assert not r.closed
|
||||
|
||||
r = connection.execute(
|
||||
self.tables.autoinc_pk.select().where(
|
||||
self.tables.autoinc_pk.c.id != None
|
||||
)
|
||||
)
|
||||
eq_(len(r.all()), 1)
|
||||
|
||||
@requirements.empty_inserts_executemany
|
||||
def test_empty_insert_multiple(self, connection):
|
||||
r = connection.execute(self.tables.autoinc_pk.insert(), [{}, {}, {}])
|
||||
assert r._soft_closed
|
||||
assert not r.closed
|
||||
|
||||
r = connection.execute(
|
||||
self.tables.autoinc_pk.select().where(
|
||||
self.tables.autoinc_pk.c.id != None
|
||||
)
|
||||
)
|
||||
|
||||
eq_(len(r.all()), 3)
|
||||
|
||||
@requirements.insert_from_select
|
||||
def test_insert_from_select_autoinc(self, connection):
|
||||
src_table = self.tables.manual_pk
|
||||
dest_table = self.tables.autoinc_pk
|
||||
connection.execute(
|
||||
src_table.insert(),
|
||||
[
|
||||
dict(id=1, data="data1"),
|
||||
dict(id=2, data="data2"),
|
||||
dict(id=3, data="data3"),
|
||||
],
|
||||
)
|
||||
|
||||
result = connection.execute(
|
||||
dest_table.insert().from_select(
|
||||
("data",),
|
||||
select(src_table.c.data).where(
|
||||
src_table.c.data.in_(["data2", "data3"])
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
eq_(result.inserted_primary_key, (None,))
|
||||
|
||||
result = connection.execute(
|
||||
select(dest_table.c.data).order_by(dest_table.c.data)
|
||||
)
|
||||
eq_(result.fetchall(), [("data2",), ("data3",)])
|
||||
|
||||
@requirements.insert_from_select
|
||||
def test_insert_from_select_autoinc_no_rows(self, connection):
|
||||
src_table = self.tables.manual_pk
|
||||
dest_table = self.tables.autoinc_pk
|
||||
|
||||
result = connection.execute(
|
||||
dest_table.insert().from_select(
|
||||
("data",),
|
||||
select(src_table.c.data).where(
|
||||
src_table.c.data.in_(["data2", "data3"])
|
||||
),
|
||||
)
|
||||
)
|
||||
eq_(result.inserted_primary_key, (None,))
|
||||
|
||||
result = connection.execute(
|
||||
select(dest_table.c.data).order_by(dest_table.c.data)
|
||||
)
|
||||
|
||||
eq_(result.fetchall(), [])
|
||||
|
||||
@requirements.insert_from_select
|
||||
def test_insert_from_select(self, connection):
|
||||
table = self.tables.manual_pk
|
||||
connection.execute(
|
||||
table.insert(),
|
||||
[
|
||||
dict(id=1, data="data1"),
|
||||
dict(id=2, data="data2"),
|
||||
dict(id=3, data="data3"),
|
||||
],
|
||||
)
|
||||
|
||||
connection.execute(
|
||||
table.insert()
|
||||
.inline()
|
||||
.from_select(
|
||||
("id", "data"),
|
||||
select(table.c.id + 5, table.c.data).where(
|
||||
table.c.data.in_(["data2", "data3"])
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
eq_(
|
||||
connection.execute(
|
||||
select(table.c.data).order_by(table.c.data)
|
||||
).fetchall(),
|
||||
[("data1",), ("data2",), ("data2",), ("data3",), ("data3",)],
|
||||
)
|
||||
|
||||
@requirements.insert_from_select
|
||||
def test_insert_from_select_with_defaults(self, connection):
|
||||
table = self.tables.includes_defaults
|
||||
connection.execute(
|
||||
table.insert(),
|
||||
[
|
||||
dict(id=1, data="data1"),
|
||||
dict(id=2, data="data2"),
|
||||
dict(id=3, data="data3"),
|
||||
],
|
||||
)
|
||||
|
||||
connection.execute(
|
||||
table.insert()
|
||||
.inline()
|
||||
.from_select(
|
||||
("id", "data"),
|
||||
select(table.c.id + 5, table.c.data).where(
|
||||
table.c.data.in_(["data2", "data3"])
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
eq_(
|
||||
connection.execute(
|
||||
select(table).order_by(table.c.data, table.c.id)
|
||||
).fetchall(),
|
||||
[
|
||||
(1, "data1", 5, 4),
|
||||
(2, "data2", 5, 4),
|
||||
(7, "data2", 5, 4),
|
||||
(3, "data3", 5, 4),
|
||||
(8, "data3", 5, 4),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class ReturningTest(fixtures.TablesTest):
|
||||
run_create_tables = "each"
|
||||
__requires__ = "insert_returning", "autoincrement_insert"
|
||||
__backend__ = True
|
||||
|
||||
def _assert_round_trip(self, table, conn):
|
||||
row = conn.execute(table.select()).first()
|
||||
eq_(
|
||||
row,
|
||||
(
|
||||
conn.dialect.default_sequence_base,
|
||||
"some data",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"autoinc_pk",
|
||||
metadata,
|
||||
Column(
|
||||
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
||||
),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
@requirements.fetch_rows_post_commit
|
||||
def test_explicit_returning_pk_autocommit(self, connection):
|
||||
table = self.tables.autoinc_pk
|
||||
r = connection.execute(
|
||||
table.insert().returning(table.c.id), dict(data="some data")
|
||||
)
|
||||
pk = r.first()[0]
|
||||
fetched_pk = connection.scalar(select(table.c.id))
|
||||
eq_(fetched_pk, pk)
|
||||
|
||||
def test_explicit_returning_pk_no_autocommit(self, connection):
|
||||
table = self.tables.autoinc_pk
|
||||
r = connection.execute(
|
||||
table.insert().returning(table.c.id), dict(data="some data")
|
||||
)
|
||||
|
||||
pk = r.first()[0]
|
||||
fetched_pk = connection.scalar(select(table.c.id))
|
||||
eq_(fetched_pk, pk)
|
||||
|
||||
def test_autoincrement_on_insert_implicit_returning(self, connection):
|
||||
connection.execute(
|
||||
self.tables.autoinc_pk.insert(), dict(data="some data")
|
||||
)
|
||||
self._assert_round_trip(self.tables.autoinc_pk, connection)
|
||||
|
||||
def test_last_inserted_id_implicit_returning(self, connection):
|
||||
r = connection.execute(
|
||||
self.tables.autoinc_pk.insert(), dict(data="some data")
|
||||
)
|
||||
pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
|
||||
eq_(r.inserted_primary_key, (pk,))
|
||||
|
||||
@requirements.insert_executemany_returning
|
||||
def test_insertmanyvalues_returning(self, connection):
|
||||
r = connection.execute(
|
||||
self.tables.autoinc_pk.insert().returning(
|
||||
self.tables.autoinc_pk.c.id
|
||||
),
|
||||
[
|
||||
{"data": "d1"},
|
||||
{"data": "d2"},
|
||||
{"data": "d3"},
|
||||
{"data": "d4"},
|
||||
{"data": "d5"},
|
||||
],
|
||||
)
|
||||
rall = r.all()
|
||||
|
||||
pks = connection.execute(select(self.tables.autoinc_pk.c.id))
|
||||
|
||||
eq_(rall, pks.all())
|
||||
|
||||
@testing.combinations(
|
||||
(Double(), 8.5514716, True),
|
||||
(
|
||||
Double(53),
|
||||
8.5514716,
|
||||
True,
|
||||
testing.requires.float_or_double_precision_behaves_generically,
|
||||
),
|
||||
(Float(), 8.5514, True),
|
||||
(
|
||||
Float(8),
|
||||
8.5514,
|
||||
True,
|
||||
testing.requires.float_or_double_precision_behaves_generically,
|
||||
),
|
||||
(
|
||||
Numeric(precision=15, scale=12, asdecimal=False),
|
||||
8.5514716,
|
||||
True,
|
||||
testing.requires.literal_float_coercion,
|
||||
),
|
||||
(
|
||||
Numeric(precision=15, scale=12, asdecimal=True),
|
||||
Decimal("8.5514716"),
|
||||
False,
|
||||
),
|
||||
argnames="type_,value,do_rounding",
|
||||
)
|
||||
@testing.variation("sort_by_parameter_order", [True, False])
|
||||
@testing.variation("multiple_rows", [True, False])
|
||||
def test_insert_w_floats(
|
||||
self,
|
||||
connection,
|
||||
metadata,
|
||||
sort_by_parameter_order,
|
||||
type_,
|
||||
value,
|
||||
do_rounding,
|
||||
multiple_rows,
|
||||
):
|
||||
"""test #9701.
|
||||
|
||||
this tests insertmanyvalues as well as decimal / floating point
|
||||
RETURNING types
|
||||
|
||||
"""
|
||||
|
||||
t = Table(
|
||||
# Oracle backends seems to be getting confused if
|
||||
# this table is named the same as the one
|
||||
# in test_imv_returning_datatypes. use a different name
|
||||
"f_t",
|
||||
metadata,
|
||||
Column("id", Integer, Identity(), primary_key=True),
|
||||
Column("value", type_),
|
||||
)
|
||||
|
||||
t.create(connection)
|
||||
|
||||
result = connection.execute(
|
||||
t.insert().returning(
|
||||
t.c.id,
|
||||
t.c.value,
|
||||
sort_by_parameter_order=bool(sort_by_parameter_order),
|
||||
),
|
||||
(
|
||||
[{"value": value} for i in range(10)]
|
||||
if multiple_rows
|
||||
else {"value": value}
|
||||
),
|
||||
)
|
||||
|
||||
if multiple_rows:
|
||||
i_range = range(1, 11)
|
||||
else:
|
||||
i_range = range(1, 2)
|
||||
|
||||
# we want to test only that we are getting floating points back
|
||||
# with some degree of the original value maintained, that it is not
|
||||
# being truncated to an integer. there's too much variation in how
|
||||
# drivers return floats, which should not be relied upon to be
|
||||
# exact, for us to just compare as is (works for PG drivers but not
|
||||
# others) so we use rounding here. There's precedent for this
|
||||
# in suite/test_types.py::NumericTest as well
|
||||
|
||||
if do_rounding:
|
||||
eq_(
|
||||
{(id_, round(val_, 5)) for id_, val_ in result},
|
||||
{(id_, round(value, 5)) for id_ in i_range},
|
||||
)
|
||||
|
||||
eq_(
|
||||
{
|
||||
round(val_, 5)
|
||||
for val_ in connection.scalars(select(t.c.value))
|
||||
},
|
||||
{round(value, 5)},
|
||||
)
|
||||
else:
|
||||
eq_(
|
||||
set(result),
|
||||
{(id_, value) for id_ in i_range},
|
||||
)
|
||||
|
||||
eq_(
|
||||
set(connection.scalars(select(t.c.value))),
|
||||
{value},
|
||||
)
|
||||
|
||||
@testing.combinations(
|
||||
(
|
||||
"non_native_uuid",
|
||||
Uuid(native_uuid=False),
|
||||
uuid.uuid4(),
|
||||
),
|
||||
(
|
||||
"non_native_uuid_str",
|
||||
Uuid(as_uuid=False, native_uuid=False),
|
||||
str(uuid.uuid4()),
|
||||
),
|
||||
(
|
||||
"generic_native_uuid",
|
||||
Uuid(native_uuid=True),
|
||||
uuid.uuid4(),
|
||||
testing.requires.uuid_data_type,
|
||||
),
|
||||
(
|
||||
"generic_native_uuid_str",
|
||||
Uuid(as_uuid=False, native_uuid=True),
|
||||
str(uuid.uuid4()),
|
||||
testing.requires.uuid_data_type,
|
||||
),
|
||||
("UUID", UUID(), uuid.uuid4(), testing.requires.uuid_data_type),
|
||||
(
|
||||
"LargeBinary1",
|
||||
LargeBinary(),
|
||||
b"this is binary",
|
||||
),
|
||||
("LargeBinary2", LargeBinary(), b"7\xe7\x9f"),
|
||||
argnames="type_,value",
|
||||
id_="iaa",
|
||||
)
|
||||
@testing.variation("sort_by_parameter_order", [True, False])
|
||||
@testing.variation("multiple_rows", [True, False])
|
||||
@testing.requires.insert_returning
|
||||
def test_imv_returning_datatypes(
|
||||
self,
|
||||
connection,
|
||||
metadata,
|
||||
sort_by_parameter_order,
|
||||
type_,
|
||||
value,
|
||||
multiple_rows,
|
||||
):
|
||||
"""test #9739, #9808 (similar to #9701).
|
||||
|
||||
this tests insertmanyvalues in conjunction with various datatypes.
|
||||
|
||||
These tests are particularly for the asyncpg driver which needs
|
||||
most types to be explicitly cast for the new IMV format
|
||||
|
||||
"""
|
||||
t = Table(
|
||||
"d_t",
|
||||
metadata,
|
||||
Column("id", Integer, Identity(), primary_key=True),
|
||||
Column("value", type_),
|
||||
)
|
||||
|
||||
t.create(connection)
|
||||
|
||||
result = connection.execute(
|
||||
t.insert().returning(
|
||||
t.c.id,
|
||||
t.c.value,
|
||||
sort_by_parameter_order=bool(sort_by_parameter_order),
|
||||
),
|
||||
(
|
||||
[{"value": value} for i in range(10)]
|
||||
if multiple_rows
|
||||
else {"value": value}
|
||||
),
|
||||
)
|
||||
|
||||
if multiple_rows:
|
||||
i_range = range(1, 11)
|
||||
else:
|
||||
i_range = range(1, 2)
|
||||
|
||||
eq_(
|
||||
set(result),
|
||||
{(id_, value) for id_ in i_range},
|
||||
)
|
||||
|
||||
eq_(
|
||||
set(connection.scalars(select(t.c.value))),
|
||||
{value},
|
||||
)
|
||||
|
||||
|
||||
__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest")
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,504 @@
|
||||
# testing/suite/test_results.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
import datetime
|
||||
import re
|
||||
|
||||
from .. import engines
|
||||
from .. import fixtures
|
||||
from ..assertions import eq_
|
||||
from ..config import requirements
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import DateTime
|
||||
from ... import func
|
||||
from ... import Integer
|
||||
from ... import select
|
||||
from ... import sql
|
||||
from ... import String
|
||||
from ... import testing
|
||||
from ... import text
|
||||
|
||||
|
||||
class RowFetchTest(fixtures.TablesTest):
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"plain_pk",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
Table(
|
||||
"has_dates",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("today", DateTime),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls, connection):
|
||||
connection.execute(
|
||||
cls.tables.plain_pk.insert(),
|
||||
[
|
||||
{"id": 1, "data": "d1"},
|
||||
{"id": 2, "data": "d2"},
|
||||
{"id": 3, "data": "d3"},
|
||||
],
|
||||
)
|
||||
|
||||
connection.execute(
|
||||
cls.tables.has_dates.insert(),
|
||||
[{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}],
|
||||
)
|
||||
|
||||
def test_via_attr(self, connection):
|
||||
row = connection.execute(
|
||||
self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
|
||||
).first()
|
||||
|
||||
eq_(row.id, 1)
|
||||
eq_(row.data, "d1")
|
||||
|
||||
def test_via_string(self, connection):
|
||||
row = connection.execute(
|
||||
self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
|
||||
).first()
|
||||
|
||||
eq_(row._mapping["id"], 1)
|
||||
eq_(row._mapping["data"], "d1")
|
||||
|
||||
def test_via_int(self, connection):
|
||||
row = connection.execute(
|
||||
self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
|
||||
).first()
|
||||
|
||||
eq_(row[0], 1)
|
||||
eq_(row[1], "d1")
|
||||
|
||||
def test_via_col_object(self, connection):
|
||||
row = connection.execute(
|
||||
self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
|
||||
).first()
|
||||
|
||||
eq_(row._mapping[self.tables.plain_pk.c.id], 1)
|
||||
eq_(row._mapping[self.tables.plain_pk.c.data], "d1")
|
||||
|
||||
@requirements.duplicate_names_in_cursor_description
|
||||
def test_row_with_dupe_names(self, connection):
|
||||
result = connection.execute(
|
||||
select(
|
||||
self.tables.plain_pk.c.data,
|
||||
self.tables.plain_pk.c.data.label("data"),
|
||||
).order_by(self.tables.plain_pk.c.id)
|
||||
)
|
||||
row = result.first()
|
||||
eq_(result.keys(), ["data", "data"])
|
||||
eq_(row, ("d1", "d1"))
|
||||
|
||||
def test_row_w_scalar_select(self, connection):
|
||||
"""test that a scalar select as a column is returned as such
|
||||
and that type conversion works OK.
|
||||
|
||||
(this is half a SQLAlchemy Core test and half to catch database
|
||||
backends that may have unusual behavior with scalar selects.)
|
||||
|
||||
"""
|
||||
datetable = self.tables.has_dates
|
||||
s = select(datetable.alias("x").c.today).scalar_subquery()
|
||||
s2 = select(datetable.c.id, s.label("somelabel"))
|
||||
row = connection.execute(s2).first()
|
||||
|
||||
eq_(row.somelabel, datetime.datetime(2006, 5, 12, 12, 0, 0))
|
||||
|
||||
|
||||
class PercentSchemaNamesTest(fixtures.TablesTest):
|
||||
"""tests using percent signs, spaces in table and column names.
|
||||
|
||||
This didn't work for PostgreSQL / MySQL drivers for a long time
|
||||
but is now supported.
|
||||
|
||||
"""
|
||||
|
||||
__requires__ = ("percent_schema_names",)
|
||||
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
cls.tables.percent_table = Table(
|
||||
"percent%table",
|
||||
metadata,
|
||||
Column("percent%", Integer),
|
||||
Column("spaces % more spaces", Integer),
|
||||
)
|
||||
cls.tables.lightweight_percent_table = sql.table(
|
||||
"percent%table",
|
||||
sql.column("percent%"),
|
||||
sql.column("spaces % more spaces"),
|
||||
)
|
||||
|
||||
def test_single_roundtrip(self, connection):
|
||||
percent_table = self.tables.percent_table
|
||||
for params in [
|
||||
{"percent%": 5, "spaces % more spaces": 12},
|
||||
{"percent%": 7, "spaces % more spaces": 11},
|
||||
{"percent%": 9, "spaces % more spaces": 10},
|
||||
{"percent%": 11, "spaces % more spaces": 9},
|
||||
]:
|
||||
connection.execute(percent_table.insert(), params)
|
||||
self._assert_table(connection)
|
||||
|
||||
def test_executemany_roundtrip(self, connection):
|
||||
percent_table = self.tables.percent_table
|
||||
connection.execute(
|
||||
percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
|
||||
)
|
||||
connection.execute(
|
||||
percent_table.insert(),
|
||||
[
|
||||
{"percent%": 7, "spaces % more spaces": 11},
|
||||
{"percent%": 9, "spaces % more spaces": 10},
|
||||
{"percent%": 11, "spaces % more spaces": 9},
|
||||
],
|
||||
)
|
||||
self._assert_table(connection)
|
||||
|
||||
@requirements.insert_executemany_returning
|
||||
def test_executemany_returning_roundtrip(self, connection):
|
||||
percent_table = self.tables.percent_table
|
||||
connection.execute(
|
||||
percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
|
||||
)
|
||||
result = connection.execute(
|
||||
percent_table.insert().returning(
|
||||
percent_table.c["percent%"],
|
||||
percent_table.c["spaces % more spaces"],
|
||||
),
|
||||
[
|
||||
{"percent%": 7, "spaces % more spaces": 11},
|
||||
{"percent%": 9, "spaces % more spaces": 10},
|
||||
{"percent%": 11, "spaces % more spaces": 9},
|
||||
],
|
||||
)
|
||||
eq_(result.all(), [(7, 11), (9, 10), (11, 9)])
|
||||
self._assert_table(connection)
|
||||
|
||||
def _assert_table(self, conn):
|
||||
percent_table = self.tables.percent_table
|
||||
lightweight_percent_table = self.tables.lightweight_percent_table
|
||||
|
||||
for table in (
|
||||
percent_table,
|
||||
percent_table.alias(),
|
||||
lightweight_percent_table,
|
||||
lightweight_percent_table.alias(),
|
||||
):
|
||||
eq_(
|
||||
list(
|
||||
conn.execute(table.select().order_by(table.c["percent%"]))
|
||||
),
|
||||
[(5, 12), (7, 11), (9, 10), (11, 9)],
|
||||
)
|
||||
|
||||
eq_(
|
||||
list(
|
||||
conn.execute(
|
||||
table.select()
|
||||
.where(table.c["spaces % more spaces"].in_([9, 10]))
|
||||
.order_by(table.c["percent%"])
|
||||
)
|
||||
),
|
||||
[(9, 10), (11, 9)],
|
||||
)
|
||||
|
||||
row = conn.execute(
|
||||
table.select().order_by(table.c["percent%"])
|
||||
).first()
|
||||
eq_(row._mapping["percent%"], 5)
|
||||
eq_(row._mapping["spaces % more spaces"], 12)
|
||||
|
||||
eq_(row._mapping[table.c["percent%"]], 5)
|
||||
eq_(row._mapping[table.c["spaces % more spaces"]], 12)
|
||||
|
||||
conn.execute(
|
||||
percent_table.update().values(
|
||||
{percent_table.c["spaces % more spaces"]: 15}
|
||||
)
|
||||
)
|
||||
|
||||
eq_(
|
||||
list(
|
||||
conn.execute(
|
||||
percent_table.select().order_by(
|
||||
percent_table.c["percent%"]
|
||||
)
|
||||
)
|
||||
),
|
||||
[(5, 15), (7, 15), (9, 15), (11, 15)],
|
||||
)
|
||||
|
||||
|
||||
class ServerSideCursorsTest(
|
||||
fixtures.TestBase, testing.AssertsExecutionResults
|
||||
):
|
||||
__requires__ = ("server_side_cursors",)
|
||||
|
||||
__backend__ = True
|
||||
|
||||
def _is_server_side(self, cursor):
|
||||
# TODO: this is a huge issue as it prevents these tests from being
|
||||
# usable by third party dialects.
|
||||
if self.engine.dialect.driver == "psycopg2":
|
||||
return bool(cursor.name)
|
||||
elif self.engine.dialect.driver == "pymysql":
|
||||
sscursor = __import__("pymysql.cursors").cursors.SSCursor
|
||||
return isinstance(cursor, sscursor)
|
||||
elif self.engine.dialect.driver in ("aiomysql", "asyncmy", "aioodbc"):
|
||||
return cursor.server_side
|
||||
elif self.engine.dialect.driver == "mysqldb":
|
||||
sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
|
||||
return isinstance(cursor, sscursor)
|
||||
elif self.engine.dialect.driver == "mariadbconnector":
|
||||
return not cursor.buffered
|
||||
elif self.engine.dialect.driver == "mysqlconnector":
|
||||
return "buffered" not in type(cursor).__name__.lower()
|
||||
elif self.engine.dialect.driver in ("asyncpg", "aiosqlite"):
|
||||
return cursor.server_side
|
||||
elif self.engine.dialect.driver == "pg8000":
|
||||
return getattr(cursor, "server_side", False)
|
||||
elif self.engine.dialect.driver == "psycopg":
|
||||
return bool(getattr(cursor, "name", False))
|
||||
elif self.engine.dialect.driver == "oracledb":
|
||||
return getattr(cursor, "server_side", False)
|
||||
else:
|
||||
return False
|
||||
|
||||
def _fixture(self, server_side_cursors):
|
||||
if server_side_cursors:
|
||||
with testing.expect_deprecated(
|
||||
"The create_engine.server_side_cursors parameter is "
|
||||
"deprecated and will be removed in a future release. "
|
||||
"Please use the Connection.execution_options.stream_results "
|
||||
"parameter."
|
||||
):
|
||||
self.engine = engines.testing_engine(
|
||||
options={"server_side_cursors": server_side_cursors}
|
||||
)
|
||||
else:
|
||||
self.engine = engines.testing_engine(
|
||||
options={"server_side_cursors": server_side_cursors}
|
||||
)
|
||||
return self.engine
|
||||
|
||||
def stringify(self, str_):
|
||||
return re.compile(r"SELECT (\d+)", re.I).sub(
|
||||
lambda m: str(select(int(m.group(1))).compile(testing.db)), str_
|
||||
)
|
||||
|
||||
@testing.combinations(
|
||||
("global_string", True, lambda stringify: stringify("select 1"), True),
|
||||
(
|
||||
"global_text",
|
||||
True,
|
||||
lambda stringify: text(stringify("select 1")),
|
||||
True,
|
||||
),
|
||||
("global_expr", True, select(1), True),
|
||||
(
|
||||
"global_off_explicit",
|
||||
False,
|
||||
lambda stringify: text(stringify("select 1")),
|
||||
False,
|
||||
),
|
||||
(
|
||||
"stmt_option",
|
||||
False,
|
||||
select(1).execution_options(stream_results=True),
|
||||
True,
|
||||
),
|
||||
(
|
||||
"stmt_option_disabled",
|
||||
True,
|
||||
select(1).execution_options(stream_results=False),
|
||||
False,
|
||||
),
|
||||
("for_update_expr", True, select(1).with_for_update(), True),
|
||||
# TODO: need a real requirement for this, or dont use this test
|
||||
(
|
||||
"for_update_string",
|
||||
True,
|
||||
lambda stringify: stringify("SELECT 1 FOR UPDATE"),
|
||||
True,
|
||||
testing.skip_if(["sqlite", "mssql"]),
|
||||
),
|
||||
(
|
||||
"text_no_ss",
|
||||
False,
|
||||
lambda stringify: text(stringify("select 42")),
|
||||
False,
|
||||
),
|
||||
(
|
||||
"text_ss_option",
|
||||
False,
|
||||
lambda stringify: text(stringify("select 42")).execution_options(
|
||||
stream_results=True
|
||||
),
|
||||
True,
|
||||
),
|
||||
id_="iaaa",
|
||||
argnames="engine_ss_arg, statement, cursor_ss_status",
|
||||
)
|
||||
def test_ss_cursor_status(
|
||||
self, engine_ss_arg, statement, cursor_ss_status
|
||||
):
|
||||
engine = self._fixture(engine_ss_arg)
|
||||
with engine.begin() as conn:
|
||||
if callable(statement):
|
||||
statement = testing.resolve_lambda(
|
||||
statement, stringify=self.stringify
|
||||
)
|
||||
|
||||
if isinstance(statement, str):
|
||||
result = conn.exec_driver_sql(statement)
|
||||
else:
|
||||
result = conn.execute(statement)
|
||||
eq_(self._is_server_side(result.cursor), cursor_ss_status)
|
||||
result.close()
|
||||
|
||||
def test_conn_option(self):
|
||||
engine = self._fixture(False)
|
||||
|
||||
with engine.connect() as conn:
|
||||
# should be enabled for this one
|
||||
result = conn.execution_options(
|
||||
stream_results=True
|
||||
).exec_driver_sql(self.stringify("select 1"))
|
||||
assert self._is_server_side(result.cursor)
|
||||
|
||||
# the connection has autobegun, which means at the end of the
|
||||
# block, we will roll back, which on MySQL at least will fail
|
||||
# with "Commands out of sync" if the result set
|
||||
# is not closed, so we close it first.
|
||||
#
|
||||
# fun fact! why did we not have this result.close() in this test
|
||||
# before 2.0? don't we roll back in the connection pool
|
||||
# unconditionally? yes! and in fact if you run this test in 1.4
|
||||
# with stdout shown, there is in fact "Exception during reset or
|
||||
# similar" with "Commands out sync" emitted a warning! 2.0's
|
||||
# architecture finds and fixes what was previously an expensive
|
||||
# silent error condition.
|
||||
result.close()
|
||||
|
||||
def test_stmt_enabled_conn_option_disabled(self):
|
||||
engine = self._fixture(False)
|
||||
|
||||
s = select(1).execution_options(stream_results=True)
|
||||
|
||||
with engine.connect() as conn:
|
||||
# not this one
|
||||
result = conn.execution_options(stream_results=False).execute(s)
|
||||
assert not self._is_server_side(result.cursor)
|
||||
|
||||
def test_aliases_and_ss(self):
|
||||
engine = self._fixture(False)
|
||||
s1 = (
|
||||
select(sql.literal_column("1").label("x"))
|
||||
.execution_options(stream_results=True)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# options don't propagate out when subquery is used as a FROM clause
|
||||
with engine.begin() as conn:
|
||||
result = conn.execute(s1.select())
|
||||
assert not self._is_server_side(result.cursor)
|
||||
result.close()
|
||||
|
||||
s2 = select(1).select_from(s1)
|
||||
with engine.begin() as conn:
|
||||
result = conn.execute(s2)
|
||||
assert not self._is_server_side(result.cursor)
|
||||
result.close()
|
||||
|
||||
def test_roundtrip_fetchall(self, metadata):
|
||||
md = self.metadata
|
||||
|
||||
engine = self._fixture(True)
|
||||
test_table = Table(
|
||||
"test_table",
|
||||
md,
|
||||
Column(
|
||||
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
||||
),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
with engine.begin() as connection:
|
||||
test_table.create(connection, checkfirst=True)
|
||||
connection.execute(test_table.insert(), dict(data="data1"))
|
||||
connection.execute(test_table.insert(), dict(data="data2"))
|
||||
eq_(
|
||||
connection.execute(
|
||||
test_table.select().order_by(test_table.c.id)
|
||||
).fetchall(),
|
||||
[(1, "data1"), (2, "data2")],
|
||||
)
|
||||
connection.execute(
|
||||
test_table.update()
|
||||
.where(test_table.c.id == 2)
|
||||
.values(data=test_table.c.data + " updated")
|
||||
)
|
||||
eq_(
|
||||
connection.execute(
|
||||
test_table.select().order_by(test_table.c.id)
|
||||
).fetchall(),
|
||||
[(1, "data1"), (2, "data2 updated")],
|
||||
)
|
||||
connection.execute(test_table.delete())
|
||||
eq_(
|
||||
connection.scalar(
|
||||
select(func.count("*")).select_from(test_table)
|
||||
),
|
||||
0,
|
||||
)
|
||||
|
||||
def test_roundtrip_fetchmany(self, metadata):
|
||||
md = self.metadata
|
||||
|
||||
engine = self._fixture(True)
|
||||
test_table = Table(
|
||||
"test_table",
|
||||
md,
|
||||
Column(
|
||||
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
||||
),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
with engine.begin() as connection:
|
||||
test_table.create(connection, checkfirst=True)
|
||||
connection.execute(
|
||||
test_table.insert(),
|
||||
[dict(data="data%d" % i) for i in range(1, 20)],
|
||||
)
|
||||
|
||||
result = connection.execute(
|
||||
test_table.select().order_by(test_table.c.id)
|
||||
)
|
||||
|
||||
eq_(
|
||||
result.fetchmany(5),
|
||||
[(i, "data%d" % i) for i in range(1, 6)],
|
||||
)
|
||||
eq_(
|
||||
result.fetchmany(10),
|
||||
[(i, "data%d" % i) for i in range(6, 16)],
|
||||
)
|
||||
eq_(result.fetchall(), [(i, "data%d" % i) for i in range(16, 20)])
|
@ -0,0 +1,258 @@
|
||||
# testing/suite/test_rowcount.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from sqlalchemy import bindparam
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Table
|
||||
from sqlalchemy import testing
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.testing import eq_
|
||||
from sqlalchemy.testing import fixtures
|
||||
|
||||
|
||||
class RowCountTest(fixtures.TablesTest):
|
||||
"""test rowcount functionality"""
|
||||
|
||||
__requires__ = ("sane_rowcount",)
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"employees",
|
||||
metadata,
|
||||
Column(
|
||||
"employee_id",
|
||||
Integer,
|
||||
autoincrement=False,
|
||||
primary_key=True,
|
||||
),
|
||||
Column("name", String(50)),
|
||||
Column("department", String(1)),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls, connection):
|
||||
cls.data = data = [
|
||||
("Angela", "A"),
|
||||
("Andrew", "A"),
|
||||
("Anand", "A"),
|
||||
("Bob", "B"),
|
||||
("Bobette", "B"),
|
||||
("Buffy", "B"),
|
||||
("Charlie", "C"),
|
||||
("Cynthia", "C"),
|
||||
("Chris", "C"),
|
||||
]
|
||||
|
||||
employees_table = cls.tables.employees
|
||||
connection.execute(
|
||||
employees_table.insert(),
|
||||
[
|
||||
{"employee_id": i, "name": n, "department": d}
|
||||
for i, (n, d) in enumerate(data)
|
||||
],
|
||||
)
|
||||
|
||||
def test_basic(self, connection):
|
||||
employees_table = self.tables.employees
|
||||
s = select(
|
||||
employees_table.c.name, employees_table.c.department
|
||||
).order_by(employees_table.c.employee_id)
|
||||
rows = connection.execute(s).fetchall()
|
||||
|
||||
eq_(rows, self.data)
|
||||
|
||||
@testing.variation("statement", ["update", "delete", "insert", "select"])
|
||||
@testing.variation("close_first", [True, False])
|
||||
def test_non_rowcount_scenarios_no_raise(
|
||||
self, connection, statement, close_first
|
||||
):
|
||||
employees_table = self.tables.employees
|
||||
|
||||
# WHERE matches 3, 3 rows changed
|
||||
department = employees_table.c.department
|
||||
|
||||
if statement.update:
|
||||
r = connection.execute(
|
||||
employees_table.update().where(department == "C"),
|
||||
{"department": "Z"},
|
||||
)
|
||||
elif statement.delete:
|
||||
r = connection.execute(
|
||||
employees_table.delete().where(department == "C"),
|
||||
{"department": "Z"},
|
||||
)
|
||||
elif statement.insert:
|
||||
r = connection.execute(
|
||||
employees_table.insert(),
|
||||
[
|
||||
{"employee_id": 25, "name": "none 1", "department": "X"},
|
||||
{"employee_id": 26, "name": "none 2", "department": "Z"},
|
||||
{"employee_id": 27, "name": "none 3", "department": "Z"},
|
||||
],
|
||||
)
|
||||
elif statement.select:
|
||||
s = select(
|
||||
employees_table.c.name, employees_table.c.department
|
||||
).where(employees_table.c.department == "C")
|
||||
r = connection.execute(s)
|
||||
r.all()
|
||||
else:
|
||||
statement.fail()
|
||||
|
||||
if close_first:
|
||||
r.close()
|
||||
|
||||
assert r.rowcount in (-1, 3)
|
||||
|
||||
def test_update_rowcount1(self, connection):
|
||||
employees_table = self.tables.employees
|
||||
|
||||
# WHERE matches 3, 3 rows changed
|
||||
department = employees_table.c.department
|
||||
r = connection.execute(
|
||||
employees_table.update().where(department == "C"),
|
||||
{"department": "Z"},
|
||||
)
|
||||
assert r.rowcount == 3
|
||||
|
||||
def test_update_rowcount2(self, connection):
|
||||
employees_table = self.tables.employees
|
||||
|
||||
# WHERE matches 3, 0 rows changed
|
||||
department = employees_table.c.department
|
||||
|
||||
r = connection.execute(
|
||||
employees_table.update().where(department == "C"),
|
||||
{"department": "C"},
|
||||
)
|
||||
eq_(r.rowcount, 3)
|
||||
|
||||
@testing.variation("implicit_returning", [True, False])
|
||||
@testing.variation(
|
||||
"dml",
|
||||
[
|
||||
("update", testing.requires.update_returning),
|
||||
("delete", testing.requires.delete_returning),
|
||||
],
|
||||
)
|
||||
def test_update_delete_rowcount_return_defaults(
|
||||
self, connection, implicit_returning, dml
|
||||
):
|
||||
"""note this test should succeed for all RETURNING backends
|
||||
as of 2.0. In
|
||||
Idf28379f8705e403a3c6a937f6a798a042ef2540 we changed rowcount to use
|
||||
len(rows) when we have implicit returning
|
||||
|
||||
"""
|
||||
|
||||
if implicit_returning:
|
||||
employees_table = self.tables.employees
|
||||
else:
|
||||
employees_table = Table(
|
||||
"employees",
|
||||
MetaData(),
|
||||
Column(
|
||||
"employee_id",
|
||||
Integer,
|
||||
autoincrement=False,
|
||||
primary_key=True,
|
||||
),
|
||||
Column("name", String(50)),
|
||||
Column("department", String(1)),
|
||||
implicit_returning=False,
|
||||
)
|
||||
|
||||
department = employees_table.c.department
|
||||
|
||||
if dml.update:
|
||||
stmt = (
|
||||
employees_table.update()
|
||||
.where(department == "C")
|
||||
.values(name=employees_table.c.department + "Z")
|
||||
.return_defaults()
|
||||
)
|
||||
elif dml.delete:
|
||||
stmt = (
|
||||
employees_table.delete()
|
||||
.where(department == "C")
|
||||
.return_defaults()
|
||||
)
|
||||
else:
|
||||
dml.fail()
|
||||
|
||||
r = connection.execute(stmt)
|
||||
eq_(r.rowcount, 3)
|
||||
|
||||
def test_raw_sql_rowcount(self, connection):
|
||||
# test issue #3622, make sure eager rowcount is called for text
|
||||
result = connection.exec_driver_sql(
|
||||
"update employees set department='Z' where department='C'"
|
||||
)
|
||||
eq_(result.rowcount, 3)
|
||||
|
||||
def test_text_rowcount(self, connection):
|
||||
# test issue #3622, make sure eager rowcount is called for text
|
||||
result = connection.execute(
|
||||
text("update employees set department='Z' where department='C'")
|
||||
)
|
||||
eq_(result.rowcount, 3)
|
||||
|
||||
def test_delete_rowcount(self, connection):
|
||||
employees_table = self.tables.employees
|
||||
|
||||
# WHERE matches 3, 3 rows deleted
|
||||
department = employees_table.c.department
|
||||
r = connection.execute(
|
||||
employees_table.delete().where(department == "C")
|
||||
)
|
||||
eq_(r.rowcount, 3)
|
||||
|
||||
@testing.requires.sane_multi_rowcount
|
||||
def test_multi_update_rowcount(self, connection):
|
||||
employees_table = self.tables.employees
|
||||
stmt = (
|
||||
employees_table.update()
|
||||
.where(employees_table.c.name == bindparam("emp_name"))
|
||||
.values(department="C")
|
||||
)
|
||||
|
||||
r = connection.execute(
|
||||
stmt,
|
||||
[
|
||||
{"emp_name": "Bob"},
|
||||
{"emp_name": "Cynthia"},
|
||||
{"emp_name": "nonexistent"},
|
||||
],
|
||||
)
|
||||
|
||||
eq_(r.rowcount, 2)
|
||||
|
||||
@testing.requires.sane_multi_rowcount
|
||||
def test_multi_delete_rowcount(self, connection):
|
||||
employees_table = self.tables.employees
|
||||
|
||||
stmt = employees_table.delete().where(
|
||||
employees_table.c.name == bindparam("emp_name")
|
||||
)
|
||||
|
||||
r = connection.execute(
|
||||
stmt,
|
||||
[
|
||||
{"emp_name": "Bob"},
|
||||
{"emp_name": "Cynthia"},
|
||||
{"emp_name": "nonexistent"},
|
||||
],
|
||||
)
|
||||
|
||||
eq_(r.rowcount, 2)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,317 @@
|
||||
# testing/suite/test_sequence.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from .. import config
|
||||
from .. import fixtures
|
||||
from ..assertions import eq_
|
||||
from ..assertions import is_true
|
||||
from ..config import requirements
|
||||
from ..provision import normalize_sequence
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import inspect
|
||||
from ... import Integer
|
||||
from ... import MetaData
|
||||
from ... import Sequence
|
||||
from ... import String
|
||||
from ... import testing
|
||||
|
||||
|
||||
class SequenceTest(fixtures.TablesTest):
|
||||
__requires__ = ("sequences",)
|
||||
__backend__ = True
|
||||
|
||||
run_create_tables = "each"
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"seq_pk",
|
||||
metadata,
|
||||
Column(
|
||||
"id",
|
||||
Integer,
|
||||
normalize_sequence(config, Sequence("tab_id_seq")),
|
||||
primary_key=True,
|
||||
),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
Table(
|
||||
"seq_opt_pk",
|
||||
metadata,
|
||||
Column(
|
||||
"id",
|
||||
Integer,
|
||||
normalize_sequence(
|
||||
config,
|
||||
Sequence("tab_id_seq", data_type=Integer, optional=True),
|
||||
),
|
||||
primary_key=True,
|
||||
),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
Table(
|
||||
"seq_no_returning",
|
||||
metadata,
|
||||
Column(
|
||||
"id",
|
||||
Integer,
|
||||
normalize_sequence(config, Sequence("noret_id_seq")),
|
||||
primary_key=True,
|
||||
),
|
||||
Column("data", String(50)),
|
||||
implicit_returning=False,
|
||||
)
|
||||
|
||||
if testing.requires.schemas.enabled:
|
||||
Table(
|
||||
"seq_no_returning_sch",
|
||||
metadata,
|
||||
Column(
|
||||
"id",
|
||||
Integer,
|
||||
normalize_sequence(
|
||||
config,
|
||||
Sequence(
|
||||
"noret_sch_id_seq", schema=config.test_schema
|
||||
),
|
||||
),
|
||||
primary_key=True,
|
||||
),
|
||||
Column("data", String(50)),
|
||||
implicit_returning=False,
|
||||
schema=config.test_schema,
|
||||
)
|
||||
|
||||
def test_insert_roundtrip(self, connection):
|
||||
connection.execute(self.tables.seq_pk.insert(), dict(data="some data"))
|
||||
self._assert_round_trip(self.tables.seq_pk, connection)
|
||||
|
||||
def test_insert_lastrowid(self, connection):
|
||||
r = connection.execute(
|
||||
self.tables.seq_pk.insert(), dict(data="some data")
|
||||
)
|
||||
eq_(
|
||||
r.inserted_primary_key, (testing.db.dialect.default_sequence_base,)
|
||||
)
|
||||
|
||||
def test_nextval_direct(self, connection):
|
||||
r = connection.scalar(self.tables.seq_pk.c.id.default)
|
||||
eq_(r, testing.db.dialect.default_sequence_base)
|
||||
|
||||
@requirements.sequences_optional
|
||||
def test_optional_seq(self, connection):
|
||||
r = connection.execute(
|
||||
self.tables.seq_opt_pk.insert(), dict(data="some data")
|
||||
)
|
||||
eq_(r.inserted_primary_key, (1,))
|
||||
|
||||
def _assert_round_trip(self, table, conn):
|
||||
row = conn.execute(table.select()).first()
|
||||
eq_(row, (testing.db.dialect.default_sequence_base, "some data"))
|
||||
|
||||
def test_insert_roundtrip_no_implicit_returning(self, connection):
|
||||
connection.execute(
|
||||
self.tables.seq_no_returning.insert(), dict(data="some data")
|
||||
)
|
||||
self._assert_round_trip(self.tables.seq_no_returning, connection)
|
||||
|
||||
@testing.combinations((True,), (False,), argnames="implicit_returning")
|
||||
@testing.requires.schemas
|
||||
def test_insert_roundtrip_translate(self, connection, implicit_returning):
|
||||
seq_no_returning = Table(
|
||||
"seq_no_returning_sch",
|
||||
MetaData(),
|
||||
Column(
|
||||
"id",
|
||||
Integer,
|
||||
normalize_sequence(
|
||||
config, Sequence("noret_sch_id_seq", schema="alt_schema")
|
||||
),
|
||||
primary_key=True,
|
||||
),
|
||||
Column("data", String(50)),
|
||||
implicit_returning=implicit_returning,
|
||||
schema="alt_schema",
|
||||
)
|
||||
|
||||
connection = connection.execution_options(
|
||||
schema_translate_map={"alt_schema": config.test_schema}
|
||||
)
|
||||
connection.execute(seq_no_returning.insert(), dict(data="some data"))
|
||||
self._assert_round_trip(seq_no_returning, connection)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_nextval_direct_schema_translate(self, connection):
|
||||
seq = normalize_sequence(
|
||||
config, Sequence("noret_sch_id_seq", schema="alt_schema")
|
||||
)
|
||||
connection = connection.execution_options(
|
||||
schema_translate_map={"alt_schema": config.test_schema}
|
||||
)
|
||||
|
||||
r = connection.scalar(seq)
|
||||
eq_(r, testing.db.dialect.default_sequence_base)
|
||||
|
||||
|
||||
class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase):
|
||||
__requires__ = ("sequences",)
|
||||
__backend__ = True
|
||||
|
||||
def test_literal_binds_inline_compile(self, connection):
|
||||
table = Table(
|
||||
"x",
|
||||
MetaData(),
|
||||
Column(
|
||||
"y", Integer, normalize_sequence(config, Sequence("y_seq"))
|
||||
),
|
||||
Column("q", Integer),
|
||||
)
|
||||
|
||||
stmt = table.insert().values(q=5)
|
||||
|
||||
seq_nextval = connection.dialect.statement_compiler(
|
||||
statement=None, dialect=connection.dialect
|
||||
).visit_sequence(normalize_sequence(config, Sequence("y_seq")))
|
||||
self.assert_compile(
|
||||
stmt,
|
||||
"INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval,),
|
||||
literal_binds=True,
|
||||
dialect=connection.dialect,
|
||||
)
|
||||
|
||||
|
||||
class HasSequenceTest(fixtures.TablesTest):
|
||||
run_deletes = None
|
||||
|
||||
__requires__ = ("sequences",)
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
normalize_sequence(config, Sequence("user_id_seq", metadata=metadata))
|
||||
normalize_sequence(
|
||||
config,
|
||||
Sequence(
|
||||
"other_seq",
|
||||
metadata=metadata,
|
||||
nomaxvalue=True,
|
||||
nominvalue=True,
|
||||
),
|
||||
)
|
||||
if testing.requires.schemas.enabled:
|
||||
normalize_sequence(
|
||||
config,
|
||||
Sequence(
|
||||
"user_id_seq", schema=config.test_schema, metadata=metadata
|
||||
),
|
||||
)
|
||||
normalize_sequence(
|
||||
config,
|
||||
Sequence(
|
||||
"schema_seq", schema=config.test_schema, metadata=metadata
|
||||
),
|
||||
)
|
||||
Table(
|
||||
"user_id_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
)
|
||||
|
||||
def test_has_sequence(self, connection):
|
||||
eq_(inspect(connection).has_sequence("user_id_seq"), True)
|
||||
|
||||
def test_has_sequence_cache(self, connection, metadata):
|
||||
insp = inspect(connection)
|
||||
eq_(insp.has_sequence("user_id_seq"), True)
|
||||
ss = normalize_sequence(config, Sequence("new_seq", metadata=metadata))
|
||||
eq_(insp.has_sequence("new_seq"), False)
|
||||
ss.create(connection)
|
||||
try:
|
||||
eq_(insp.has_sequence("new_seq"), False)
|
||||
insp.clear_cache()
|
||||
eq_(insp.has_sequence("new_seq"), True)
|
||||
finally:
|
||||
ss.drop(connection)
|
||||
|
||||
def test_has_sequence_other_object(self, connection):
|
||||
eq_(inspect(connection).has_sequence("user_id_table"), False)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_has_sequence_schema(self, connection):
|
||||
eq_(
|
||||
inspect(connection).has_sequence(
|
||||
"user_id_seq", schema=config.test_schema
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
def test_has_sequence_neg(self, connection):
|
||||
eq_(inspect(connection).has_sequence("some_sequence"), False)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_has_sequence_schemas_neg(self, connection):
|
||||
eq_(
|
||||
inspect(connection).has_sequence(
|
||||
"some_sequence", schema=config.test_schema
|
||||
),
|
||||
False,
|
||||
)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_has_sequence_default_not_in_remote(self, connection):
|
||||
eq_(
|
||||
inspect(connection).has_sequence(
|
||||
"other_sequence", schema=config.test_schema
|
||||
),
|
||||
False,
|
||||
)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_has_sequence_remote_not_in_default(self, connection):
|
||||
eq_(inspect(connection).has_sequence("schema_seq"), False)
|
||||
|
||||
def test_get_sequence_names(self, connection):
|
||||
exp = {"other_seq", "user_id_seq"}
|
||||
|
||||
res = set(inspect(connection).get_sequence_names())
|
||||
is_true(res.intersection(exp) == exp)
|
||||
is_true("schema_seq" not in res)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_get_sequence_names_no_sequence_schema(self, connection):
|
||||
eq_(
|
||||
inspect(connection).get_sequence_names(
|
||||
schema=config.test_schema_2
|
||||
),
|
||||
[],
|
||||
)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_get_sequence_names_sequences_schema(self, connection):
|
||||
eq_(
|
||||
sorted(
|
||||
inspect(connection).get_sequence_names(
|
||||
schema=config.test_schema
|
||||
)
|
||||
),
|
||||
["schema_seq", "user_id_seq"],
|
||||
)
|
||||
|
||||
|
||||
class HasSequenceTestEmpty(fixtures.TestBase):
|
||||
__requires__ = ("sequences",)
|
||||
__backend__ = True
|
||||
|
||||
def test_get_sequence_names_no_sequence(self, connection):
|
||||
eq_(
|
||||
inspect(connection).get_sequence_names(),
|
||||
[],
|
||||
)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,189 @@
|
||||
# testing/suite/test_unicode_ddl.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import testing
|
||||
from sqlalchemy.testing import eq_
|
||||
from sqlalchemy.testing import fixtures
|
||||
from sqlalchemy.testing.schema import Column
|
||||
from sqlalchemy.testing.schema import Table
|
||||
|
||||
|
||||
class UnicodeSchemaTest(fixtures.TablesTest):
|
||||
__requires__ = ("unicode_ddl",)
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
global t1, t2, t3
|
||||
|
||||
t1 = Table(
|
||||
"unitable1",
|
||||
metadata,
|
||||
Column("méil", Integer, primary_key=True),
|
||||
Column("\u6e2c\u8a66", Integer),
|
||||
test_needs_fk=True,
|
||||
)
|
||||
t2 = Table(
|
||||
"Unitéble2",
|
||||
metadata,
|
||||
Column("méil", Integer, primary_key=True, key="a"),
|
||||
Column(
|
||||
"\u6e2c\u8a66",
|
||||
Integer,
|
||||
ForeignKey("unitable1.méil"),
|
||||
key="b",
|
||||
),
|
||||
test_needs_fk=True,
|
||||
)
|
||||
|
||||
# Few DBs support Unicode foreign keys
|
||||
if testing.against("sqlite"):
|
||||
t3 = Table(
|
||||
"\u6e2c\u8a66",
|
||||
metadata,
|
||||
Column(
|
||||
"\u6e2c\u8a66_id",
|
||||
Integer,
|
||||
primary_key=True,
|
||||
autoincrement=False,
|
||||
),
|
||||
Column(
|
||||
"unitable1_\u6e2c\u8a66",
|
||||
Integer,
|
||||
ForeignKey("unitable1.\u6e2c\u8a66"),
|
||||
),
|
||||
Column("Unitéble2_b", Integer, ForeignKey("Unitéble2.b")),
|
||||
Column(
|
||||
"\u6e2c\u8a66_self",
|
||||
Integer,
|
||||
ForeignKey("\u6e2c\u8a66.\u6e2c\u8a66_id"),
|
||||
),
|
||||
test_needs_fk=True,
|
||||
)
|
||||
else:
|
||||
t3 = Table(
|
||||
"\u6e2c\u8a66",
|
||||
metadata,
|
||||
Column(
|
||||
"\u6e2c\u8a66_id",
|
||||
Integer,
|
||||
primary_key=True,
|
||||
autoincrement=False,
|
||||
),
|
||||
Column("unitable1_\u6e2c\u8a66", Integer),
|
||||
Column("Unitéble2_b", Integer),
|
||||
Column("\u6e2c\u8a66_self", Integer),
|
||||
test_needs_fk=True,
|
||||
)
|
||||
|
||||
def test_insert(self, connection):
|
||||
connection.execute(t1.insert(), {"méil": 1, "\u6e2c\u8a66": 5})
|
||||
connection.execute(t2.insert(), {"a": 1, "b": 1})
|
||||
connection.execute(
|
||||
t3.insert(),
|
||||
{
|
||||
"\u6e2c\u8a66_id": 1,
|
||||
"unitable1_\u6e2c\u8a66": 5,
|
||||
"Unitéble2_b": 1,
|
||||
"\u6e2c\u8a66_self": 1,
|
||||
},
|
||||
)
|
||||
|
||||
eq_(connection.execute(t1.select()).fetchall(), [(1, 5)])
|
||||
eq_(connection.execute(t2.select()).fetchall(), [(1, 1)])
|
||||
eq_(connection.execute(t3.select()).fetchall(), [(1, 5, 1, 1)])
|
||||
|
||||
def test_col_targeting(self, connection):
|
||||
connection.execute(t1.insert(), {"méil": 1, "\u6e2c\u8a66": 5})
|
||||
connection.execute(t2.insert(), {"a": 1, "b": 1})
|
||||
connection.execute(
|
||||
t3.insert(),
|
||||
{
|
||||
"\u6e2c\u8a66_id": 1,
|
||||
"unitable1_\u6e2c\u8a66": 5,
|
||||
"Unitéble2_b": 1,
|
||||
"\u6e2c\u8a66_self": 1,
|
||||
},
|
||||
)
|
||||
|
||||
row = connection.execute(t1.select()).first()
|
||||
eq_(row._mapping[t1.c["méil"]], 1)
|
||||
eq_(row._mapping[t1.c["\u6e2c\u8a66"]], 5)
|
||||
|
||||
row = connection.execute(t2.select()).first()
|
||||
eq_(row._mapping[t2.c["a"]], 1)
|
||||
eq_(row._mapping[t2.c["b"]], 1)
|
||||
|
||||
row = connection.execute(t3.select()).first()
|
||||
eq_(row._mapping[t3.c["\u6e2c\u8a66_id"]], 1)
|
||||
eq_(row._mapping[t3.c["unitable1_\u6e2c\u8a66"]], 5)
|
||||
eq_(row._mapping[t3.c["Unitéble2_b"]], 1)
|
||||
eq_(row._mapping[t3.c["\u6e2c\u8a66_self"]], 1)
|
||||
|
||||
def test_reflect(self, connection):
|
||||
connection.execute(t1.insert(), {"méil": 2, "\u6e2c\u8a66": 7})
|
||||
connection.execute(t2.insert(), {"a": 2, "b": 2})
|
||||
connection.execute(
|
||||
t3.insert(),
|
||||
{
|
||||
"\u6e2c\u8a66_id": 2,
|
||||
"unitable1_\u6e2c\u8a66": 7,
|
||||
"Unitéble2_b": 2,
|
||||
"\u6e2c\u8a66_self": 2,
|
||||
},
|
||||
)
|
||||
|
||||
meta = MetaData()
|
||||
tt1 = Table(t1.name, meta, autoload_with=connection)
|
||||
tt2 = Table(t2.name, meta, autoload_with=connection)
|
||||
tt3 = Table(t3.name, meta, autoload_with=connection)
|
||||
|
||||
connection.execute(tt1.insert(), {"méil": 1, "\u6e2c\u8a66": 5})
|
||||
connection.execute(tt2.insert(), {"méil": 1, "\u6e2c\u8a66": 1})
|
||||
connection.execute(
|
||||
tt3.insert(),
|
||||
{
|
||||
"\u6e2c\u8a66_id": 1,
|
||||
"unitable1_\u6e2c\u8a66": 5,
|
||||
"Unitéble2_b": 1,
|
||||
"\u6e2c\u8a66_self": 1,
|
||||
},
|
||||
)
|
||||
|
||||
eq_(
|
||||
connection.execute(tt1.select().order_by(desc("méil"))).fetchall(),
|
||||
[(2, 7), (1, 5)],
|
||||
)
|
||||
eq_(
|
||||
connection.execute(tt2.select().order_by(desc("méil"))).fetchall(),
|
||||
[(2, 2), (1, 1)],
|
||||
)
|
||||
eq_(
|
||||
connection.execute(
|
||||
tt3.select().order_by(desc("\u6e2c\u8a66_id"))
|
||||
).fetchall(),
|
||||
[(2, 7, 2, 2), (1, 5, 1, 1)],
|
||||
)
|
||||
|
||||
def test_repr(self):
|
||||
meta = MetaData()
|
||||
t = Table("\u6e2c\u8a66", meta, Column("\u6e2c\u8a66_id", Integer))
|
||||
eq_(
|
||||
repr(t),
|
||||
(
|
||||
"Table('測試', MetaData(), "
|
||||
"Column('測試_id', Integer(), "
|
||||
"table=<測試>), "
|
||||
"schema=None)"
|
||||
),
|
||||
)
|
@ -0,0 +1,139 @@
|
||||
# testing/suite/test_update_delete.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from .. import fixtures
|
||||
from ..assertions import eq_
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import Integer
|
||||
from ... import String
|
||||
from ... import testing
|
||||
|
||||
|
||||
class SimpleUpdateDeleteTest(fixtures.TablesTest):
|
||||
run_deletes = "each"
|
||||
__requires__ = ("sane_rowcount",)
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"plain_pk",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls, connection):
|
||||
connection.execute(
|
||||
cls.tables.plain_pk.insert(),
|
||||
[
|
||||
{"id": 1, "data": "d1"},
|
||||
{"id": 2, "data": "d2"},
|
||||
{"id": 3, "data": "d3"},
|
||||
],
|
||||
)
|
||||
|
||||
def test_update(self, connection):
|
||||
t = self.tables.plain_pk
|
||||
r = connection.execute(
|
||||
t.update().where(t.c.id == 2), dict(data="d2_new")
|
||||
)
|
||||
assert not r.is_insert
|
||||
assert not r.returns_rows
|
||||
assert r.rowcount == 1
|
||||
|
||||
eq_(
|
||||
connection.execute(t.select().order_by(t.c.id)).fetchall(),
|
||||
[(1, "d1"), (2, "d2_new"), (3, "d3")],
|
||||
)
|
||||
|
||||
def test_delete(self, connection):
|
||||
t = self.tables.plain_pk
|
||||
r = connection.execute(t.delete().where(t.c.id == 2))
|
||||
assert not r.is_insert
|
||||
assert not r.returns_rows
|
||||
assert r.rowcount == 1
|
||||
eq_(
|
||||
connection.execute(t.select().order_by(t.c.id)).fetchall(),
|
||||
[(1, "d1"), (3, "d3")],
|
||||
)
|
||||
|
||||
@testing.variation("criteria", ["rows", "norows", "emptyin"])
|
||||
@testing.requires.update_returning
|
||||
def test_update_returning(self, connection, criteria):
|
||||
t = self.tables.plain_pk
|
||||
|
||||
stmt = t.update().returning(t.c.id, t.c.data)
|
||||
|
||||
if criteria.norows:
|
||||
stmt = stmt.where(t.c.id == 10)
|
||||
elif criteria.rows:
|
||||
stmt = stmt.where(t.c.id == 2)
|
||||
elif criteria.emptyin:
|
||||
stmt = stmt.where(t.c.id.in_([]))
|
||||
else:
|
||||
criteria.fail()
|
||||
|
||||
r = connection.execute(stmt, dict(data="d2_new"))
|
||||
assert not r.is_insert
|
||||
assert r.returns_rows
|
||||
eq_(r.keys(), ["id", "data"])
|
||||
|
||||
if criteria.rows:
|
||||
eq_(r.all(), [(2, "d2_new")])
|
||||
else:
|
||||
eq_(r.all(), [])
|
||||
|
||||
eq_(
|
||||
connection.execute(t.select().order_by(t.c.id)).fetchall(),
|
||||
(
|
||||
[(1, "d1"), (2, "d2_new"), (3, "d3")]
|
||||
if criteria.rows
|
||||
else [(1, "d1"), (2, "d2"), (3, "d3")]
|
||||
),
|
||||
)
|
||||
|
||||
@testing.variation("criteria", ["rows", "norows", "emptyin"])
|
||||
@testing.requires.delete_returning
|
||||
def test_delete_returning(self, connection, criteria):
|
||||
t = self.tables.plain_pk
|
||||
|
||||
stmt = t.delete().returning(t.c.id, t.c.data)
|
||||
|
||||
if criteria.norows:
|
||||
stmt = stmt.where(t.c.id == 10)
|
||||
elif criteria.rows:
|
||||
stmt = stmt.where(t.c.id == 2)
|
||||
elif criteria.emptyin:
|
||||
stmt = stmt.where(t.c.id.in_([]))
|
||||
else:
|
||||
criteria.fail()
|
||||
|
||||
r = connection.execute(stmt)
|
||||
assert not r.is_insert
|
||||
assert r.returns_rows
|
||||
eq_(r.keys(), ["id", "data"])
|
||||
|
||||
if criteria.rows:
|
||||
eq_(r.all(), [(2, "d2")])
|
||||
else:
|
||||
eq_(r.all(), [])
|
||||
|
||||
eq_(
|
||||
connection.execute(t.select().order_by(t.c.id)).fetchall(),
|
||||
(
|
||||
[(1, "d1"), (3, "d3")]
|
||||
if criteria.rows
|
||||
else [(1, "d1"), (2, "d2"), (3, "d3")]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
__all__ = ("SimpleUpdateDeleteTest",)
|
538
venv/lib/python3.11/site-packages/sqlalchemy/testing/util.py
Normal file
538
venv/lib/python3.11/site-packages/sqlalchemy/testing/util.py
Normal file
@ -0,0 +1,538 @@
|
||||
# testing/util.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
import contextlib
|
||||
import decimal
|
||||
import gc
|
||||
from itertools import chain
|
||||
import random
|
||||
import sys
|
||||
from sys import getsizeof
|
||||
import time
|
||||
import types
|
||||
from typing import Any
|
||||
|
||||
from . import config
|
||||
from . import mock
|
||||
from .. import inspect
|
||||
from ..engine import Connection
|
||||
from ..schema import Column
|
||||
from ..schema import DropConstraint
|
||||
from ..schema import DropTable
|
||||
from ..schema import ForeignKeyConstraint
|
||||
from ..schema import MetaData
|
||||
from ..schema import Table
|
||||
from ..sql import schema
|
||||
from ..sql.sqltypes import Integer
|
||||
from ..util import decorator
|
||||
from ..util import defaultdict
|
||||
from ..util import has_refcount_gc
|
||||
from ..util import inspect_getfullargspec
|
||||
|
||||
|
||||
if not has_refcount_gc:
|
||||
|
||||
def non_refcount_gc_collect(*args):
|
||||
gc.collect()
|
||||
gc.collect()
|
||||
|
||||
gc_collect = lazy_gc = non_refcount_gc_collect
|
||||
else:
|
||||
# assume CPython - straight gc.collect, lazy_gc() is a pass
|
||||
gc_collect = gc.collect
|
||||
|
||||
def lazy_gc():
|
||||
pass
|
||||
|
||||
|
||||
def picklers():
|
||||
picklers = set()
|
||||
import pickle
|
||||
|
||||
picklers.add(pickle)
|
||||
|
||||
# yes, this thing needs this much testing
|
||||
for pickle_ in picklers:
|
||||
for protocol in range(-2, pickle.HIGHEST_PROTOCOL + 1):
|
||||
yield pickle_.loads, lambda d: pickle_.dumps(d, protocol)
|
||||
|
||||
|
||||
def random_choices(population, k=1):
|
||||
return random.choices(population, k=k)
|
||||
|
||||
|
||||
def round_decimal(value, prec):
|
||||
if isinstance(value, float):
|
||||
return round(value, prec)
|
||||
|
||||
# can also use shift() here but that is 2.6 only
|
||||
return (value * decimal.Decimal("1" + "0" * prec)).to_integral(
|
||||
decimal.ROUND_FLOOR
|
||||
) / pow(10, prec)
|
||||
|
||||
|
||||
class RandomSet(set):
|
||||
def __iter__(self):
|
||||
l = list(set.__iter__(self))
|
||||
random.shuffle(l)
|
||||
return iter(l)
|
||||
|
||||
def pop(self):
|
||||
index = random.randint(0, len(self) - 1)
|
||||
item = list(set.__iter__(self))[index]
|
||||
self.remove(item)
|
||||
return item
|
||||
|
||||
def union(self, other):
|
||||
return RandomSet(set.union(self, other))
|
||||
|
||||
def difference(self, other):
|
||||
return RandomSet(set.difference(self, other))
|
||||
|
||||
def intersection(self, other):
|
||||
return RandomSet(set.intersection(self, other))
|
||||
|
||||
def copy(self):
|
||||
return RandomSet(self)
|
||||
|
||||
|
||||
def conforms_partial_ordering(tuples, sorted_elements):
|
||||
"""True if the given sorting conforms to the given partial ordering."""
|
||||
|
||||
deps = defaultdict(set)
|
||||
for parent, child in tuples:
|
||||
deps[parent].add(child)
|
||||
for i, node in enumerate(sorted_elements):
|
||||
for n in sorted_elements[i:]:
|
||||
if node in deps[n]:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def all_partial_orderings(tuples, elements):
|
||||
edges = defaultdict(set)
|
||||
for parent, child in tuples:
|
||||
edges[child].add(parent)
|
||||
|
||||
def _all_orderings(elements):
|
||||
if len(elements) == 1:
|
||||
yield list(elements)
|
||||
else:
|
||||
for elem in elements:
|
||||
subset = set(elements).difference([elem])
|
||||
if not subset.intersection(edges[elem]):
|
||||
for sub_ordering in _all_orderings(subset):
|
||||
yield [elem] + sub_ordering
|
||||
|
||||
return iter(_all_orderings(elements))
|
||||
|
||||
|
||||
def function_named(fn, name):
|
||||
"""Return a function with a given __name__.
|
||||
|
||||
Will assign to __name__ and return the original function if possible on
|
||||
the Python implementation, otherwise a new function will be constructed.
|
||||
|
||||
This function should be phased out as much as possible
|
||||
in favor of @decorator. Tests that "generate" many named tests
|
||||
should be modernized.
|
||||
|
||||
"""
|
||||
try:
|
||||
fn.__name__ = name
|
||||
except TypeError:
|
||||
fn = types.FunctionType(
|
||||
fn.__code__, fn.__globals__, name, fn.__defaults__, fn.__closure__
|
||||
)
|
||||
return fn
|
||||
|
||||
|
||||
def run_as_contextmanager(ctx, fn, *arg, **kw):
|
||||
"""Run the given function under the given contextmanager,
|
||||
simulating the behavior of 'with' to support older
|
||||
Python versions.
|
||||
|
||||
This is not necessary anymore as we have placed 2.6
|
||||
as minimum Python version, however some tests are still using
|
||||
this structure.
|
||||
|
||||
"""
|
||||
|
||||
obj = ctx.__enter__()
|
||||
try:
|
||||
result = fn(obj, *arg, **kw)
|
||||
ctx.__exit__(None, None, None)
|
||||
return result
|
||||
except:
|
||||
exc_info = sys.exc_info()
|
||||
raise_ = ctx.__exit__(*exc_info)
|
||||
if not raise_:
|
||||
raise
|
||||
else:
|
||||
return raise_
|
||||
|
||||
|
||||
def rowset(results):
|
||||
"""Converts the results of sql execution into a plain set of column tuples.
|
||||
|
||||
Useful for asserting the results of an unordered query.
|
||||
"""
|
||||
|
||||
return {tuple(row) for row in results}
|
||||
|
||||
|
||||
def fail(msg):
|
||||
assert False, msg
|
||||
|
||||
|
||||
@decorator
|
||||
def provide_metadata(fn, *args, **kw):
|
||||
"""Provide bound MetaData for a single test, dropping afterwards.
|
||||
|
||||
Legacy; use the "metadata" pytest fixture.
|
||||
|
||||
"""
|
||||
|
||||
from . import fixtures
|
||||
|
||||
metadata = schema.MetaData()
|
||||
self = args[0]
|
||||
prev_meta = getattr(self, "metadata", None)
|
||||
self.metadata = metadata
|
||||
try:
|
||||
return fn(*args, **kw)
|
||||
finally:
|
||||
# close out some things that get in the way of dropping tables.
|
||||
# when using the "metadata" fixture, there is a set ordering
|
||||
# of things that makes sure things are cleaned up in order, however
|
||||
# the simple "decorator" nature of this legacy function means
|
||||
# we have to hardcode some of that cleanup ahead of time.
|
||||
|
||||
# close ORM sessions
|
||||
fixtures.close_all_sessions()
|
||||
|
||||
# integrate with the "connection" fixture as there are many
|
||||
# tests where it is used along with provide_metadata
|
||||
cfc = fixtures.base._connection_fixture_connection
|
||||
if cfc:
|
||||
# TODO: this warning can be used to find all the places
|
||||
# this is used with connection fixture
|
||||
# warn("mixing legacy provide metadata with connection fixture")
|
||||
drop_all_tables_from_metadata(metadata, cfc)
|
||||
# as the provide_metadata fixture is often used with "testing.db",
|
||||
# when we do the drop we have to commit the transaction so that
|
||||
# the DB is actually updated as the CREATE would have been
|
||||
# committed
|
||||
cfc.get_transaction().commit()
|
||||
else:
|
||||
drop_all_tables_from_metadata(metadata, config.db)
|
||||
self.metadata = prev_meta
|
||||
|
||||
|
||||
def flag_combinations(*combinations):
|
||||
"""A facade around @testing.combinations() oriented towards boolean
|
||||
keyword-based arguments.
|
||||
|
||||
Basically generates a nice looking identifier based on the keywords
|
||||
and also sets up the argument names.
|
||||
|
||||
E.g.::
|
||||
|
||||
@testing.flag_combinations(
|
||||
dict(lazy=False, passive=False),
|
||||
dict(lazy=True, passive=False),
|
||||
dict(lazy=False, passive=True),
|
||||
dict(lazy=False, passive=True, raiseload=True),
|
||||
)
|
||||
def test_fn(lazy, passive, raiseload): ...
|
||||
|
||||
would result in::
|
||||
|
||||
@testing.combinations(
|
||||
("", False, False, False),
|
||||
("lazy", True, False, False),
|
||||
("lazy_passive", True, True, False),
|
||||
("lazy_passive", True, True, True),
|
||||
id_="iaaa",
|
||||
argnames="lazy,passive,raiseload",
|
||||
)
|
||||
def test_fn(lazy, passive, raiseload): ...
|
||||
|
||||
"""
|
||||
|
||||
keys = set()
|
||||
|
||||
for d in combinations:
|
||||
keys.update(d)
|
||||
|
||||
keys = sorted(keys)
|
||||
|
||||
return config.combinations(
|
||||
*[
|
||||
("_".join(k for k in keys if d.get(k, False)),)
|
||||
+ tuple(d.get(k, False) for k in keys)
|
||||
for d in combinations
|
||||
],
|
||||
id_="i" + ("a" * len(keys)),
|
||||
argnames=",".join(keys),
|
||||
)
|
||||
|
||||
|
||||
def lambda_combinations(lambda_arg_sets, **kw):
|
||||
args = inspect_getfullargspec(lambda_arg_sets)
|
||||
|
||||
arg_sets = lambda_arg_sets(*[mock.Mock() for arg in args[0]])
|
||||
|
||||
def create_fixture(pos):
|
||||
def fixture(**kw):
|
||||
return lambda_arg_sets(**kw)[pos]
|
||||
|
||||
fixture.__name__ = "fixture_%3.3d" % pos
|
||||
return fixture
|
||||
|
||||
return config.combinations(
|
||||
*[(create_fixture(i),) for i in range(len(arg_sets))], **kw
|
||||
)
|
||||
|
||||
|
||||
def resolve_lambda(__fn, **kw):
|
||||
"""Given a no-arg lambda and a namespace, return a new lambda that
|
||||
has all the values filled in.
|
||||
|
||||
This is used so that we can have module-level fixtures that
|
||||
refer to instance-level variables using lambdas.
|
||||
|
||||
"""
|
||||
|
||||
pos_args = inspect_getfullargspec(__fn)[0]
|
||||
pass_pos_args = {arg: kw.pop(arg) for arg in pos_args}
|
||||
glb = dict(__fn.__globals__)
|
||||
glb.update(kw)
|
||||
new_fn = types.FunctionType(__fn.__code__, glb)
|
||||
return new_fn(**pass_pos_args)
|
||||
|
||||
|
||||
def metadata_fixture(ddl="function"):
|
||||
"""Provide MetaData for a pytest fixture."""
|
||||
|
||||
def decorate(fn):
|
||||
def run_ddl(self):
|
||||
metadata = self.metadata = schema.MetaData()
|
||||
try:
|
||||
result = fn(self, metadata)
|
||||
metadata.create_all(config.db)
|
||||
# TODO:
|
||||
# somehow get a per-function dml erase fixture here
|
||||
yield result
|
||||
finally:
|
||||
metadata.drop_all(config.db)
|
||||
|
||||
return config.fixture(scope=ddl)(run_ddl)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def force_drop_names(*names):
|
||||
"""Force the given table names to be dropped after test complete,
|
||||
isolating for foreign key cycles
|
||||
|
||||
"""
|
||||
|
||||
@decorator
|
||||
def go(fn, *args, **kw):
|
||||
try:
|
||||
return fn(*args, **kw)
|
||||
finally:
|
||||
drop_all_tables(config.db, inspect(config.db), include_names=names)
|
||||
|
||||
return go
|
||||
|
||||
|
||||
class adict(dict):
|
||||
"""Dict keys available as attributes. Shadows."""
|
||||
|
||||
def __getattribute__(self, key):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return dict.__getattribute__(self, key)
|
||||
|
||||
def __call__(self, *keys):
|
||||
return tuple([self[key] for key in keys])
|
||||
|
||||
get_all = __call__
|
||||
|
||||
|
||||
def drop_all_tables_from_metadata(metadata, engine_or_connection):
|
||||
from . import engines
|
||||
|
||||
def go(connection):
|
||||
engines.testing_reaper.prepare_for_drop_tables(connection)
|
||||
|
||||
if not connection.dialect.supports_alter:
|
||||
from . import assertions
|
||||
|
||||
with assertions.expect_warnings(
|
||||
"Can't sort tables", assert_=False
|
||||
):
|
||||
metadata.drop_all(connection)
|
||||
else:
|
||||
metadata.drop_all(connection)
|
||||
|
||||
if not isinstance(engine_or_connection, Connection):
|
||||
with engine_or_connection.begin() as connection:
|
||||
go(connection)
|
||||
else:
|
||||
go(engine_or_connection)
|
||||
|
||||
|
||||
def drop_all_tables(
|
||||
engine,
|
||||
inspector,
|
||||
schema=None,
|
||||
consider_schemas=(None,),
|
||||
include_names=None,
|
||||
):
|
||||
if include_names is not None:
|
||||
include_names = set(include_names)
|
||||
|
||||
if schema is not None:
|
||||
assert consider_schemas == (
|
||||
None,
|
||||
), "consider_schemas and schema are mutually exclusive"
|
||||
consider_schemas = (schema,)
|
||||
|
||||
with engine.begin() as conn:
|
||||
for table_key, fkcs in reversed(
|
||||
inspector.sort_tables_on_foreign_key_dependency(
|
||||
consider_schemas=consider_schemas
|
||||
)
|
||||
):
|
||||
if table_key:
|
||||
if (
|
||||
include_names is not None
|
||||
and table_key[1] not in include_names
|
||||
):
|
||||
continue
|
||||
conn.execute(
|
||||
DropTable(
|
||||
Table(table_key[1], MetaData(), schema=table_key[0])
|
||||
)
|
||||
)
|
||||
elif fkcs:
|
||||
if not engine.dialect.supports_alter:
|
||||
continue
|
||||
for t_key, fkc in fkcs:
|
||||
if (
|
||||
include_names is not None
|
||||
and t_key[1] not in include_names
|
||||
):
|
||||
continue
|
||||
tb = Table(
|
||||
t_key[1],
|
||||
MetaData(),
|
||||
Column("x", Integer),
|
||||
Column("y", Integer),
|
||||
schema=t_key[0],
|
||||
)
|
||||
conn.execute(
|
||||
DropConstraint(
|
||||
ForeignKeyConstraint([tb.c.x], [tb.c.y], name=fkc)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def teardown_events(event_cls):
|
||||
@decorator
|
||||
def decorate(fn, *arg, **kw):
|
||||
try:
|
||||
return fn(*arg, **kw)
|
||||
finally:
|
||||
event_cls._clear()
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def total_size(o):
|
||||
"""Returns the approximate memory footprint an object and all of its
|
||||
contents.
|
||||
|
||||
source: https://code.activestate.com/recipes/577504/
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def dict_handler(d):
|
||||
return chain.from_iterable(d.items())
|
||||
|
||||
all_handlers = {
|
||||
tuple: iter,
|
||||
list: iter,
|
||||
deque: iter,
|
||||
dict: dict_handler,
|
||||
set: iter,
|
||||
frozenset: iter,
|
||||
}
|
||||
seen = set() # track which object id's have already been seen
|
||||
default_size = getsizeof(0) # estimate sizeof object without __sizeof__
|
||||
|
||||
def sizeof(o):
|
||||
if id(o) in seen: # do not double count the same object
|
||||
return 0
|
||||
seen.add(id(o))
|
||||
s = getsizeof(o, default_size)
|
||||
|
||||
for typ, handler in all_handlers.items():
|
||||
if isinstance(o, typ):
|
||||
s += sum(map(sizeof, handler(o)))
|
||||
break
|
||||
return s
|
||||
|
||||
return sizeof(o)
|
||||
|
||||
|
||||
def count_cache_key_tuples(tup):
|
||||
"""given a cache key tuple, counts how many instances of actual
|
||||
tuples are found.
|
||||
|
||||
used to alert large jumps in cache key complexity.
|
||||
|
||||
"""
|
||||
stack = [tup]
|
||||
|
||||
sentinel = object()
|
||||
num_elements = 0
|
||||
|
||||
while stack:
|
||||
elem = stack.pop(0)
|
||||
if elem is sentinel:
|
||||
num_elements += 1
|
||||
elif isinstance(elem, tuple):
|
||||
if elem:
|
||||
stack = list(elem) + [sentinel] + stack
|
||||
return num_elements
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def skip_if_timeout(seconds: float, cleanup: Any = None):
|
||||
|
||||
now = time.time()
|
||||
yield
|
||||
sec = time.time() - now
|
||||
if sec > seconds:
|
||||
try:
|
||||
cleanup()
|
||||
finally:
|
||||
config.skip_test(
|
||||
f"test took too long ({sec:.4f} seconds > {seconds})"
|
||||
)
|
@ -0,0 +1,52 @@
|
||||
# testing/warnings.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
|
||||
from . import assertions
|
||||
from .. import exc
|
||||
from .. import exc as sa_exc
|
||||
from ..exc import SATestSuiteWarning
|
||||
from ..util.langhelpers import _warnings_warn
|
||||
|
||||
|
||||
def warn_test_suite(message):
|
||||
_warnings_warn(message, category=SATestSuiteWarning)
|
||||
|
||||
|
||||
def setup_filters():
|
||||
"""hook for setting up warnings filters.
|
||||
|
||||
SQLAlchemy-specific classes must only be here and not in pytest config,
|
||||
as we need to delay importing SQLAlchemy until conftest.py has been
|
||||
processed.
|
||||
|
||||
NOTE: filters on subclasses of DeprecationWarning or
|
||||
PendingDeprecationWarning have no effect if added here, since pytest
|
||||
will add at each test the following filters
|
||||
``always::PendingDeprecationWarning`` and ``always::DeprecationWarning``
|
||||
that will take precedence over any added here.
|
||||
|
||||
"""
|
||||
warnings.filterwarnings("error", category=exc.SAWarning)
|
||||
warnings.filterwarnings("always", category=exc.SATestSuiteWarning)
|
||||
|
||||
|
||||
def assert_warnings(fn, warning_msgs, regex=False):
|
||||
"""Assert that each of the given warnings are emitted by fn.
|
||||
|
||||
Deprecated. Please use assertions.expect_warnings().
|
||||
|
||||
"""
|
||||
|
||||
with assertions._expect_warnings(
|
||||
sa_exc.SAWarning, warning_msgs, regex=regex
|
||||
):
|
||||
return fn()
|
Reference in New Issue
Block a user