Update 2025-04-13_16:25:39

This commit is contained in:
root
2025-04-13 16:25:41 +02:00
commit 4c711360d3
2979 changed files with 666585 additions and 0 deletions

View File

@ -0,0 +1,96 @@
# testing/__init__.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from unittest import mock
from . import config
from .assertions import assert_raises
from .assertions import assert_raises_context_ok
from .assertions import assert_raises_message
from .assertions import assert_raises_message_context_ok
from .assertions import assert_warns
from .assertions import assert_warns_message
from .assertions import AssertsCompiledSQL
from .assertions import AssertsExecutionResults
from .assertions import ComparesIndexes
from .assertions import ComparesTables
from .assertions import emits_warning
from .assertions import emits_warning_on
from .assertions import eq_
from .assertions import eq_ignore_whitespace
from .assertions import eq_regex
from .assertions import expect_deprecated
from .assertions import expect_deprecated_20
from .assertions import expect_raises
from .assertions import expect_raises_message
from .assertions import expect_warnings
from .assertions import in_
from .assertions import int_within_variance
from .assertions import is_
from .assertions import is_false
from .assertions import is_instance_of
from .assertions import is_none
from .assertions import is_not
from .assertions import is_not_
from .assertions import is_not_none
from .assertions import is_true
from .assertions import le_
from .assertions import ne_
from .assertions import not_in
from .assertions import not_in_
from .assertions import startswith_
from .assertions import uses_deprecated
from .config import add_to_marker
from .config import async_test
from .config import combinations
from .config import combinations_list
from .config import db
from .config import fixture
from .config import requirements as requires
from .config import skip_test
from .config import Variation
from .config import variation
from .config import variation_fixture
from .exclusions import _is_excluded
from .exclusions import _server_version
from .exclusions import against as _against
from .exclusions import db_spec
from .exclusions import exclude
from .exclusions import fails
from .exclusions import fails_if
from .exclusions import fails_on
from .exclusions import fails_on_everything_except
from .exclusions import future
from .exclusions import only_if
from .exclusions import only_on
from .exclusions import skip
from .exclusions import skip_if
from .schema import eq_clause_element
from .schema import eq_type_affinity
from .util import adict
from .util import fail
from .util import flag_combinations
from .util import force_drop_names
from .util import lambda_combinations
from .util import metadata_fixture
from .util import provide_metadata
from .util import resolve_lambda
from .util import rowset
from .util import run_as_contextmanager
from .util import skip_if_timeout
from .util import teardown_events
from .warnings import assert_warnings
from .warnings import warn_test_suite
def against(*queries):
return _against(config._current, *queries)
crashes = skip

View File

@ -0,0 +1,989 @@
# testing/assertions.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
from collections import defaultdict
import contextlib
from copy import copy
from itertools import filterfalse
import re
import sys
import warnings
from . import assertsql
from . import config
from . import engines
from . import mock
from .exclusions import db_spec
from .util import fail
from .. import exc as sa_exc
from .. import schema
from .. import sql
from .. import types as sqltypes
from .. import util
from ..engine import default
from ..engine import url
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..util import decorator
def expect_warnings(*messages, **kw):
"""Context manager which expects one or more warnings.
With no arguments, squelches all SAWarning emitted via
sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise
pass string expressions that will match selected warnings via regex;
all non-matching warnings are sent through.
The expect version **asserts** that the warnings were in fact seen.
Note that the test suite sets SAWarning warnings to raise exceptions.
""" # noqa
return _expect_warnings_sqla_only(sa_exc.SAWarning, messages, **kw)
@contextlib.contextmanager
def expect_warnings_on(db, *messages, **kw):
"""Context manager which expects one or more warnings on specific
dialects.
The expect version **asserts** that the warnings were in fact seen.
"""
spec = db_spec(db)
if isinstance(db, str) and not spec(config._current):
yield
else:
with expect_warnings(*messages, **kw):
yield
def emits_warning(*messages):
"""Decorator form of expect_warnings().
Note that emits_warning does **not** assert that the warnings
were in fact seen.
"""
@decorator
def decorate(fn, *args, **kw):
with expect_warnings(assert_=False, *messages):
return fn(*args, **kw)
return decorate
def expect_deprecated(*messages, **kw):
return _expect_warnings_sqla_only(
sa_exc.SADeprecationWarning, messages, **kw
)
def expect_deprecated_20(*messages, **kw):
return _expect_warnings_sqla_only(
sa_exc.Base20DeprecationWarning, messages, **kw
)
def emits_warning_on(db, *messages):
"""Mark a test as emitting a warning on a specific dialect.
With no arguments, squelches all SAWarning failures. Or pass one or more
strings; these will be matched to the root of the warning description by
warnings.filterwarnings().
Note that emits_warning_on does **not** assert that the warnings
were in fact seen.
"""
@decorator
def decorate(fn, *args, **kw):
with expect_warnings_on(db, assert_=False, *messages):
return fn(*args, **kw)
return decorate
def uses_deprecated(*messages):
"""Mark a test as immune from fatal deprecation warnings.
With no arguments, squelches all SADeprecationWarning failures.
Or pass one or more strings; these will be matched to the root
of the warning description by warnings.filterwarnings().
As a special case, you may pass a function name prefixed with //
and it will be re-written as needed to match the standard warning
verbiage emitted by the sqlalchemy.util.deprecated decorator.
Note that uses_deprecated does **not** assert that the warnings
were in fact seen.
"""
@decorator
def decorate(fn, *args, **kw):
with expect_deprecated(*messages, assert_=False):
return fn(*args, **kw)
return decorate
_FILTERS = None
_SEEN = None
_EXC_CLS = None
def _expect_warnings_sqla_only(
exc_cls,
messages,
regex=True,
search_msg=False,
assert_=True,
):
"""SQLAlchemy internal use only _expect_warnings().
Alembic is using _expect_warnings() directly, and should be updated
to use this new interface.
"""
return _expect_warnings(
exc_cls,
messages,
regex=regex,
search_msg=search_msg,
assert_=assert_,
raise_on_any_unexpected=True,
)
@contextlib.contextmanager
def _expect_warnings(
exc_cls,
messages,
regex=True,
search_msg=False,
assert_=True,
raise_on_any_unexpected=False,
squelch_other_warnings=False,
):
global _FILTERS, _SEEN, _EXC_CLS
if regex or search_msg:
filters = [re.compile(msg, re.I | re.S) for msg in messages]
else:
filters = list(messages)
if _FILTERS is not None:
# nested call; update _FILTERS and _SEEN, return. outer
# block will assert our messages
assert _SEEN is not None
assert _EXC_CLS is not None
_FILTERS.extend(filters)
_SEEN.update(filters)
_EXC_CLS += (exc_cls,)
yield
else:
seen = _SEEN = set(filters)
_FILTERS = filters
_EXC_CLS = (exc_cls,)
if raise_on_any_unexpected:
def real_warn(msg, *arg, **kw):
raise AssertionError("Got unexpected warning: %r" % msg)
else:
real_warn = warnings.warn
def our_warn(msg, *arg, **kw):
if isinstance(msg, _EXC_CLS):
exception = type(msg)
msg = str(msg)
elif arg:
exception = arg[0]
else:
exception = None
if not exception or not issubclass(exception, _EXC_CLS):
if not squelch_other_warnings:
return real_warn(msg, *arg, **kw)
else:
return
if not filters and not raise_on_any_unexpected:
return
for filter_ in filters:
if (
(search_msg and filter_.search(msg))
or (regex and filter_.match(msg))
or (not regex and filter_ == msg)
):
seen.discard(filter_)
break
else:
if not squelch_other_warnings:
real_warn(msg, *arg, **kw)
with mock.patch("warnings.warn", our_warn):
try:
yield
finally:
_SEEN = _FILTERS = _EXC_CLS = None
if assert_:
assert not seen, "Warnings were not seen: %s" % ", ".join(
"%r" % (s.pattern if regex else s) for s in seen
)
def global_cleanup_assertions():
"""Check things that have to be finalized at the end of a test suite.
Hardcoded at the moment, a modular system can be built here
to support things like PG prepared transactions, tables all
dropped, etc.
"""
_assert_no_stray_pool_connections()
def _assert_no_stray_pool_connections():
engines.testing_reaper.assert_all_closed()
def int_within_variance(expected, received, variance):
deviance = int(expected * variance)
assert (
abs(received - expected) < deviance
), "Given int value %s is not within %d%% of expected value %s" % (
received,
variance * 100,
expected,
)
def eq_regex(a, b, msg=None, flags=0):
assert re.match(b, a, flags), msg or "%r !~ %r" % (a, b)
def eq_(a, b, msg=None):
"""Assert a == b, with repr messaging on failure."""
assert a == b, msg or "%r != %r" % (a, b)
def ne_(a, b, msg=None):
"""Assert a != b, with repr messaging on failure."""
assert a != b, msg or "%r == %r" % (a, b)
def le_(a, b, msg=None):
"""Assert a <= b, with repr messaging on failure."""
assert a <= b, msg or "%r != %r" % (a, b)
def is_instance_of(a, b, msg=None):
assert isinstance(a, b), msg or "%r is not an instance of %r" % (a, b)
def is_none(a, msg=None):
is_(a, None, msg=msg)
def is_not_none(a, msg=None):
is_not(a, None, msg=msg)
def is_true(a, msg=None):
is_(bool(a), True, msg=msg)
def is_false(a, msg=None):
is_(bool(a), False, msg=msg)
def is_(a, b, msg=None):
"""Assert a is b, with repr messaging on failure."""
assert a is b, msg or "%r is not %r" % (a, b)
def is_not(a, b, msg=None):
"""Assert a is not b, with repr messaging on failure."""
assert a is not b, msg or "%r is %r" % (a, b)
# deprecated. See #5429
is_not_ = is_not
def in_(a, b, msg=None):
"""Assert a in b, with repr messaging on failure."""
assert a in b, msg or "%r not in %r" % (a, b)
def not_in(a, b, msg=None):
"""Assert a in not b, with repr messaging on failure."""
assert a not in b, msg or "%r is in %r" % (a, b)
# deprecated. See #5429
not_in_ = not_in
def startswith_(a, fragment, msg=None):
"""Assert a.startswith(fragment), with repr messaging on failure."""
assert a.startswith(fragment), msg or "%r does not start with %r" % (
a,
fragment,
)
def eq_ignore_whitespace(a, b, msg=None):
a = re.sub(r"^\s+?|\n", "", a)
a = re.sub(r" {2,}", " ", a)
a = re.sub(r"\t", "", a)
b = re.sub(r"^\s+?|\n", "", b)
b = re.sub(r" {2,}", " ", b)
b = re.sub(r"\t", "", b)
assert a == b, msg or "%r != %r" % (a, b)
def _assert_proper_exception_context(exception):
"""assert that any exception we're catching does not have a __context__
without a __cause__, and that __suppress_context__ is never set.
Python 3 will report nested as exceptions as "during the handling of
error X, error Y occurred". That's not what we want to do. we want
these exceptions in a cause chain.
"""
if (
exception.__context__ is not exception.__cause__
and not exception.__suppress_context__
):
assert False, (
"Exception %r was correctly raised but did not set a cause, "
"within context %r as its cause."
% (exception, exception.__context__)
)
def assert_raises(except_cls, callable_, *args, **kw):
return _assert_raises(except_cls, callable_, args, kw, check_context=True)
def assert_raises_context_ok(except_cls, callable_, *args, **kw):
return _assert_raises(except_cls, callable_, args, kw)
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
return _assert_raises(
except_cls, callable_, args, kwargs, msg=msg, check_context=True
)
def assert_warns(except_cls, callable_, *args, **kwargs):
"""legacy adapter function for functions that were previously using
assert_raises with SAWarning or similar.
has some workarounds to accommodate the fact that the callable completes
with this approach rather than stopping at the exception raise.
"""
with _expect_warnings_sqla_only(except_cls, [".*"]):
return callable_(*args, **kwargs)
def assert_warns_message(except_cls, msg, callable_, *args, **kwargs):
"""legacy adapter function for functions that were previously using
assert_raises with SAWarning or similar.
has some workarounds to accommodate the fact that the callable completes
with this approach rather than stopping at the exception raise.
Also uses regex.search() to match the given message to the error string
rather than regex.match().
"""
with _expect_warnings_sqla_only(
except_cls,
[msg],
search_msg=True,
regex=False,
):
return callable_(*args, **kwargs)
def assert_raises_message_context_ok(
except_cls, msg, callable_, *args, **kwargs
):
return _assert_raises(except_cls, callable_, args, kwargs, msg=msg)
def _assert_raises(
except_cls, callable_, args, kwargs, msg=None, check_context=False
):
with _expect_raises(except_cls, msg, check_context) as ec:
callable_(*args, **kwargs)
return ec.error
class _ErrorContainer:
error = None
@contextlib.contextmanager
def _expect_raises(except_cls, msg=None, check_context=False):
if (
isinstance(except_cls, type)
and issubclass(except_cls, Warning)
or isinstance(except_cls, Warning)
):
raise TypeError(
"Use expect_warnings for warnings, not "
"expect_raises / assert_raises"
)
ec = _ErrorContainer()
if check_context:
are_we_already_in_a_traceback = sys.exc_info()[0]
try:
yield ec
success = False
except except_cls as err:
ec.error = err
success = True
if msg is not None:
# I'm often pdbing here, and "err" above isn't
# in scope, so assign the string explicitly
error_as_string = str(err)
assert re.search(msg, error_as_string, re.UNICODE), "%r !~ %s" % (
msg,
error_as_string,
)
if check_context and not are_we_already_in_a_traceback:
_assert_proper_exception_context(err)
print(str(err).encode("utf-8"))
# it's generally a good idea to not carry traceback objects outside
# of the except: block, but in this case especially we seem to have
# hit some bug in either python 3.10.0b2 or greenlet or both which
# this seems to fix:
# https://github.com/python-greenlet/greenlet/issues/242
del ec
# assert outside the block so it works for AssertionError too !
assert success, "Callable did not raise an exception"
def expect_raises(except_cls, check_context=True):
return _expect_raises(except_cls, check_context=check_context)
def expect_raises_message(except_cls, msg, check_context=True):
return _expect_raises(except_cls, msg=msg, check_context=check_context)
class AssertsCompiledSQL:
def assert_compile(
self,
clause,
result,
params=None,
checkparams=None,
for_executemany=False,
check_literal_execute=None,
check_post_param=None,
dialect=None,
checkpositional=None,
check_prefetch=None,
use_default_dialect=False,
allow_dialect_select=False,
supports_default_values=True,
supports_default_metavalue=True,
literal_binds=False,
render_postcompile=False,
schema_translate_map=None,
render_schema_translate=False,
default_schema_name=None,
from_linting=False,
check_param_order=True,
use_literal_execute_for_simple_int=False,
):
if use_default_dialect:
dialect = default.DefaultDialect()
dialect.supports_default_values = supports_default_values
dialect.supports_default_metavalue = supports_default_metavalue
elif allow_dialect_select:
dialect = None
else:
if dialect is None:
dialect = getattr(self, "__dialect__", None)
if dialect is None:
dialect = config.db.dialect
elif dialect == "default" or dialect == "default_qmark":
if dialect == "default":
dialect = default.DefaultDialect()
else:
dialect = default.DefaultDialect("qmark")
dialect.supports_default_values = supports_default_values
dialect.supports_default_metavalue = supports_default_metavalue
elif dialect == "default_enhanced":
dialect = default.StrCompileDialect()
elif isinstance(dialect, str):
dialect = url.URL.create(dialect).get_dialect()()
if default_schema_name:
dialect.default_schema_name = default_schema_name
kw = {}
compile_kwargs = {}
if schema_translate_map:
kw["schema_translate_map"] = schema_translate_map
if params is not None:
kw["column_keys"] = list(params)
if literal_binds:
compile_kwargs["literal_binds"] = True
if render_postcompile:
compile_kwargs["render_postcompile"] = True
if use_literal_execute_for_simple_int:
compile_kwargs["use_literal_execute_for_simple_int"] = True
if for_executemany:
kw["for_executemany"] = True
if render_schema_translate:
kw["render_schema_translate"] = True
if from_linting or getattr(self, "assert_from_linting", False):
kw["linting"] = sql.FROM_LINTING
from sqlalchemy import orm
if isinstance(clause, orm.Query):
stmt = clause._statement_20()
stmt._label_style = LABEL_STYLE_TABLENAME_PLUS_COL
clause = stmt
if compile_kwargs:
kw["compile_kwargs"] = compile_kwargs
class DontAccess:
def __getattribute__(self, key):
raise NotImplementedError(
"compiler accessed .statement; use "
"compiler.current_executable"
)
class CheckCompilerAccess:
def __init__(self, test_statement):
self.test_statement = test_statement
self._annotations = {}
self.supports_execution = getattr(
test_statement, "supports_execution", False
)
if self.supports_execution:
self._execution_options = test_statement._execution_options
if hasattr(test_statement, "_returning"):
self._returning = test_statement._returning
if hasattr(test_statement, "_inline"):
self._inline = test_statement._inline
if hasattr(test_statement, "_return_defaults"):
self._return_defaults = test_statement._return_defaults
@property
def _variant_mapping(self):
return self.test_statement._variant_mapping
def _default_dialect(self):
return self.test_statement._default_dialect()
def compile(self, dialect, **kw):
return self.test_statement.compile.__func__(
self, dialect=dialect, **kw
)
def _compiler(self, dialect, **kw):
return self.test_statement._compiler.__func__(
self, dialect, **kw
)
def _compiler_dispatch(self, compiler, **kwargs):
if hasattr(compiler, "statement"):
with mock.patch.object(
compiler, "statement", DontAccess()
):
return self.test_statement._compiler_dispatch(
compiler, **kwargs
)
else:
return self.test_statement._compiler_dispatch(
compiler, **kwargs
)
# no construct can assume it's the "top level" construct in all cases
# as anything can be nested. ensure constructs don't assume they
# are the "self.statement" element
c = CheckCompilerAccess(clause).compile(dialect=dialect, **kw)
if isinstance(clause, sqltypes.TypeEngine):
cache_key_no_warnings = clause._static_cache_key
if cache_key_no_warnings:
hash(cache_key_no_warnings)
else:
cache_key_no_warnings = clause._generate_cache_key()
if cache_key_no_warnings:
hash(cache_key_no_warnings[0])
param_str = repr(getattr(c, "params", {}))
param_str = param_str.encode("utf-8").decode("ascii", "ignore")
print(("\nSQL String:\n" + str(c) + param_str).encode("utf-8"))
cc = re.sub(r"[\n\t]", "", str(c))
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
if checkparams is not None:
if render_postcompile:
expanded_state = c.construct_expanded_state(
params, escape_names=False
)
eq_(expanded_state.parameters, checkparams)
else:
eq_(c.construct_params(params), checkparams)
if checkpositional is not None:
if render_postcompile:
expanded_state = c.construct_expanded_state(
params, escape_names=False
)
eq_(
tuple(
[
expanded_state.parameters[x]
for x in expanded_state.positiontup
]
),
checkpositional,
)
else:
p = c.construct_params(params, escape_names=False)
eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
if check_prefetch is not None:
eq_(c.prefetch, check_prefetch)
if check_literal_execute is not None:
eq_(
{
c.bind_names[b]: b.effective_value
for b in c.literal_execute_params
},
check_literal_execute,
)
if check_post_param is not None:
eq_(
{
c.bind_names[b]: b.effective_value
for b in c.post_compile_params
},
check_post_param,
)
if check_param_order and getattr(c, "params", None):
def get_dialect(paramstyle, positional):
cp = copy(dialect)
cp.paramstyle = paramstyle
cp.positional = positional
return cp
pyformat_dialect = get_dialect("pyformat", False)
pyformat_c = clause.compile(dialect=pyformat_dialect, **kw)
stmt = re.sub(r"[\n\t]", "", str(pyformat_c))
qmark_dialect = get_dialect("qmark", True)
qmark_c = clause.compile(dialect=qmark_dialect, **kw)
values = list(qmark_c.positiontup)
escaped = qmark_c.escaped_bind_names
for post_param in (
qmark_c.post_compile_params | qmark_c.literal_execute_params
):
name = qmark_c.bind_names[post_param]
if name in values:
values = [v for v in values if v != name]
positions = []
pos_by_value = defaultdict(list)
for v in values:
try:
if v in pos_by_value:
start = pos_by_value[v][-1]
else:
start = 0
esc = escaped.get(v, v)
pos = stmt.index("%%(%s)s" % (esc,), start) + 2
positions.append(pos)
pos_by_value[v].append(pos)
except ValueError:
msg = "Expected to find bindparam %r in %r" % (v, stmt)
assert False, msg
ordered = all(
positions[i - 1] < positions[i]
for i in range(1, len(positions))
)
expected = [v for _, v in sorted(zip(positions, values))]
msg = (
"Order of parameters %s does not match the order "
"in the statement %s. Statement %r" % (values, expected, stmt)
)
is_true(ordered, msg)
class ComparesTables:
def assert_tables_equal(
self,
table,
reflected_table,
strict_types=False,
strict_constraints=True,
):
assert len(table.c) == len(reflected_table.c)
for c, reflected_c in zip(table.c, reflected_table.c):
eq_(c.name, reflected_c.name)
assert reflected_c is reflected_table.c[c.name]
if strict_constraints:
eq_(c.primary_key, reflected_c.primary_key)
eq_(c.nullable, reflected_c.nullable)
if strict_types:
msg = "Type '%s' doesn't correspond to type '%s'"
assert isinstance(reflected_c.type, type(c.type)), msg % (
reflected_c.type,
c.type,
)
else:
self.assert_types_base(reflected_c, c)
if isinstance(c.type, sqltypes.String):
eq_(c.type.length, reflected_c.type.length)
if strict_constraints:
eq_(
{f.column.name for f in c.foreign_keys},
{f.column.name for f in reflected_c.foreign_keys},
)
if c.server_default:
assert isinstance(
reflected_c.server_default, schema.FetchedValue
)
if strict_constraints:
assert len(table.primary_key) == len(reflected_table.primary_key)
for c in table.primary_key:
assert reflected_table.primary_key.columns[c.name] is not None
def assert_types_base(self, c1, c2):
assert c1.type._compare_type_affinity(
c2.type
), "On column %r, type '%s' doesn't correspond to type '%s'" % (
c1.name,
c1.type,
c2.type,
)
class AssertsExecutionResults:
def assert_result(self, result, class_, *objects):
result = list(result)
print(repr(result))
self.assert_list(result, class_, objects)
def assert_list(self, result, class_, list_):
self.assert_(
len(result) == len(list_),
"result list is not the same size as test list, "
+ "for class "
+ class_.__name__,
)
for i in range(0, len(list_)):
self.assert_row(class_, result[i], list_[i])
def assert_row(self, class_, rowobj, desc):
self.assert_(
rowobj.__class__ is class_, "item class is not " + repr(class_)
)
for key, value in desc.items():
if isinstance(value, tuple):
if isinstance(value[1], list):
self.assert_list(getattr(rowobj, key), value[0], value[1])
else:
self.assert_row(value[0], getattr(rowobj, key), value[1])
else:
self.assert_(
getattr(rowobj, key) == value,
"attribute %s value %s does not match %s"
% (key, getattr(rowobj, key), value),
)
def assert_unordered_result(self, result, cls, *expected):
"""As assert_result, but the order of objects is not considered.
The algorithm is very expensive but not a big deal for the small
numbers of rows that the test suite manipulates.
"""
class immutabledict(dict):
def __hash__(self):
return id(self)
found = util.IdentitySet(result)
expected = {immutabledict(e) for e in expected}
for wrong in filterfalse(lambda o: isinstance(o, cls), found):
fail(
'Unexpected type "%s", expected "%s"'
% (type(wrong).__name__, cls.__name__)
)
if len(found) != len(expected):
fail(
'Unexpected object count "%s", expected "%s"'
% (len(found), len(expected))
)
NOVALUE = object()
def _compare_item(obj, spec):
for key, value in spec.items():
if isinstance(value, tuple):
try:
self.assert_unordered_result(
getattr(obj, key), value[0], *value[1]
)
except AssertionError:
return False
else:
if getattr(obj, key, NOVALUE) != value:
return False
return True
for expected_item in expected:
for found_item in found:
if _compare_item(found_item, expected_item):
found.remove(found_item)
break
else:
fail(
"Expected %s instance with attributes %s not found."
% (cls.__name__, repr(expected_item))
)
return True
def sql_execution_asserter(self, db=None):
if db is None:
from . import db as db
return assertsql.assert_engine(db)
def assert_sql_execution(self, db, callable_, *rules):
with self.sql_execution_asserter(db) as asserter:
result = callable_()
asserter.assert_(*rules)
return result
def assert_sql(self, db, callable_, rules):
newrules = []
for rule in rules:
if isinstance(rule, dict):
newrule = assertsql.AllOf(
*[assertsql.CompiledSQL(k, v) for k, v in rule.items()]
)
else:
newrule = assertsql.CompiledSQL(*rule)
newrules.append(newrule)
return self.assert_sql_execution(db, callable_, *newrules)
def assert_sql_count(self, db, callable_, count):
return self.assert_sql_execution(
db, callable_, assertsql.CountStatements(count)
)
@contextlib.contextmanager
def assert_execution(self, db, *rules):
with self.sql_execution_asserter(db) as asserter:
yield
asserter.assert_(*rules)
def assert_statement_count(self, db, count):
return self.assert_execution(db, assertsql.CountStatements(count))
@contextlib.contextmanager
def assert_statement_count_multi_db(self, dbs, counts):
recs = [
(self.sql_execution_asserter(db), db, count)
for (db, count) in zip(dbs, counts)
]
asserters = []
for ctx, db, count in recs:
asserters.append(ctx.__enter__())
try:
yield
finally:
for asserter, (ctx, db, count) in zip(asserters, recs):
ctx.__exit__(None, None, None)
asserter.assert_(assertsql.CountStatements(count))
class ComparesIndexes:
def compare_table_index_with_expected(
self, table: schema.Table, expected: list, dialect_name: str
):
eq_(len(table.indexes), len(expected))
idx_dict = {idx.name: idx for idx in table.indexes}
for exp in expected:
idx = idx_dict[exp["name"]]
eq_(idx.unique, exp["unique"])
cols = [c for c in exp["column_names"] if c is not None]
eq_(len(idx.columns), len(cols))
for c in cols:
is_true(c in idx.columns)
exprs = exp.get("expressions")
if exprs:
eq_(len(idx.expressions), len(exprs))
for idx_exp, expr, col in zip(
idx.expressions, exprs, exp["column_names"]
):
if col is None:
eq_(idx_exp.text, expr)
if (
exp.get("dialect_options")
and f"{dialect_name}_include" in exp["dialect_options"]
):
eq_(
idx.dialect_options[dialect_name]["include"],
exp["dialect_options"][f"{dialect_name}_include"],
)

View File

@ -0,0 +1,516 @@
# testing/assertsql.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import collections
import contextlib
import itertools
import re
from .. import event
from ..engine import url
from ..engine.default import DefaultDialect
from ..schema import BaseDDLElement
class AssertRule:
is_consumed = False
errormessage = None
consume_statement = True
def process_statement(self, execute_observed):
pass
def no_more_statements(self):
assert False, (
"All statements are complete, but pending "
"assertion rules remain"
)
class SQLMatchRule(AssertRule):
pass
class CursorSQL(SQLMatchRule):
def __init__(self, statement, params=None, consume_statement=True):
self.statement = statement
self.params = params
self.consume_statement = consume_statement
def process_statement(self, execute_observed):
stmt = execute_observed.statements[0]
if self.statement != stmt.statement or (
self.params is not None and self.params != stmt.parameters
):
self.consume_statement = True
self.errormessage = (
"Testing for exact SQL %s parameters %s received %s %s"
% (
self.statement,
self.params,
stmt.statement,
stmt.parameters,
)
)
else:
execute_observed.statements.pop(0)
self.is_consumed = True
if not execute_observed.statements:
self.consume_statement = True
class CompiledSQL(SQLMatchRule):
def __init__(
self, statement, params=None, dialect="default", enable_returning=True
):
self.statement = statement
self.params = params
self.dialect = dialect
self.enable_returning = enable_returning
def _compare_sql(self, execute_observed, received_statement):
stmt = re.sub(r"[\n\t]", "", self.statement)
return received_statement == stmt
def _compile_dialect(self, execute_observed):
if self.dialect == "default":
dialect = DefaultDialect()
# this is currently what tests are expecting
# dialect.supports_default_values = True
dialect.supports_default_metavalue = True
if self.enable_returning:
dialect.insert_returning = dialect.update_returning = (
dialect.delete_returning
) = True
dialect.use_insertmanyvalues = True
dialect.supports_multivalues_insert = True
dialect.update_returning_multifrom = True
dialect.delete_returning_multifrom = True
# dialect.favor_returning_over_lastrowid = True
# dialect.insert_null_pk_still_autoincrements = True
# this is calculated but we need it to be True for this
# to look like all the current RETURNING dialects
assert dialect.insert_executemany_returning
return dialect
else:
return url.URL.create(self.dialect).get_dialect()()
def _received_statement(self, execute_observed):
"""reconstruct the statement and params in terms
of a target dialect, which for CompiledSQL is just DefaultDialect."""
context = execute_observed.context
compare_dialect = self._compile_dialect(execute_observed)
# received_statement runs a full compile(). we should not need to
# consider extracted_parameters; if we do this indicates some state
# is being sent from a previous cached query, which some misbehaviors
# in the ORM can cause, see #6881
cache_key = None # execute_observed.context.compiled.cache_key
extracted_parameters = (
None # execute_observed.context.extracted_parameters
)
if "schema_translate_map" in context.execution_options:
map_ = context.execution_options["schema_translate_map"]
else:
map_ = None
if isinstance(execute_observed.clauseelement, BaseDDLElement):
compiled = execute_observed.clauseelement.compile(
dialect=compare_dialect,
schema_translate_map=map_,
)
else:
compiled = execute_observed.clauseelement.compile(
cache_key=cache_key,
dialect=compare_dialect,
column_keys=context.compiled.column_keys,
for_executemany=context.compiled.for_executemany,
schema_translate_map=map_,
)
_received_statement = re.sub(r"[\n\t]", "", str(compiled))
parameters = execute_observed.parameters
if not parameters:
_received_parameters = [
compiled.construct_params(
extracted_parameters=extracted_parameters
)
]
else:
_received_parameters = [
compiled.construct_params(
m, extracted_parameters=extracted_parameters
)
for m in parameters
]
return _received_statement, _received_parameters
def process_statement(self, execute_observed):
context = execute_observed.context
_received_statement, _received_parameters = self._received_statement(
execute_observed
)
params = self._all_params(context)
equivalent = self._compare_sql(execute_observed, _received_statement)
if equivalent:
if params is not None:
all_params = list(params)
all_received = list(_received_parameters)
while all_params and all_received:
param = dict(all_params.pop(0))
for idx, received in enumerate(list(all_received)):
# do a positive compare only
for param_key in param:
# a key in param did not match current
# 'received'
if (
param_key not in received
or received[param_key] != param[param_key]
):
break
else:
# all keys in param matched 'received';
# onto next param
del all_received[idx]
break
else:
# param did not match any entry
# in all_received
equivalent = False
break
if all_params or all_received:
equivalent = False
if equivalent:
self.is_consumed = True
self.errormessage = None
else:
self.errormessage = self._failure_message(
execute_observed, params
) % {
"received_statement": _received_statement,
"received_parameters": _received_parameters,
}
def _all_params(self, context):
if self.params:
if callable(self.params):
params = self.params(context)
else:
params = self.params
if not isinstance(params, list):
params = [params]
return params
else:
return None
def _failure_message(self, execute_observed, expected_params):
return (
"Testing for compiled statement\n%r partial params %s, "
"received\n%%(received_statement)r with params "
"%%(received_parameters)r"
% (
self.statement.replace("%", "%%"),
repr(expected_params).replace("%", "%%"),
)
)
class RegexSQL(CompiledSQL):
def __init__(
self, regex, params=None, dialect="default", enable_returning=False
):
SQLMatchRule.__init__(self)
self.regex = re.compile(regex)
self.orig_regex = regex
self.params = params
self.dialect = dialect
self.enable_returning = enable_returning
def _failure_message(self, execute_observed, expected_params):
return (
"Testing for compiled statement ~%r partial params %s, "
"received %%(received_statement)r with params "
"%%(received_parameters)r"
% (
self.orig_regex.replace("%", "%%"),
repr(expected_params).replace("%", "%%"),
)
)
def _compare_sql(self, execute_observed, received_statement):
return bool(self.regex.match(received_statement))
class DialectSQL(CompiledSQL):
def _compile_dialect(self, execute_observed):
return execute_observed.context.dialect
def _compare_no_space(self, real_stmt, received_stmt):
stmt = re.sub(r"[\n\t]", "", real_stmt)
return received_stmt == stmt
def _received_statement(self, execute_observed):
received_stmt, received_params = super()._received_statement(
execute_observed
)
# TODO: why do we need this part?
for real_stmt in execute_observed.statements:
if self._compare_no_space(real_stmt.statement, received_stmt):
break
else:
raise AssertionError(
"Can't locate compiled statement %r in list of "
"statements actually invoked" % received_stmt
)
return received_stmt, execute_observed.context.compiled_parameters
def _dialect_adjusted_statement(self, dialect):
paramstyle = dialect.paramstyle
stmt = re.sub(r"[\n\t]", "", self.statement)
# temporarily escape out PG double colons
stmt = stmt.replace("::", "!!")
if paramstyle == "pyformat":
stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
else:
# positional params
repl = None
if paramstyle == "qmark":
repl = "?"
elif paramstyle == "format":
repl = r"%s"
elif paramstyle.startswith("numeric"):
counter = itertools.count(1)
num_identifier = "$" if paramstyle == "numeric_dollar" else ":"
def repl(m):
return f"{num_identifier}{next(counter)}"
stmt = re.sub(r":([\w_]+)", repl, stmt)
# put them back
stmt = stmt.replace("!!", "::")
return stmt
def _compare_sql(self, execute_observed, received_statement):
stmt = self._dialect_adjusted_statement(
execute_observed.context.dialect
)
return received_statement == stmt
def _failure_message(self, execute_observed, expected_params):
return (
"Testing for compiled statement\n%r partial params %s, "
"received\n%%(received_statement)r with params "
"%%(received_parameters)r"
% (
self._dialect_adjusted_statement(
execute_observed.context.dialect
).replace("%", "%%"),
repr(expected_params).replace("%", "%%"),
)
)
class CountStatements(AssertRule):
def __init__(self, count):
self.count = count
self._statement_count = 0
def process_statement(self, execute_observed):
self._statement_count += 1
def no_more_statements(self):
if self.count != self._statement_count:
assert False, "desired statement count %d does not match %d" % (
self.count,
self._statement_count,
)
class AllOf(AssertRule):
def __init__(self, *rules):
self.rules = set(rules)
def process_statement(self, execute_observed):
for rule in list(self.rules):
rule.errormessage = None
rule.process_statement(execute_observed)
if rule.is_consumed:
self.rules.discard(rule)
if not self.rules:
self.is_consumed = True
break
elif not rule.errormessage:
# rule is not done yet
self.errormessage = None
break
else:
self.errormessage = list(self.rules)[0].errormessage
class EachOf(AssertRule):
def __init__(self, *rules):
self.rules = list(rules)
def process_statement(self, execute_observed):
if not self.rules:
self.is_consumed = True
self.consume_statement = False
while self.rules:
rule = self.rules[0]
rule.process_statement(execute_observed)
if rule.is_consumed:
self.rules.pop(0)
elif rule.errormessage:
self.errormessage = rule.errormessage
if rule.consume_statement:
break
if not self.rules:
self.is_consumed = True
def no_more_statements(self):
if self.rules and not self.rules[0].is_consumed:
self.rules[0].no_more_statements()
elif self.rules:
super().no_more_statements()
class Conditional(EachOf):
def __init__(self, condition, rules, else_rules):
if condition:
super().__init__(*rules)
else:
super().__init__(*else_rules)
class Or(AllOf):
def process_statement(self, execute_observed):
for rule in self.rules:
rule.process_statement(execute_observed)
if rule.is_consumed:
self.is_consumed = True
break
else:
self.errormessage = list(self.rules)[0].errormessage
class SQLExecuteObserved:
def __init__(self, context, clauseelement, multiparams, params):
self.context = context
self.clauseelement = clauseelement
if multiparams:
self.parameters = multiparams
elif params:
self.parameters = [params]
else:
self.parameters = []
self.statements = []
def __repr__(self):
return str(self.statements)
class SQLCursorExecuteObserved(
collections.namedtuple(
"SQLCursorExecuteObserved",
["statement", "parameters", "context", "executemany"],
)
):
pass
class SQLAsserter:
def __init__(self):
self.accumulated = []
def _close(self):
self._final = self.accumulated
del self.accumulated
def assert_(self, *rules):
rule = EachOf(*rules)
observed = list(self._final)
while observed:
statement = observed.pop(0)
rule.process_statement(statement)
if rule.is_consumed:
break
elif rule.errormessage:
assert False, rule.errormessage
if observed:
assert False, "Additional SQL statements remain:\n%s" % observed
elif not rule.is_consumed:
rule.no_more_statements()
@contextlib.contextmanager
def assert_engine(engine):
asserter = SQLAsserter()
orig = []
@event.listens_for(engine, "before_execute")
def connection_execute(
conn, clauseelement, multiparams, params, execution_options
):
# grab the original statement + params before any cursor
# execution
orig[:] = clauseelement, multiparams, params
@event.listens_for(engine, "after_cursor_execute")
def cursor_execute(
conn, cursor, statement, parameters, context, executemany
):
if not context:
return
# then grab real cursor statements and associate them all
# around a single context
if (
asserter.accumulated
and asserter.accumulated[-1].context is context
):
obs = asserter.accumulated[-1]
else:
obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
asserter.accumulated.append(obs)
obs.statements.append(
SQLCursorExecuteObserved(
statement, parameters, context, executemany
)
)
try:
yield asserter
finally:
event.remove(engine, "after_cursor_execute", cursor_execute)
event.remove(engine, "before_execute", connection_execute)
asserter._close()

View File

@ -0,0 +1,135 @@
# testing/asyncio.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
# functions and wrappers to run tests, fixtures, provisioning and
# setup/teardown in an asyncio event loop, conditionally based on the
# current DB driver being used for a test.
# note that SQLAlchemy's asyncio integration also supports a method
# of running individual asyncio functions inside of separate event loops
# using "async_fallback" mode; however running whole functions in the event
# loop is a more accurate test for how SQLAlchemy's asyncio features
# would run in the real world.
from __future__ import annotations
from functools import wraps
import inspect
from . import config
from ..util.concurrency import _AsyncUtil
# may be set to False if the
# --disable-asyncio flag is passed to the test runner.
ENABLE_ASYNCIO = True
_async_util = _AsyncUtil() # it has lazy init so just always create one
def _shutdown():
"""called when the test finishes"""
_async_util.close()
def _run_coroutine_function(fn, *args, **kwargs):
return _async_util.run(fn, *args, **kwargs)
def _assume_async(fn, *args, **kwargs):
"""Run a function in an asyncio loop unconditionally.
This function is used for provisioning features like
testing a database connection for server info.
Note that for blocking IO database drivers, this means they block the
event loop.
"""
if not ENABLE_ASYNCIO:
return fn(*args, **kwargs)
return _async_util.run_in_greenlet(fn, *args, **kwargs)
def _maybe_async_provisioning(fn, *args, **kwargs):
"""Run a function in an asyncio loop if any current drivers might need it.
This function is used for provisioning features that take
place outside of a specific database driver being selected, so if the
current driver that happens to be used for the provisioning operation
is an async driver, it will run in asyncio and not fail.
Note that for blocking IO database drivers, this means they block the
event loop.
"""
if not ENABLE_ASYNCIO:
return fn(*args, **kwargs)
if config.any_async:
return _async_util.run_in_greenlet(fn, *args, **kwargs)
else:
return fn(*args, **kwargs)
def _maybe_async(fn, *args, **kwargs):
"""Run a function in an asyncio loop if the current selected driver is
async.
This function is used for test setup/teardown and tests themselves
where the current DB driver is known.
"""
if not ENABLE_ASYNCIO:
return fn(*args, **kwargs)
is_async = config._current.is_async
if is_async:
return _async_util.run_in_greenlet(fn, *args, **kwargs)
else:
return fn(*args, **kwargs)
def _maybe_async_wrapper(fn):
"""Apply the _maybe_async function to an existing function and return
as a wrapped callable, supporting generator functions as well.
This is currently used for pytest fixtures that support generator use.
"""
if inspect.isgeneratorfunction(fn):
_stop = object()
def call_next(gen):
try:
return next(gen)
# can't raise StopIteration in an awaitable.
except StopIteration:
return _stop
@wraps(fn)
def wrap_fixture(*args, **kwargs):
gen = fn(*args, **kwargs)
while True:
value = _maybe_async(call_next, gen)
if value is _stop:
break
yield value
else:
@wraps(fn)
def wrap_fixture(*args, **kwargs):
return _maybe_async(fn, *args, **kwargs)
return wrap_fixture

View File

@ -0,0 +1,423 @@
# testing/config.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
from argparse import Namespace
import collections
import inspect
import typing
from typing import Any
from typing import Callable
from typing import Iterable
from typing import NoReturn
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union
from . import mock
from . import requirements as _requirements
from .util import fail
from .. import util
# default requirements; this is replaced by plugin_base when pytest
# is run
requirements = _requirements.SuiteRequirements()
db = None
db_url = None
db_opts = None
file_config = None
test_schema = None
test_schema_2 = None
any_async = False
_current = None
ident = "main"
options: Namespace = None # type: ignore
if typing.TYPE_CHECKING:
from .plugin.plugin_base import FixtureFunctions
_fixture_functions: FixtureFunctions
else:
class _NullFixtureFunctions:
def _null_decorator(self):
def go(fn):
return fn
return go
def skip_test_exception(self, *arg, **kw):
return Exception()
@property
def add_to_marker(self):
return mock.Mock()
def mark_base_test_class(self):
return self._null_decorator()
def combinations(self, *arg_sets, **kw):
return self._null_decorator()
def param_ident(self, *parameters):
return self._null_decorator()
def fixture(self, *arg, **kw):
return self._null_decorator()
def get_current_test_name(self):
return None
def async_test(self, fn):
return fn
# default fixture functions; these are replaced by plugin_base when
# pytest runs
_fixture_functions = _NullFixtureFunctions()
_FN = TypeVar("_FN", bound=Callable[..., Any])
def combinations(
*comb: Union[Any, Tuple[Any, ...]],
argnames: Optional[str] = None,
id_: Optional[str] = None,
**kw: str,
) -> Callable[[_FN], _FN]:
r"""Deliver multiple versions of a test based on positional combinations.
This is a facade over pytest.mark.parametrize.
:param \*comb: argument combinations. These are tuples that will be passed
positionally to the decorated function.
:param argnames: optional list of argument names. These are the names
of the arguments in the test function that correspond to the entries
in each argument tuple. pytest.mark.parametrize requires this, however
the combinations function will derive it automatically if not present
by using ``inspect.getfullargspec(fn).args[1:]``. Note this assumes the
first argument is "self" which is discarded.
:param id\_: optional id template. This is a string template that
describes how the "id" for each parameter set should be defined, if any.
The number of characters in the template should match the number of
entries in each argument tuple. Each character describes how the
corresponding entry in the argument tuple should be handled, as far as
whether or not it is included in the arguments passed to the function, as
well as if it is included in the tokens used to create the id of the
parameter set.
If omitted, the argument combinations are passed to parametrize as is. If
passed, each argument combination is turned into a pytest.param() object,
mapping the elements of the argument tuple to produce an id based on a
character value in the same position within the string template using the
following scheme:
.. sourcecode:: text
i - the given argument is a string that is part of the id only, don't
pass it as an argument
n - the given argument should be passed and it should be added to the
id by calling the .__name__ attribute
r - the given argument should be passed and it should be added to the
id by calling repr()
s - the given argument should be passed and it should be added to the
id by calling str()
a - (argument) the given argument should be passed and it should not
be used to generated the id
e.g.::
@testing.combinations(
(operator.eq, "eq"),
(operator.ne, "ne"),
(operator.gt, "gt"),
(operator.lt, "lt"),
id_="na",
)
def test_operator(self, opfunc, name):
pass
The above combination will call ``.__name__`` on the first member of
each tuple and use that as the "id" to pytest.param().
"""
return _fixture_functions.combinations(
*comb, id_=id_, argnames=argnames, **kw
)
def combinations_list(arg_iterable: Iterable[Tuple[Any, ...]], **kw):
"As combination, but takes a single iterable"
return combinations(*arg_iterable, **kw)
class Variation:
__slots__ = ("_name", "_argname")
def __init__(self, case, argname, case_names):
self._name = case
self._argname = argname
for casename in case_names:
setattr(self, casename, casename == case)
if typing.TYPE_CHECKING:
def __getattr__(self, key: str) -> bool: ...
@property
def name(self):
return self._name
def __bool__(self):
return self._name == self._argname
def __nonzero__(self):
return not self.__bool__()
def __str__(self):
return f"{self._argname}={self._name!r}"
def __repr__(self):
return str(self)
def fail(self) -> NoReturn:
fail(f"Unknown {self}")
@classmethod
def idfn(cls, variation):
return variation.name
@classmethod
def generate_cases(cls, argname, cases):
case_names = [
argname if c is True else "not_" + argname if c is False else c
for c in cases
]
typ = type(
argname,
(Variation,),
{
"__slots__": tuple(case_names),
},
)
return [typ(casename, argname, case_names) for casename in case_names]
def variation(argname_or_fn, cases=None):
"""a helper around testing.combinations that provides a single namespace
that can be used as a switch.
e.g.::
@testing.variation("querytyp", ["select", "subquery", "legacy_query"])
@testing.variation("lazy", ["select", "raise", "raise_on_sql"])
def test_thing(self, querytyp, lazy, decl_base):
class Thing(decl_base):
__tablename__ = "thing"
# use name directly
rel = relationship("Rel", lazy=lazy.name)
# use as a switch
if querytyp.select:
stmt = select(Thing)
elif querytyp.subquery:
stmt = select(Thing).subquery()
elif querytyp.legacy_query:
stmt = Session.query(Thing)
else:
querytyp.fail()
The variable provided is a slots object of boolean variables, as well
as the name of the case itself under the attribute ".name"
"""
if inspect.isfunction(argname_or_fn):
argname = argname_or_fn.__name__
cases = argname_or_fn(None)
@variation_fixture(argname, cases)
def go(self, request):
yield request.param
return go
else:
argname = argname_or_fn
cases_plus_limitations = [
(
entry
if (isinstance(entry, tuple) and len(entry) == 2)
else (entry, None)
)
for entry in cases
]
variations = Variation.generate_cases(
argname, [c for c, l in cases_plus_limitations]
)
return combinations(
*[
(
(variation._name, variation, limitation)
if limitation is not None
else (variation._name, variation)
)
for variation, (case, limitation) in zip(
variations, cases_plus_limitations
)
],
id_="ia",
argnames=argname,
)
def variation_fixture(argname, cases, scope="function"):
return fixture(
params=Variation.generate_cases(argname, cases),
ids=Variation.idfn,
scope=scope,
)
def fixture(*arg: Any, **kw: Any) -> Any:
return _fixture_functions.fixture(*arg, **kw)
def get_current_test_name() -> str:
return _fixture_functions.get_current_test_name()
def mark_base_test_class() -> Any:
return _fixture_functions.mark_base_test_class()
class _AddToMarker:
def __getattr__(self, attr: str) -> Any:
return getattr(_fixture_functions.add_to_marker, attr)
add_to_marker = _AddToMarker()
class Config:
def __init__(self, db, db_opts, options, file_config):
self._set_name(db)
self.db = db
self.db_opts = db_opts
self.options = options
self.file_config = file_config
self.test_schema = "test_schema"
self.test_schema_2 = "test_schema_2"
self.is_async = db.dialect.is_async and not util.asbool(
db.url.query.get("async_fallback", False)
)
_stack = collections.deque()
_configs = set()
def _set_name(self, db):
suffix = "_async" if db.dialect.is_async else ""
if db.dialect.server_version_info:
svi = ".".join(str(tok) for tok in db.dialect.server_version_info)
self.name = "%s+%s%s_[%s]" % (db.name, db.driver, suffix, svi)
else:
self.name = "%s+%s%s" % (db.name, db.driver, suffix)
@classmethod
def register(cls, db, db_opts, options, file_config):
"""add a config as one of the global configs.
If there are no configs set up yet, this config also
gets set as the "_current".
"""
global any_async
cfg = Config(db, db_opts, options, file_config)
# if any backends include an async driver, then ensure
# all setup/teardown and tests are wrapped in the maybe_async()
# decorator that will set up a greenlet context for async drivers.
any_async = any_async or cfg.is_async
cls._configs.add(cfg)
return cfg
@classmethod
def set_as_current(cls, config, namespace):
global db, _current, db_url, test_schema, test_schema_2, db_opts
_current = config
db_url = config.db.url
db_opts = config.db_opts
test_schema = config.test_schema
test_schema_2 = config.test_schema_2
namespace.db = db = config.db
@classmethod
def push_engine(cls, db, namespace):
assert _current, "Can't push without a default Config set up"
cls.push(
Config(
db, _current.db_opts, _current.options, _current.file_config
),
namespace,
)
@classmethod
def push(cls, config, namespace):
cls._stack.append(_current)
cls.set_as_current(config, namespace)
@classmethod
def pop(cls, namespace):
if cls._stack:
# a failed test w/ -x option can call reset() ahead of time
_current = cls._stack[-1]
del cls._stack[-1]
cls.set_as_current(_current, namespace)
@classmethod
def reset(cls, namespace):
if cls._stack:
cls.set_as_current(cls._stack[0], namespace)
cls._stack.clear()
@classmethod
def all_configs(cls):
return cls._configs
@classmethod
def all_dbs(cls):
for cfg in cls.all_configs():
yield cfg.db
def skip_test(self, msg):
skip_test(msg)
def skip_test(msg):
raise _fixture_functions.skip_test_exception(msg)
def async_test(fn):
return _fixture_functions.async_test(fn)

View File

@ -0,0 +1,474 @@
# testing/engines.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import collections
import re
import typing
from typing import Any
from typing import Dict
from typing import Optional
import warnings
import weakref
from . import config
from .util import decorator
from .util import gc_collect
from .. import event
from .. import pool
from ..util import await_only
from ..util.typing import Literal
if typing.TYPE_CHECKING:
from ..engine import Engine
from ..engine.url import URL
from ..ext.asyncio import AsyncEngine
class ConnectionKiller:
def __init__(self):
self.proxy_refs = weakref.WeakKeyDictionary()
self.testing_engines = collections.defaultdict(set)
self.dbapi_connections = set()
def add_pool(self, pool):
event.listen(pool, "checkout", self._add_conn)
event.listen(pool, "checkin", self._remove_conn)
event.listen(pool, "close", self._remove_conn)
event.listen(pool, "close_detached", self._remove_conn)
# note we are keeping "invalidated" here, as those are still
# opened connections we would like to roll back
def _add_conn(self, dbapi_con, con_record, con_proxy):
self.dbapi_connections.add(dbapi_con)
self.proxy_refs[con_proxy] = True
def _remove_conn(self, dbapi_conn, *arg):
self.dbapi_connections.discard(dbapi_conn)
def add_engine(self, engine, scope):
self.add_pool(engine.pool)
assert scope in ("class", "global", "function", "fixture")
self.testing_engines[scope].add(engine)
def _safe(self, fn):
try:
fn()
except Exception as e:
warnings.warn(
"testing_reaper couldn't rollback/close connection: %s" % e
)
def rollback_all(self):
for rec in list(self.proxy_refs):
if rec is not None and rec.is_valid:
self._safe(rec.rollback)
def checkin_all(self):
# run pool.checkin() for all ConnectionFairy instances we have
# tracked.
for rec in list(self.proxy_refs):
if rec is not None and rec.is_valid:
self.dbapi_connections.discard(rec.dbapi_connection)
self._safe(rec._checkin)
# for fairy refs that were GCed and could not close the connection,
# such as asyncio, roll back those remaining connections
for con in self.dbapi_connections:
self._safe(con.rollback)
self.dbapi_connections.clear()
def close_all(self):
self.checkin_all()
def prepare_for_drop_tables(self, connection):
# don't do aggressive checks for third party test suites
if not config.bootstrapped_as_sqlalchemy:
return
from . import provision
provision.prepare_for_drop_tables(connection.engine.url, connection)
def _drop_testing_engines(self, scope):
eng = self.testing_engines[scope]
for rec in list(eng):
for proxy_ref in list(self.proxy_refs):
if proxy_ref is not None and proxy_ref.is_valid:
if (
proxy_ref._pool is not None
and proxy_ref._pool is rec.pool
):
self._safe(proxy_ref._checkin)
if hasattr(rec, "sync_engine"):
await_only(rec.dispose())
else:
rec.dispose()
eng.clear()
def after_test(self):
self._drop_testing_engines("function")
def after_test_outside_fixtures(self, test):
# don't do aggressive checks for third party test suites
if not config.bootstrapped_as_sqlalchemy:
return
if test.__class__.__leave_connections_for_teardown__:
return
self.checkin_all()
# on PostgreSQL, this will test for any "idle in transaction"
# connections. useful to identify tests with unusual patterns
# that can't be cleaned up correctly.
from . import provision
with config.db.connect() as conn:
provision.prepare_for_drop_tables(conn.engine.url, conn)
def stop_test_class_inside_fixtures(self):
self.checkin_all()
self._drop_testing_engines("function")
self._drop_testing_engines("class")
def stop_test_class_outside_fixtures(self):
# ensure no refs to checked out connections at all.
if pool.base._strong_ref_connection_records:
gc_collect()
if pool.base._strong_ref_connection_records:
ln = len(pool.base._strong_ref_connection_records)
pool.base._strong_ref_connection_records.clear()
assert (
False
), "%d connection recs not cleared after test suite" % (ln)
def final_cleanup(self):
self.checkin_all()
for scope in self.testing_engines:
self._drop_testing_engines(scope)
def assert_all_closed(self):
for rec in self.proxy_refs:
if rec.is_valid:
assert False
testing_reaper = ConnectionKiller()
@decorator
def assert_conns_closed(fn, *args, **kw):
try:
fn(*args, **kw)
finally:
testing_reaper.assert_all_closed()
@decorator
def rollback_open_connections(fn, *args, **kw):
"""Decorator that rolls back all open connections after fn execution."""
try:
fn(*args, **kw)
finally:
testing_reaper.rollback_all()
@decorator
def close_first(fn, *args, **kw):
"""Decorator that closes all connections before fn execution."""
testing_reaper.checkin_all()
fn(*args, **kw)
@decorator
def close_open_connections(fn, *args, **kw):
"""Decorator that closes all connections after fn execution."""
try:
fn(*args, **kw)
finally:
testing_reaper.checkin_all()
def all_dialects(exclude=None):
import sqlalchemy.dialects as d
for name in d.__all__:
# TEMPORARY
if exclude and name in exclude:
continue
mod = getattr(d, name, None)
if not mod:
mod = getattr(
__import__("sqlalchemy.dialects.%s" % name).dialects, name
)
yield mod.dialect()
class ReconnectFixture:
def __init__(self, dbapi):
self.dbapi = dbapi
self.connections = []
self.is_stopped = False
def __getattr__(self, key):
return getattr(self.dbapi, key)
def connect(self, *args, **kwargs):
conn = self.dbapi.connect(*args, **kwargs)
if self.is_stopped:
self._safe(conn.close)
curs = conn.cursor() # should fail on Oracle etc.
# should fail for everything that didn't fail
# above, connection is closed
curs.execute("select 1")
assert False, "simulated connect failure didn't work"
else:
self.connections.append(conn)
return conn
def _safe(self, fn):
try:
fn()
except Exception as e:
warnings.warn("ReconnectFixture couldn't close connection: %s" % e)
def shutdown(self, stop=False):
# TODO: this doesn't cover all cases
# as nicely as we'd like, namely MySQLdb.
# would need to implement R. Brewer's
# proxy server idea to get better
# coverage.
self.is_stopped = stop
for c in list(self.connections):
self._safe(c.close)
self.connections = []
def restart(self):
self.is_stopped = False
def reconnecting_engine(url=None, options=None):
url = url or config.db.url
dbapi = config.db.dialect.dbapi
if not options:
options = {}
options["module"] = ReconnectFixture(dbapi)
engine = testing_engine(url, options)
_dispose = engine.dispose
def dispose():
engine.dialect.dbapi.shutdown()
engine.dialect.dbapi.is_stopped = False
_dispose()
engine.test_shutdown = engine.dialect.dbapi.shutdown
engine.test_restart = engine.dialect.dbapi.restart
engine.dispose = dispose
return engine
@typing.overload
def testing_engine(
url: Optional[URL] = None,
options: Optional[Dict[str, Any]] = None,
asyncio: Literal[False] = False,
transfer_staticpool: bool = False,
) -> Engine: ...
@typing.overload
def testing_engine(
url: Optional[URL] = None,
options: Optional[Dict[str, Any]] = None,
asyncio: Literal[True] = True,
transfer_staticpool: bool = False,
) -> AsyncEngine: ...
def testing_engine(
url=None,
options=None,
asyncio=False,
transfer_staticpool=False,
share_pool=False,
_sqlite_savepoint=False,
):
if asyncio:
assert not _sqlite_savepoint
from sqlalchemy.ext.asyncio import (
create_async_engine as create_engine,
)
else:
from sqlalchemy import create_engine
from sqlalchemy.engine.url import make_url
if not options:
use_reaper = True
scope = "function"
sqlite_savepoint = False
else:
use_reaper = options.pop("use_reaper", True)
scope = options.pop("scope", "function")
sqlite_savepoint = options.pop("sqlite_savepoint", False)
url = url or config.db.url
url = make_url(url)
if (
config.db is None or url.drivername == config.db.url.drivername
) and config.db_opts:
use_options = config.db_opts.copy()
else:
use_options = {}
if options is not None:
use_options.update(options)
engine = create_engine(url, **use_options)
if sqlite_savepoint and engine.name == "sqlite":
# apply SQLite savepoint workaround
@event.listens_for(engine, "connect")
def do_connect(dbapi_connection, connection_record):
dbapi_connection.isolation_level = None
@event.listens_for(engine, "begin")
def do_begin(conn):
conn.exec_driver_sql("BEGIN")
if transfer_staticpool:
from sqlalchemy.pool import StaticPool
if config.db is not None and isinstance(config.db.pool, StaticPool):
use_reaper = False
engine.pool._transfer_from(config.db.pool)
elif share_pool:
engine.pool = config.db.pool
if scope == "global":
if asyncio:
engine.sync_engine._has_events = True
else:
engine._has_events = (
True # enable event blocks, helps with profiling
)
if (
isinstance(engine.pool, pool.QueuePool)
and "pool" not in use_options
and "pool_timeout" not in use_options
and "max_overflow" not in use_options
):
engine.pool._timeout = 0
engine.pool._max_overflow = 0
if use_reaper:
testing_reaper.add_engine(engine, scope)
return engine
def mock_engine(dialect_name=None):
"""Provides a mocking engine based on the current testing.db.
This is normally used to test DDL generation flow as emitted
by an Engine.
It should not be used in other cases, as assert_compile() and
assert_sql_execution() are much better choices with fewer
moving parts.
"""
from sqlalchemy import create_mock_engine
if not dialect_name:
dialect_name = config.db.name
buffer = []
def executor(sql, *a, **kw):
buffer.append(sql)
def assert_sql(stmts):
recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer]
assert recv == stmts, recv
def print_sql():
d = engine.dialect
return "\n".join(str(s.compile(dialect=d)) for s in engine.mock)
engine = create_mock_engine(dialect_name + "://", executor)
assert not hasattr(engine, "mock")
engine.mock = buffer
engine.assert_sql = assert_sql
engine.print_sql = print_sql
return engine
class DBAPIProxyCursor:
"""Proxy a DBAPI cursor.
Tests can provide subclasses of this to intercept
DBAPI-level cursor operations.
"""
def __init__(self, engine, conn, *args, **kwargs):
self.engine = engine
self.connection = conn
self.cursor = conn.cursor(*args, **kwargs)
def execute(self, stmt, parameters=None, **kw):
if parameters:
return self.cursor.execute(stmt, parameters, **kw)
else:
return self.cursor.execute(stmt, **kw)
def executemany(self, stmt, params, **kw):
return self.cursor.executemany(stmt, params, **kw)
def __iter__(self):
return iter(self.cursor)
def __getattr__(self, key):
return getattr(self.cursor, key)
class DBAPIProxyConnection:
"""Proxy a DBAPI connection.
Tests can provide subclasses of this to intercept
DBAPI-level connection operations.
"""
def __init__(self, engine, conn, cursor_cls):
self.conn = conn
self.engine = engine
self.cursor_cls = cursor_cls
def cursor(self, *args, **kwargs):
return self.cursor_cls(self.engine, self.conn, *args, **kwargs)
def close(self):
self.conn.close()
def __getattr__(self, key):
return getattr(self.conn, key)

View File

@ -0,0 +1,117 @@
# testing/entities.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import sqlalchemy as sa
from .. import exc as sa_exc
from ..orm.writeonly import WriteOnlyCollection
_repr_stack = set()
class BasicEntity:
def __init__(self, **kw):
for key, value in kw.items():
setattr(self, key, value)
def __repr__(self):
if id(self) in _repr_stack:
return object.__repr__(self)
_repr_stack.add(id(self))
try:
return "%s(%s)" % (
(self.__class__.__name__),
", ".join(
[
"%s=%r" % (key, getattr(self, key))
for key in sorted(self.__dict__.keys())
if not key.startswith("_")
]
),
)
finally:
_repr_stack.remove(id(self))
_recursion_stack = set()
class ComparableMixin:
def __ne__(self, other):
return not self.__eq__(other)
def __eq__(self, other):
"""'Deep, sparse compare.
Deeply compare two entities, following the non-None attributes of the
non-persisted object, if possible.
"""
if other is self:
return True
elif not self.__class__ == other.__class__:
return False
if id(self) in _recursion_stack:
return True
_recursion_stack.add(id(self))
try:
# pick the entity that's not SA persisted as the source
try:
self_key = sa.orm.attributes.instance_state(self).key
except sa.orm.exc.NO_STATE:
self_key = None
if other is None:
a = self
b = other
elif self_key is not None:
a = other
b = self
else:
a = self
b = other
for attr in list(a.__dict__):
if attr.startswith("_"):
continue
value = getattr(a, attr)
if isinstance(value, WriteOnlyCollection):
continue
try:
# handle lazy loader errors
battr = getattr(b, attr)
except (AttributeError, sa_exc.UnboundExecutionError):
return False
if hasattr(value, "__iter__") and not isinstance(value, str):
if hasattr(value, "__getitem__") and not hasattr(
value, "keys"
):
if list(value) != list(battr):
return False
else:
if set(value) != set(battr):
return False
else:
if value is not None and value != battr:
return False
return True
finally:
_recursion_stack.remove(id(self))
class ComparableEntity(ComparableMixin, BasicEntity):
def __hash__(self):
return hash(self.__class__)

View File

@ -0,0 +1,435 @@
# testing/exclusions.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
import contextlib
import operator
import re
import sys
from . import config
from .. import util
from ..util import decorator
from ..util.compat import inspect_getfullargspec
def skip_if(predicate, reason=None):
rule = compound()
pred = _as_predicate(predicate, reason)
rule.skips.add(pred)
return rule
def fails_if(predicate, reason=None):
rule = compound()
pred = _as_predicate(predicate, reason)
rule.fails.add(pred)
return rule
class compound:
def __init__(self):
self.fails = set()
self.skips = set()
def __add__(self, other):
return self.add(other)
def as_skips(self):
rule = compound()
rule.skips.update(self.skips)
rule.skips.update(self.fails)
return rule
def add(self, *others):
copy = compound()
copy.fails.update(self.fails)
copy.skips.update(self.skips)
for other in others:
copy.fails.update(other.fails)
copy.skips.update(other.skips)
return copy
def not_(self):
copy = compound()
copy.fails.update(NotPredicate(fail) for fail in self.fails)
copy.skips.update(NotPredicate(skip) for skip in self.skips)
return copy
@property
def enabled(self):
return self.enabled_for_config(config._current)
def enabled_for_config(self, config):
for predicate in self.skips.union(self.fails):
if predicate(config):
return False
else:
return True
def matching_config_reasons(self, config):
return [
predicate._as_string(config)
for predicate in self.skips.union(self.fails)
if predicate(config)
]
def _extend(self, other):
self.skips.update(other.skips)
self.fails.update(other.fails)
def __call__(self, fn):
if hasattr(fn, "_sa_exclusion_extend"):
fn._sa_exclusion_extend._extend(self)
return fn
@decorator
def decorate(fn, *args, **kw):
return self._do(config._current, fn, *args, **kw)
decorated = decorate(fn)
decorated._sa_exclusion_extend = self
return decorated
@contextlib.contextmanager
def fail_if(self):
all_fails = compound()
all_fails.fails.update(self.skips.union(self.fails))
try:
yield
except Exception as ex:
all_fails._expect_failure(config._current, ex)
else:
all_fails._expect_success(config._current)
def _do(self, cfg, fn, *args, **kw):
for skip in self.skips:
if skip(cfg):
msg = "'%s' : %s" % (
config.get_current_test_name(),
skip._as_string(cfg),
)
config.skip_test(msg)
try:
return_value = fn(*args, **kw)
except Exception as ex:
self._expect_failure(cfg, ex, name=fn.__name__)
else:
self._expect_success(cfg, name=fn.__name__)
return return_value
def _expect_failure(self, config, ex, name="block"):
for fail in self.fails:
if fail(config):
print(
"%s failed as expected (%s): %s "
% (name, fail._as_string(config), ex)
)
break
else:
raise ex.with_traceback(sys.exc_info()[2])
def _expect_success(self, config, name="block"):
if not self.fails:
return
for fail in self.fails:
if fail(config):
raise AssertionError(
"Unexpected success for '%s' (%s)"
% (
name,
" and ".join(
fail._as_string(config) for fail in self.fails
),
)
)
def only_if(predicate, reason=None):
predicate = _as_predicate(predicate)
return skip_if(NotPredicate(predicate), reason)
def succeeds_if(predicate, reason=None):
predicate = _as_predicate(predicate)
return fails_if(NotPredicate(predicate), reason)
class Predicate:
@classmethod
def as_predicate(cls, predicate, description=None):
if isinstance(predicate, compound):
return cls.as_predicate(predicate.enabled_for_config, description)
elif isinstance(predicate, Predicate):
if description and predicate.description is None:
predicate.description = description
return predicate
elif isinstance(predicate, (list, set)):
return OrPredicate(
[cls.as_predicate(pred) for pred in predicate], description
)
elif isinstance(predicate, tuple):
return SpecPredicate(*predicate)
elif isinstance(predicate, str):
tokens = re.match(
r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate
)
if not tokens:
raise ValueError(
"Couldn't locate DB name in predicate: %r" % predicate
)
db = tokens.group(1)
op = tokens.group(2)
spec = (
tuple(int(d) for d in tokens.group(3).split("."))
if tokens.group(3)
else None
)
return SpecPredicate(db, op, spec, description=description)
elif callable(predicate):
return LambdaPredicate(predicate, description)
else:
assert False, "unknown predicate type: %s" % predicate
def _format_description(self, config, negate=False):
bool_ = self(config)
if negate:
bool_ = not negate
return self.description % {
"driver": (
config.db.url.get_driver_name() if config else "<no driver>"
),
"database": (
config.db.url.get_backend_name() if config else "<no database>"
),
"doesnt_support": "doesn't support" if bool_ else "does support",
"does_support": "does support" if bool_ else "doesn't support",
}
def _as_string(self, config=None, negate=False):
raise NotImplementedError()
class BooleanPredicate(Predicate):
def __init__(self, value, description=None):
self.value = value
self.description = description or "boolean %s" % value
def __call__(self, config):
return self.value
def _as_string(self, config, negate=False):
return self._format_description(config, negate=negate)
class SpecPredicate(Predicate):
def __init__(self, db, op=None, spec=None, description=None):
self.db = db
self.op = op
self.spec = spec
self.description = description
_ops = {
"<": operator.lt,
">": operator.gt,
"==": operator.eq,
"!=": operator.ne,
"<=": operator.le,
">=": operator.ge,
"in": operator.contains,
"between": lambda val, pair: val >= pair[0] and val <= pair[1],
}
def __call__(self, config):
if config is None:
return False
engine = config.db
if "+" in self.db:
dialect, driver = self.db.split("+")
else:
dialect, driver = self.db, None
if dialect and engine.name != dialect:
return False
if driver is not None and engine.driver != driver:
return False
if self.op is not None:
assert driver is None, "DBAPI version specs not supported yet"
version = _server_version(engine)
oper = (
hasattr(self.op, "__call__") and self.op or self._ops[self.op]
)
return oper(version, self.spec)
else:
return True
def _as_string(self, config, negate=False):
if self.description is not None:
return self._format_description(config)
elif self.op is None:
if negate:
return "not %s" % self.db
else:
return "%s" % self.db
else:
if negate:
return "not %s %s %s" % (self.db, self.op, self.spec)
else:
return "%s %s %s" % (self.db, self.op, self.spec)
class LambdaPredicate(Predicate):
def __init__(self, lambda_, description=None, args=None, kw=None):
spec = inspect_getfullargspec(lambda_)
if not spec[0]:
self.lambda_ = lambda db: lambda_()
else:
self.lambda_ = lambda_
self.args = args or ()
self.kw = kw or {}
if description:
self.description = description
elif lambda_.__doc__:
self.description = lambda_.__doc__
else:
self.description = "custom function"
def __call__(self, config):
return self.lambda_(config)
def _as_string(self, config, negate=False):
return self._format_description(config)
class NotPredicate(Predicate):
def __init__(self, predicate, description=None):
self.predicate = predicate
self.description = description
def __call__(self, config):
return not self.predicate(config)
def _as_string(self, config, negate=False):
if self.description:
return self._format_description(config, not negate)
else:
return self.predicate._as_string(config, not negate)
class OrPredicate(Predicate):
def __init__(self, predicates, description=None):
self.predicates = predicates
self.description = description
def __call__(self, config):
for pred in self.predicates:
if pred(config):
return True
return False
def _eval_str(self, config, negate=False):
if negate:
conjunction = " and "
else:
conjunction = " or "
return conjunction.join(
p._as_string(config, negate=negate) for p in self.predicates
)
def _negation_str(self, config):
if self.description is not None:
return "Not " + self._format_description(config)
else:
return self._eval_str(config, negate=True)
def _as_string(self, config, negate=False):
if negate:
return self._negation_str(config)
else:
if self.description is not None:
return self._format_description(config)
else:
return self._eval_str(config)
_as_predicate = Predicate.as_predicate
def _is_excluded(db, op, spec):
return SpecPredicate(db, op, spec)(config._current)
def _server_version(engine):
"""Return a server_version_info tuple."""
# force metadata to be retrieved
conn = engine.connect()
version = getattr(engine.dialect, "server_version_info", None)
if version is None:
version = ()
conn.close()
return version
def db_spec(*dbs):
return OrPredicate([Predicate.as_predicate(db) for db in dbs])
def open(): # noqa
return skip_if(BooleanPredicate(False, "mark as execute"))
def closed():
return skip_if(BooleanPredicate(True, "marked as skip"))
def fails(reason=None):
return fails_if(BooleanPredicate(True, reason or "expected to fail"))
def future():
return fails_if(BooleanPredicate(True, "Future feature"))
def fails_on(db, reason=None):
return fails_if(db, reason)
def fails_on_everything_except(*dbs):
return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))
def skip(db, reason=None):
return skip_if(db, reason)
def only_on(dbs, reason=None):
return only_if(
OrPredicate(
[Predicate.as_predicate(db, reason) for db in util.to_list(dbs)]
)
)
def exclude(db, op, spec, reason=None):
return skip_if(SpecPredicate(db, op, spec), reason)
def against(config, *queries):
assert queries, "no queries sent!"
return OrPredicate([Predicate.as_predicate(query) for query in queries])(
config
)

View File

@ -0,0 +1,28 @@
# testing/fixtures/__init__.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from .base import FutureEngineMixin as FutureEngineMixin
from .base import TestBase as TestBase
from .mypy import MypyTest as MypyTest
from .orm import after_test as after_test
from .orm import close_all_sessions as close_all_sessions
from .orm import DeclarativeMappedTest as DeclarativeMappedTest
from .orm import fixture_session as fixture_session
from .orm import MappedTest as MappedTest
from .orm import ORMTest as ORMTest
from .orm import RemoveORMEventsGlobally as RemoveORMEventsGlobally
from .orm import (
stop_test_class_inside_fixtures as stop_test_class_inside_fixtures,
)
from .sql import CacheKeyFixture as CacheKeyFixture
from .sql import (
ComputedReflectionFixtureTest as ComputedReflectionFixtureTest,
)
from .sql import insertmanyvalues_fixture as insertmanyvalues_fixture
from .sql import NoCache as NoCache
from .sql import RemovesEvents as RemovesEvents
from .sql import TablesTest as TablesTest

View File

@ -0,0 +1,366 @@
# testing/fixtures/base.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import sqlalchemy as sa
from .. import assertions
from .. import config
from ..assertions import eq_
from ..util import drop_all_tables_from_metadata
from ... import Column
from ... import func
from ... import Integer
from ... import select
from ... import Table
from ...orm import DeclarativeBase
from ...orm import MappedAsDataclass
from ...orm import registry
@config.mark_base_test_class()
class TestBase:
# A sequence of requirement names matching testing.requires decorators
__requires__ = ()
# A sequence of dialect names to exclude from the test class.
__unsupported_on__ = ()
# If present, test class is only runnable for the *single* specified
# dialect. If you need multiple, use __unsupported_on__ and invert.
__only_on__ = None
# A sequence of no-arg callables. If any are True, the entire testcase is
# skipped.
__skip_if__ = None
# if True, the testing reaper will not attempt to touch connection
# state after a test is completed and before the outer teardown
# starts
__leave_connections_for_teardown__ = False
def assert_(self, val, msg=None):
assert val, msg
@config.fixture()
def nocache(self):
_cache = config.db._compiled_cache
config.db._compiled_cache = None
yield
config.db._compiled_cache = _cache
@config.fixture()
def connection_no_trans(self):
eng = getattr(self, "bind", None) or config.db
with eng.connect() as conn:
yield conn
@config.fixture()
def connection(self):
global _connection_fixture_connection
eng = getattr(self, "bind", None) or config.db
conn = eng.connect()
trans = conn.begin()
_connection_fixture_connection = conn
yield conn
_connection_fixture_connection = None
if trans.is_active:
trans.rollback()
# trans would not be active here if the test is using
# the legacy @provide_metadata decorator still, as it will
# run a close all connections.
conn.close()
@config.fixture()
def close_result_when_finished(self):
to_close = []
to_consume = []
def go(result, consume=False):
to_close.append(result)
if consume:
to_consume.append(result)
yield go
for r in to_consume:
try:
r.all()
except:
pass
for r in to_close:
try:
r.close()
except:
pass
@config.fixture()
def registry(self, metadata):
reg = registry(
metadata=metadata,
type_annotation_map={
str: sa.String().with_variant(
sa.String(50), "mysql", "mariadb", "oracle"
)
},
)
yield reg
reg.dispose()
@config.fixture
def decl_base(self, metadata):
_md = metadata
class Base(DeclarativeBase):
metadata = _md
type_annotation_map = {
str: sa.String().with_variant(
sa.String(50), "mysql", "mariadb", "oracle"
)
}
yield Base
Base.registry.dispose()
@config.fixture
def dc_decl_base(self, metadata):
_md = metadata
class Base(MappedAsDataclass, DeclarativeBase):
metadata = _md
type_annotation_map = {
str: sa.String().with_variant(
sa.String(50), "mysql", "mariadb"
)
}
yield Base
Base.registry.dispose()
@config.fixture()
def future_connection(self, future_engine, connection):
# integrate the future_engine and connection fixtures so
# that users of the "connection" fixture will get at the
# "future" connection
yield connection
@config.fixture()
def future_engine(self):
yield
@config.fixture()
def testing_engine(self):
from .. import engines
def gen_testing_engine(
url=None,
options=None,
future=None,
asyncio=False,
transfer_staticpool=False,
share_pool=False,
):
if options is None:
options = {}
options["scope"] = "fixture"
return engines.testing_engine(
url=url,
options=options,
asyncio=asyncio,
transfer_staticpool=transfer_staticpool,
share_pool=share_pool,
)
yield gen_testing_engine
engines.testing_reaper._drop_testing_engines("fixture")
@config.fixture()
def async_testing_engine(self, testing_engine):
def go(**kw):
kw["asyncio"] = True
return testing_engine(**kw)
return go
@config.fixture()
def metadata(self, request):
"""Provide bound MetaData for a single test, dropping afterwards."""
from ...sql import schema
metadata = schema.MetaData()
request.instance.metadata = metadata
yield metadata
del request.instance.metadata
if (
_connection_fixture_connection
and _connection_fixture_connection.in_transaction()
):
trans = _connection_fixture_connection.get_transaction()
trans.rollback()
with _connection_fixture_connection.begin():
drop_all_tables_from_metadata(
metadata, _connection_fixture_connection
)
else:
drop_all_tables_from_metadata(metadata, config.db)
@config.fixture(
params=[
(rollback, second_operation, begin_nested)
for rollback in (True, False)
for second_operation in ("none", "execute", "begin")
for begin_nested in (
True,
False,
)
]
)
def trans_ctx_manager_fixture(self, request, metadata):
rollback, second_operation, begin_nested = request.param
t = Table("test", metadata, Column("data", Integer))
eng = getattr(self, "bind", None) or config.db
t.create(eng)
def run_test(subject, trans_on_subject, execute_on_subject):
with subject.begin() as trans:
if begin_nested:
if not config.requirements.savepoints.enabled:
config.skip_test("savepoints not enabled")
if execute_on_subject:
nested_trans = subject.begin_nested()
else:
nested_trans = trans.begin_nested()
with nested_trans:
if execute_on_subject:
subject.execute(t.insert(), {"data": 10})
else:
trans.execute(t.insert(), {"data": 10})
# for nested trans, we always commit/rollback on the
# "nested trans" object itself.
# only Session(future=False) will affect savepoint
# transaction for session.commit/rollback
if rollback:
nested_trans.rollback()
else:
nested_trans.commit()
if second_operation != "none":
with assertions.expect_raises_message(
sa.exc.InvalidRequestError,
"Can't operate on closed transaction "
"inside context "
"manager. Please complete the context "
"manager "
"before emitting further commands.",
):
if second_operation == "execute":
if execute_on_subject:
subject.execute(
t.insert(), {"data": 12}
)
else:
trans.execute(t.insert(), {"data": 12})
elif second_operation == "begin":
if execute_on_subject:
subject.begin_nested()
else:
trans.begin_nested()
# outside the nested trans block, but still inside the
# transaction block, we can run SQL, and it will be
# committed
if execute_on_subject:
subject.execute(t.insert(), {"data": 14})
else:
trans.execute(t.insert(), {"data": 14})
else:
if execute_on_subject:
subject.execute(t.insert(), {"data": 10})
else:
trans.execute(t.insert(), {"data": 10})
if trans_on_subject:
if rollback:
subject.rollback()
else:
subject.commit()
else:
if rollback:
trans.rollback()
else:
trans.commit()
if second_operation != "none":
with assertions.expect_raises_message(
sa.exc.InvalidRequestError,
"Can't operate on closed transaction inside "
"context "
"manager. Please complete the context manager "
"before emitting further commands.",
):
if second_operation == "execute":
if execute_on_subject:
subject.execute(t.insert(), {"data": 12})
else:
trans.execute(t.insert(), {"data": 12})
elif second_operation == "begin":
if hasattr(trans, "begin"):
trans.begin()
else:
subject.begin()
elif second_operation == "begin_nested":
if execute_on_subject:
subject.begin_nested()
else:
trans.begin_nested()
expected_committed = 0
if begin_nested:
# begin_nested variant, we inserted a row after the nested
# block
expected_committed += 1
if not rollback:
# not rollback variant, our row inserted in the target
# block itself would be committed
expected_committed += 1
if execute_on_subject:
eq_(
subject.scalar(select(func.count()).select_from(t)),
expected_committed,
)
else:
with subject.connect() as conn:
eq_(
conn.scalar(select(func.count()).select_from(t)),
expected_committed,
)
return run_test
_connection_fixture_connection = None
class FutureEngineMixin:
"""alembic's suite still using this"""

View File

@ -0,0 +1,312 @@
# testing/fixtures/mypy.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import inspect
import os
from pathlib import Path
import re
import shutil
import sys
import tempfile
from .base import TestBase
from .. import config
from ..assertions import eq_
from ... import util
@config.add_to_marker.mypy
class MypyTest(TestBase):
__requires__ = ("no_sqlalchemy2_stubs",)
@config.fixture(scope="function")
def per_func_cachedir(self):
yield from self._cachedir()
@config.fixture(scope="class")
def cachedir(self):
yield from self._cachedir()
def _cachedir(self):
# as of mypy 0.971 i think we need to keep mypy_path empty
mypy_path = ""
with tempfile.TemporaryDirectory() as cachedir:
with open(
Path(cachedir) / "sqla_mypy_config.cfg", "w"
) as config_file:
config_file.write(
f"""
[mypy]\n
plugins = sqlalchemy.ext.mypy.plugin\n
show_error_codes = True\n
{mypy_path}
disable_error_code = no-untyped-call
[mypy-sqlalchemy.*]
ignore_errors = True
"""
)
with open(
Path(cachedir) / "plain_mypy_config.cfg", "w"
) as config_file:
config_file.write(
f"""
[mypy]\n
show_error_codes = True\n
{mypy_path}
disable_error_code = var-annotated,no-untyped-call
[mypy-sqlalchemy.*]
ignore_errors = True
"""
)
yield cachedir
@config.fixture()
def mypy_runner(self, cachedir):
from mypy import api
def run(path, use_plugin=False, use_cachedir=None):
if use_cachedir is None:
use_cachedir = cachedir
args = [
"--strict",
"--raise-exceptions",
"--cache-dir",
use_cachedir,
"--config-file",
os.path.join(
use_cachedir,
(
"sqla_mypy_config.cfg"
if use_plugin
else "plain_mypy_config.cfg"
),
),
]
# mypy as of 0.990 is more aggressively blocking messaging
# for paths that are in sys.path, and as pytest puts currdir,
# test/ etc in sys.path, just copy the source file to the
# tempdir we are working in so that we don't have to try to
# manipulate sys.path and/or guess what mypy is doing
filename = os.path.basename(path)
test_program = os.path.join(use_cachedir, filename)
if path != test_program:
shutil.copyfile(path, test_program)
args.append(test_program)
# I set this locally but for the suite here needs to be
# disabled
os.environ.pop("MYPY_FORCE_COLOR", None)
stdout, stderr, exitcode = api.run(args)
return stdout, stderr, exitcode
return run
@config.fixture
def mypy_typecheck_file(self, mypy_runner):
def run(path, use_plugin=False):
expected_messages = self._collect_messages(path)
stdout, stderr, exitcode = mypy_runner(path, use_plugin=use_plugin)
self._check_output(
path, expected_messages, stdout, stderr, exitcode
)
return run
@staticmethod
def file_combinations(dirname):
if os.path.isabs(dirname):
path = dirname
else:
caller_path = inspect.stack()[1].filename
path = os.path.join(os.path.dirname(caller_path), dirname)
files = list(Path(path).glob("**/*.py"))
for extra_dir in config.options.mypy_extra_test_paths:
if extra_dir and os.path.isdir(extra_dir):
files.extend((Path(extra_dir) / dirname).glob("**/*.py"))
return files
def _collect_messages(self, path):
from sqlalchemy.ext.mypy.util import mypy_14
expected_messages = []
expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)")
py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)")
with open(path) as file_:
current_assert_messages = []
for num, line in enumerate(file_, 1):
m = py_ver_re.match(line)
if m:
major, _, minor = m.group(1).partition(".")
if sys.version_info < (int(major), int(minor)):
config.skip_test(
"Requires python >= %s" % (m.group(1))
)
continue
m = expected_re.match(line)
if m:
is_mypy = bool(m.group(1))
is_re = bool(m.group(2))
is_type = bool(m.group(3))
expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4))
if is_type:
if not is_re:
# the goal here is that we can cut-and-paste
# from vscode -> pylance into the
# EXPECTED_TYPE: line, then the test suite will
# validate that line against what mypy produces
expected_msg = re.sub(
r"([\[\]])",
lambda m: rf"\{m.group(0)}",
expected_msg,
)
# note making sure preceding text matches
# with a dot, so that an expect for "Select"
# does not match "TypedSelect"
expected_msg = re.sub(
r"([\w_]+)",
lambda m: rf"(?:.*\.)?{m.group(1)}\*?",
expected_msg,
)
expected_msg = re.sub(
"List", "builtins.list", expected_msg
)
expected_msg = re.sub(
r"\b(int|str|float|bool)\b",
lambda m: rf"builtins.{m.group(0)}\*?",
expected_msg,
)
# expected_msg = re.sub(
# r"(Sequence|Tuple|List|Union)",
# lambda m: fr"typing.{m.group(0)}\*?",
# expected_msg,
# )
is_mypy = is_re = True
expected_msg = f'Revealed type is "{expected_msg}"'
if mypy_14 and util.py39:
# use_lowercase_names, py39 and above
# https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L363 # noqa: E501
# skip first character which could be capitalized
# "List item x not found" type of message
expected_msg = expected_msg[0] + re.sub(
(
r"\b(List|Tuple|Dict|Set)\b"
if is_type
else r"\b(List|Tuple|Dict|Set|Type)\b"
),
lambda m: m.group(1).lower(),
expected_msg[1:],
)
if mypy_14 and util.py310:
# use_or_syntax, py310 and above
# https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L368 # noqa: E501
expected_msg = re.sub(
r"Optional\[(.*?)\]",
lambda m: f"{m.group(1)} | None",
expected_msg,
)
current_assert_messages.append(
(is_mypy, is_re, expected_msg.strip())
)
elif current_assert_messages:
expected_messages.extend(
(num, is_mypy, is_re, expected_msg)
for (
is_mypy,
is_re,
expected_msg,
) in current_assert_messages
)
current_assert_messages[:] = []
return expected_messages
def _check_output(self, path, expected_messages, stdout, stderr, exitcode):
not_located = []
filename = os.path.basename(path)
if expected_messages:
# mypy 0.990 changed how return codes work, so don't assume a
# 1 or a 0 return code here, could be either depending on if
# errors were generated or not
output = []
raw_lines = stdout.split("\n")
while raw_lines:
e = raw_lines.pop(0)
if re.match(r".+\.py:\d+: error: .*", e):
output.append(("error", e))
elif re.match(
r".+\.py:\d+: note: +(?:Possible overload|def ).*", e
):
while raw_lines:
ol = raw_lines.pop(0)
if not re.match(r".+\.py:\d+: note: +def \[.*", ol):
break
elif re.match(
r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I
):
pass
elif re.match(r".+\.py:\d+: note: .*", e):
output.append(("note", e))
for num, is_mypy, is_re, msg in expected_messages:
msg = msg.replace("'", '"')
prefix = "[SQLAlchemy Mypy plugin] " if not is_mypy else ""
for idx, (typ, errmsg) in enumerate(output):
if is_re:
if re.match(
rf".*{filename}\:{num}\: {typ}\: {prefix}{msg}",
errmsg,
):
break
elif (
f"{filename}:{num}: {typ}: {prefix}{msg}"
in errmsg.replace("'", '"')
):
break
else:
not_located.append(msg)
continue
del output[idx]
if not_located:
missing = "\n".join(not_located)
print("Couldn't locate expected messages:", missing, sep="\n")
if output:
extra = "\n".join(msg for _, msg in output)
print("Remaining messages:", extra, sep="\n")
assert False, "expected messages not found, see stdout"
if output:
print(f"{len(output)} messages from mypy were not consumed:")
print("\n".join(msg for _, msg in output))
assert False, "errors and/or notes remain, see stdout"
else:
if exitcode != 0:
print(stdout, stderr, sep="\n")
eq_(exitcode, 0, msg=stdout)

View File

@ -0,0 +1,227 @@
# testing/fixtures/orm.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
from typing import Any
import sqlalchemy as sa
from .base import TestBase
from .sql import TablesTest
from .. import assertions
from .. import config
from .. import schema
from ..entities import BasicEntity
from ..entities import ComparableEntity
from ..util import adict
from ... import orm
from ...orm import DeclarativeBase
from ...orm import events as orm_events
from ...orm import registry
class ORMTest(TestBase):
@config.fixture
def fixture_session(self):
return fixture_session()
class MappedTest(ORMTest, TablesTest, assertions.AssertsExecutionResults):
# 'once', 'each', None
run_setup_classes = "once"
# 'once', 'each', None
run_setup_mappers = "each"
classes: Any = None
@config.fixture(autouse=True, scope="class")
def _setup_tables_test_class(self):
cls = self.__class__
cls._init_class()
if cls.classes is None:
cls.classes = adict()
cls._setup_once_tables()
cls._setup_once_classes()
cls._setup_once_mappers()
cls._setup_once_inserts()
yield
cls._teardown_once_class()
cls._teardown_once_metadata_bind()
@config.fixture(autouse=True, scope="function")
def _setup_tables_test_instance(self):
self._setup_each_tables()
self._setup_each_classes()
self._setup_each_mappers()
self._setup_each_inserts()
yield
orm.session.close_all_sessions()
self._teardown_each_mappers()
self._teardown_each_classes()
self._teardown_each_tables()
@classmethod
def _teardown_once_class(cls):
cls.classes.clear()
@classmethod
def _setup_once_classes(cls):
if cls.run_setup_classes == "once":
cls._with_register_classes(cls.setup_classes)
@classmethod
def _setup_once_mappers(cls):
if cls.run_setup_mappers == "once":
cls.mapper_registry, cls.mapper = cls._generate_registry()
cls._with_register_classes(cls.setup_mappers)
def _setup_each_mappers(self):
if self.run_setup_mappers != "once":
(
self.__class__.mapper_registry,
self.__class__.mapper,
) = self._generate_registry()
if self.run_setup_mappers == "each":
self._with_register_classes(self.setup_mappers)
def _setup_each_classes(self):
if self.run_setup_classes == "each":
self._with_register_classes(self.setup_classes)
@classmethod
def _generate_registry(cls):
decl = registry(metadata=cls._tables_metadata)
return decl, decl.map_imperatively
@classmethod
def _with_register_classes(cls, fn):
"""Run a setup method, framing the operation with a Base class
that will catch new subclasses to be established within
the "classes" registry.
"""
cls_registry = cls.classes
class _Base:
def __init_subclass__(cls) -> None:
assert cls_registry is not None
cls_registry[cls.__name__] = cls
super().__init_subclass__()
class Basic(BasicEntity, _Base):
pass
class Comparable(ComparableEntity, _Base):
pass
cls.Basic = Basic
cls.Comparable = Comparable
fn()
def _teardown_each_mappers(self):
# some tests create mappers in the test bodies
# and will define setup_mappers as None -
# clear mappers in any case
if self.run_setup_mappers != "once":
orm.clear_mappers()
def _teardown_each_classes(self):
if self.run_setup_classes != "once":
self.classes.clear()
@classmethod
def setup_classes(cls):
pass
@classmethod
def setup_mappers(cls):
pass
class DeclarativeMappedTest(MappedTest):
run_setup_classes = "once"
run_setup_mappers = "once"
@classmethod
def _setup_once_tables(cls):
pass
@classmethod
def _with_register_classes(cls, fn):
cls_registry = cls.classes
class _DeclBase(DeclarativeBase):
__table_cls__ = schema.Table
metadata = cls._tables_metadata
type_annotation_map = {
str: sa.String().with_variant(
sa.String(50), "mysql", "mariadb", "oracle"
)
}
def __init_subclass__(cls, **kw) -> None:
assert cls_registry is not None
cls_registry[cls.__name__] = cls
super().__init_subclass__(**kw)
cls.DeclarativeBasic = _DeclBase
# sets up cls.Basic which is helpful for things like composite
# classes
super()._with_register_classes(fn)
if cls._tables_metadata.tables and cls.run_create_tables:
cls._tables_metadata.create_all(config.db)
class RemoveORMEventsGlobally:
@config.fixture(autouse=True)
def _remove_listeners(self):
yield
orm_events.MapperEvents._clear()
orm_events.InstanceEvents._clear()
orm_events.SessionEvents._clear()
orm_events.InstrumentationEvents._clear()
orm_events.QueryEvents._clear()
_fixture_sessions = set()
def fixture_session(**kw):
kw.setdefault("autoflush", True)
kw.setdefault("expire_on_commit", True)
bind = kw.pop("bind", config.db)
sess = orm.Session(bind, **kw)
_fixture_sessions.add(sess)
return sess
def close_all_sessions():
# will close all still-referenced sessions
orm.close_all_sessions()
_fixture_sessions.clear()
def stop_test_class_inside_fixtures(cls):
close_all_sessions()
orm.clear_mappers()
def after_test():
if _fixture_sessions:
close_all_sessions()

View File

@ -0,0 +1,503 @@
# testing/fixtures/sql.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import itertools
import random
import re
import sys
import sqlalchemy as sa
from .base import TestBase
from .. import config
from .. import mock
from ..assertions import eq_
from ..assertions import ne_
from ..util import adict
from ..util import drop_all_tables_from_metadata
from ... import event
from ... import util
from ...schema import sort_tables_and_constraints
from ...sql import visitors
from ...sql.elements import ClauseElement
class TablesTest(TestBase):
# 'once', None
run_setup_bind = "once"
# 'once', 'each', None
run_define_tables = "once"
# 'once', 'each', None
run_create_tables = "once"
# 'once', 'each', None
run_inserts = "each"
# 'each', None
run_deletes = "each"
# 'once', None
run_dispose_bind = None
bind = None
_tables_metadata = None
tables = None
other = None
sequences = None
@config.fixture(autouse=True, scope="class")
def _setup_tables_test_class(self):
cls = self.__class__
cls._init_class()
cls._setup_once_tables()
cls._setup_once_inserts()
yield
cls._teardown_once_metadata_bind()
@config.fixture(autouse=True, scope="function")
def _setup_tables_test_instance(self):
self._setup_each_tables()
self._setup_each_inserts()
yield
self._teardown_each_tables()
@property
def tables_test_metadata(self):
return self._tables_metadata
@classmethod
def _init_class(cls):
if cls.run_define_tables == "each":
if cls.run_create_tables == "once":
cls.run_create_tables = "each"
assert cls.run_inserts in ("each", None)
cls.other = adict()
cls.tables = adict()
cls.sequences = adict()
cls.bind = cls.setup_bind()
cls._tables_metadata = sa.MetaData()
@classmethod
def _setup_once_inserts(cls):
if cls.run_inserts == "once":
cls._load_fixtures()
with cls.bind.begin() as conn:
cls.insert_data(conn)
@classmethod
def _setup_once_tables(cls):
if cls.run_define_tables == "once":
cls.define_tables(cls._tables_metadata)
if cls.run_create_tables == "once":
cls._tables_metadata.create_all(cls.bind)
cls.tables.update(cls._tables_metadata.tables)
cls.sequences.update(cls._tables_metadata._sequences)
def _setup_each_tables(self):
if self.run_define_tables == "each":
self.define_tables(self._tables_metadata)
if self.run_create_tables == "each":
self._tables_metadata.create_all(self.bind)
self.tables.update(self._tables_metadata.tables)
self.sequences.update(self._tables_metadata._sequences)
elif self.run_create_tables == "each":
self._tables_metadata.create_all(self.bind)
def _setup_each_inserts(self):
if self.run_inserts == "each":
self._load_fixtures()
with self.bind.begin() as conn:
self.insert_data(conn)
def _teardown_each_tables(self):
if self.run_define_tables == "each":
self.tables.clear()
if self.run_create_tables == "each":
drop_all_tables_from_metadata(self._tables_metadata, self.bind)
self._tables_metadata.clear()
elif self.run_create_tables == "each":
drop_all_tables_from_metadata(self._tables_metadata, self.bind)
savepoints = getattr(config.requirements, "savepoints", False)
if savepoints:
savepoints = savepoints.enabled
# no need to run deletes if tables are recreated on setup
if (
self.run_define_tables != "each"
and self.run_create_tables != "each"
and self.run_deletes == "each"
):
with self.bind.begin() as conn:
for table in reversed(
[
t
for (t, fks) in sort_tables_and_constraints(
self._tables_metadata.tables.values()
)
if t is not None
]
):
try:
if savepoints:
with conn.begin_nested():
conn.execute(table.delete())
else:
conn.execute(table.delete())
except sa.exc.DBAPIError as ex:
print(
("Error emptying table %s: %r" % (table, ex)),
file=sys.stderr,
)
@classmethod
def _teardown_once_metadata_bind(cls):
if cls.run_create_tables:
drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
if cls.run_dispose_bind == "once":
cls.dispose_bind(cls.bind)
cls._tables_metadata.bind = None
if cls.run_setup_bind is not None:
cls.bind = None
@classmethod
def setup_bind(cls):
return config.db
@classmethod
def dispose_bind(cls, bind):
if hasattr(bind, "dispose"):
bind.dispose()
elif hasattr(bind, "close"):
bind.close()
@classmethod
def define_tables(cls, metadata):
pass
@classmethod
def fixtures(cls):
return {}
@classmethod
def insert_data(cls, connection):
pass
def sql_count_(self, count, fn):
self.assert_sql_count(self.bind, fn, count)
def sql_eq_(self, callable_, statements):
self.assert_sql(self.bind, callable_, statements)
@classmethod
def _load_fixtures(cls):
"""Insert rows as represented by the fixtures() method."""
headers, rows = {}, {}
for table, data in cls.fixtures().items():
if len(data) < 2:
continue
if isinstance(table, str):
table = cls.tables[table]
headers[table] = data[0]
rows[table] = data[1:]
for table, fks in sort_tables_and_constraints(
cls._tables_metadata.tables.values()
):
if table is None:
continue
if table not in headers:
continue
with cls.bind.begin() as conn:
conn.execute(
table.insert(),
[
dict(zip(headers[table], column_values))
for column_values in rows[table]
],
)
class NoCache:
@config.fixture(autouse=True, scope="function")
def _disable_cache(self):
_cache = config.db._compiled_cache
config.db._compiled_cache = None
yield
config.db._compiled_cache = _cache
class RemovesEvents:
@util.memoized_property
def _event_fns(self):
return set()
def event_listen(self, target, name, fn, **kw):
self._event_fns.add((target, name, fn))
event.listen(target, name, fn, **kw)
@config.fixture(autouse=True, scope="function")
def _remove_events(self):
yield
for key in self._event_fns:
event.remove(*key)
class ComputedReflectionFixtureTest(TablesTest):
run_inserts = run_deletes = None
__backend__ = True
__requires__ = ("computed_columns", "table_reflection")
regexp = re.compile(r"[\[\]\(\)\s`'\"]*")
def normalize(self, text):
return self.regexp.sub("", text).lower()
@classmethod
def define_tables(cls, metadata):
from ... import Integer
from ... import testing
from ...schema import Column
from ...schema import Computed
from ...schema import Table
Table(
"computed_default_table",
metadata,
Column("id", Integer, primary_key=True),
Column("normal", Integer),
Column("computed_col", Integer, Computed("normal + 42")),
Column("with_default", Integer, server_default="42"),
)
t = Table(
"computed_column_table",
metadata,
Column("id", Integer, primary_key=True),
Column("normal", Integer),
Column("computed_no_flag", Integer, Computed("normal + 42")),
)
if testing.requires.schemas.enabled:
t2 = Table(
"computed_column_table",
metadata,
Column("id", Integer, primary_key=True),
Column("normal", Integer),
Column("computed_no_flag", Integer, Computed("normal / 42")),
schema=config.test_schema,
)
if testing.requires.computed_columns_virtual.enabled:
t.append_column(
Column(
"computed_virtual",
Integer,
Computed("normal + 2", persisted=False),
)
)
if testing.requires.schemas.enabled:
t2.append_column(
Column(
"computed_virtual",
Integer,
Computed("normal / 2", persisted=False),
)
)
if testing.requires.computed_columns_stored.enabled:
t.append_column(
Column(
"computed_stored",
Integer,
Computed("normal - 42", persisted=True),
)
)
if testing.requires.schemas.enabled:
t2.append_column(
Column(
"computed_stored",
Integer,
Computed("normal * 42", persisted=True),
)
)
class CacheKeyFixture:
def _compare_equal(self, a, b, compare_values):
a_key = a._generate_cache_key()
b_key = b._generate_cache_key()
if a_key is None:
assert a._annotations.get("nocache")
assert b_key is None
else:
eq_(a_key.key, b_key.key)
eq_(hash(a_key.key), hash(b_key.key))
for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
assert a_param.compare(b_param, compare_values=compare_values)
return a_key, b_key
def _run_cache_key_fixture(self, fixture, compare_values):
case_a = fixture()
case_b = fixture()
for a, b in itertools.combinations_with_replacement(
range(len(case_a)), 2
):
if a == b:
a_key, b_key = self._compare_equal(
case_a[a], case_b[b], compare_values
)
if a_key is None:
continue
else:
a_key = case_a[a]._generate_cache_key()
b_key = case_b[b]._generate_cache_key()
if a_key is None or b_key is None:
if a_key is None:
assert case_a[a]._annotations.get("nocache")
if b_key is None:
assert case_b[b]._annotations.get("nocache")
continue
if a_key.key == b_key.key:
for a_param, b_param in zip(
a_key.bindparams, b_key.bindparams
):
if not a_param.compare(
b_param, compare_values=compare_values
):
break
else:
# this fails unconditionally since we could not
# find bound parameter values that differed.
# Usually we intended to get two distinct keys here
# so the failure will be more descriptive using the
# ne_() assertion.
ne_(a_key.key, b_key.key)
else:
ne_(a_key.key, b_key.key)
# ClauseElement-specific test to ensure the cache key
# collected all the bound parameters that aren't marked
# as "literal execute"
if isinstance(case_a[a], ClauseElement) and isinstance(
case_b[b], ClauseElement
):
assert_a_params = []
assert_b_params = []
for elem in visitors.iterate(case_a[a]):
if elem.__visit_name__ == "bindparam":
assert_a_params.append(elem)
for elem in visitors.iterate(case_b[b]):
if elem.__visit_name__ == "bindparam":
assert_b_params.append(elem)
# note we're asserting the order of the params as well as
# if there are dupes or not. ordering has to be
# deterministic and matches what a traversal would provide.
eq_(
sorted(a_key.bindparams, key=lambda b: b.key),
sorted(
util.unique_list(assert_a_params), key=lambda b: b.key
),
)
eq_(
sorted(b_key.bindparams, key=lambda b: b.key),
sorted(
util.unique_list(assert_b_params), key=lambda b: b.key
),
)
def _run_cache_key_equal_fixture(self, fixture, compare_values):
case_a = fixture()
case_b = fixture()
for a, b in itertools.combinations_with_replacement(
range(len(case_a)), 2
):
self._compare_equal(case_a[a], case_b[b], compare_values)
def insertmanyvalues_fixture(
connection, randomize_rows=False, warn_on_downgraded=False
):
dialect = connection.dialect
orig_dialect = dialect._deliver_insertmanyvalues_batches
orig_conn = connection._exec_insertmany_context
class RandomCursor:
__slots__ = ("cursor",)
def __init__(self, cursor):
self.cursor = cursor
# only this method is called by the deliver method.
# by not having the other methods we assert that those aren't being
# used
@property
def description(self):
return self.cursor.description
def fetchall(self):
rows = self.cursor.fetchall()
rows = list(rows)
random.shuffle(rows)
return rows
def _deliver_insertmanyvalues_batches(
connection,
cursor,
statement,
parameters,
generic_setinputsizes,
context,
):
if randomize_rows:
cursor = RandomCursor(cursor)
for batch in orig_dialect(
connection,
cursor,
statement,
parameters,
generic_setinputsizes,
context,
):
if warn_on_downgraded and batch.is_downgraded:
util.warn("Batches were downgraded for sorted INSERT")
yield batch
def _exec_insertmany_context(dialect, context):
with mock.patch.object(
dialect,
"_deliver_insertmanyvalues_batches",
new=_deliver_insertmanyvalues_batches,
):
return orig_conn(dialect, context)
connection._exec_insertmany_context = _exec_insertmany_context

View File

@ -0,0 +1,155 @@
# testing/pickleable.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""Classes used in pickling tests, need to be at the module level for
unpickling.
"""
from __future__ import annotations
from .entities import ComparableEntity
from ..schema import Column
from ..types import String
class User(ComparableEntity):
pass
class Order(ComparableEntity):
pass
class Dingaling(ComparableEntity):
pass
class EmailUser(User):
pass
class Address(ComparableEntity):
pass
# TODO: these are kind of arbitrary....
class Child1(ComparableEntity):
pass
class Child2(ComparableEntity):
pass
class Parent(ComparableEntity):
pass
class Screen:
def __init__(self, obj, parent=None):
self.obj = obj
self.parent = parent
class Mixin:
email_address = Column(String)
class AddressWMixin(Mixin, ComparableEntity):
pass
class Foo:
def __init__(self, moredata, stuff="im stuff"):
self.data = "im data"
self.stuff = stuff
self.moredata = moredata
__hash__ = object.__hash__
def __eq__(self, other):
return (
other.data == self.data
and other.stuff == self.stuff
and other.moredata == self.moredata
)
class Bar:
def __init__(self, x, y):
self.x = x
self.y = y
__hash__ = object.__hash__
def __eq__(self, other):
return (
other.__class__ is self.__class__
and other.x == self.x
and other.y == self.y
)
def __str__(self):
return "Bar(%d, %d)" % (self.x, self.y)
class OldSchool:
def __init__(self, x, y):
self.x = x
self.y = y
def __eq__(self, other):
return (
other.__class__ is self.__class__
and other.x == self.x
and other.y == self.y
)
class OldSchoolWithoutCompare:
def __init__(self, x, y):
self.x = x
self.y = y
class BarWithoutCompare:
def __init__(self, x, y):
self.x = x
self.y = y
def __str__(self):
return "Bar(%d, %d)" % (self.x, self.y)
class NotComparable:
def __init__(self, data):
self.data = data
def __hash__(self):
return id(self)
def __eq__(self, other):
return NotImplemented
def __ne__(self, other):
return NotImplemented
class BrokenComparable:
def __init__(self, data):
self.data = data
def __hash__(self):
return id(self)
def __eq__(self, other):
raise NotImplementedError
def __ne__(self, other):
raise NotImplementedError

View File

@ -0,0 +1,6 @@
# testing/plugin/__init__.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php

View File

@ -0,0 +1,51 @@
# testing/plugin/bootstrap.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""
Bootstrapper for test framework plugins.
The entire rationale for this system is to get the modules in plugin/
imported without importing all of the supporting library, so that we can
set up things for testing before coverage starts.
The rationale for all of plugin/ being *in* the supporting library in the
first place is so that the testing and plugin suite is available to other
libraries, mainly external SQLAlchemy and Alembic dialects, to make use
of the same test environment and standard suites available to
SQLAlchemy/Alembic themselves without the need to ship/install a separate
package outside of SQLAlchemy.
"""
import importlib.util
import os
import sys
bootstrap_file = locals()["bootstrap_file"]
to_bootstrap = locals()["to_bootstrap"]
def load_file_as_module(name):
path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
spec = importlib.util.spec_from_file_location(name, path)
assert spec is not None
assert spec.loader is not None
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
if to_bootstrap == "pytest":
sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
sys.modules["sqla_plugin_base"].bootstrapped_as_sqlalchemy = True
sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
else:
raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa

View File

@ -0,0 +1,779 @@
# testing/plugin/plugin_base.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import abc
from argparse import Namespace
import configparser
import logging
import os
from pathlib import Path
import re
import sys
from typing import Any
from sqlalchemy.testing import asyncio
"""Testing extensions.
this module is designed to work as a testing-framework-agnostic library,
created so that multiple test frameworks can be supported at once
(mostly so that we can migrate to new ones). The current target
is pytest.
"""
# flag which indicates we are in the SQLAlchemy testing suite,
# and not that of Alembic or a third party dialect.
bootstrapped_as_sqlalchemy = False
log = logging.getLogger("sqlalchemy.testing.plugin_base")
# late imports
fixtures = None
engines = None
exclusions = None
warnings = None
profiling = None
provision = None
assertions = None
requirements = None
config = None
testing = None
util = None
file_config = None
logging = None
include_tags = set()
exclude_tags = set()
options: Namespace = None # type: ignore
def setup_options(make_option):
make_option(
"--log-info",
action="callback",
type=str,
callback=_log,
help="turn on info logging for <LOG> (multiple OK)",
)
make_option(
"--log-debug",
action="callback",
type=str,
callback=_log,
help="turn on debug logging for <LOG> (multiple OK)",
)
make_option(
"--db",
action="append",
type=str,
dest="db",
help="Use prefab database uri. Multiple OK, "
"first one is run by default.",
)
make_option(
"--dbs",
action="callback",
zeroarg_callback=_list_dbs,
help="List available prefab dbs",
)
make_option(
"--dburi",
action="append",
type=str,
dest="dburi",
help="Database uri. Multiple OK, first one is run by default.",
)
make_option(
"--dbdriver",
action="append",
type=str,
dest="dbdriver",
help="Additional database drivers to include in tests. "
"These are linked to the existing database URLs by the "
"provisioning system.",
)
make_option(
"--dropfirst",
action="store_true",
dest="dropfirst",
help="Drop all tables in the target database first",
)
make_option(
"--disable-asyncio",
action="store_true",
help="disable test / fixtures / provisoning running in asyncio",
)
make_option(
"--backend-only",
action="callback",
zeroarg_callback=_set_tag_include("backend"),
help=(
"Run only tests marked with __backend__ or __sparse_backend__; "
"this is now equivalent to the pytest -m backend mark expression"
),
)
make_option(
"--nomemory",
action="callback",
zeroarg_callback=_set_tag_exclude("memory_intensive"),
help="Don't run memory profiling tests; "
"this is now equivalent to the pytest -m 'not memory_intensive' "
"mark expression",
)
make_option(
"--notimingintensive",
action="callback",
zeroarg_callback=_set_tag_exclude("timing_intensive"),
help="Don't run timing intensive tests; "
"this is now equivalent to the pytest -m 'not timing_intensive' "
"mark expression",
)
make_option(
"--nomypy",
action="callback",
zeroarg_callback=_set_tag_exclude("mypy"),
help="Don't run mypy typing tests; "
"this is now equivalent to the pytest -m 'not mypy' mark expression",
)
make_option(
"--profile-sort",
type=str,
default="cumulative",
dest="profilesort",
help="Type of sort for profiling standard output",
)
make_option(
"--profile-dump",
type=str,
dest="profiledump",
help="Filename where a single profile run will be dumped",
)
make_option(
"--low-connections",
action="store_true",
dest="low_connections",
help="Use a low number of distinct connections - "
"i.e. for Oracle TNS",
)
make_option(
"--write-idents",
type=str,
dest="write_idents",
help="write out generated follower idents to <file>, "
"when -n<num> is used",
)
make_option(
"--requirements",
action="callback",
type=str,
callback=_requirements_opt,
help="requirements class for testing, overrides setup.cfg",
)
make_option(
"--include-tag",
action="callback",
callback=_include_tag,
type=str,
help="Include tests with tag <tag>; "
"legacy, use pytest -m 'tag' instead",
)
make_option(
"--exclude-tag",
action="callback",
callback=_exclude_tag,
type=str,
help="Exclude tests with tag <tag>; "
"legacy, use pytest -m 'not tag' instead",
)
make_option(
"--write-profiles",
action="store_true",
dest="write_profiles",
default=False,
help="Write/update failing profiling data.",
)
make_option(
"--force-write-profiles",
action="store_true",
dest="force_write_profiles",
default=False,
help="Unconditionally write/update profiling data.",
)
make_option(
"--dump-pyannotate",
type=str,
dest="dump_pyannotate",
help="Run pyannotate and dump json info to given file",
)
make_option(
"--mypy-extra-test-path",
type=str,
action="append",
default=[],
dest="mypy_extra_test_paths",
help="Additional test directories to add to the mypy tests. "
"This is used only when running mypy tests. Multiple OK",
)
# db specific options
make_option(
"--postgresql-templatedb",
type=str,
help="name of template database to use for PostgreSQL "
"CREATE DATABASE (defaults to current database)",
)
make_option(
"--oracledb-thick-mode",
action="store_true",
help="enables the 'thick mode' when testing with oracle+oracledb",
)
def configure_follower(follower_ident):
"""Configure required state for a follower.
This invokes in the parent process and typically includes
database creation.
"""
from sqlalchemy.testing import provision
provision.FOLLOWER_IDENT = follower_ident
def memoize_important_follower_config(dict_):
"""Store important configuration we will need to send to a follower.
This invokes in the parent process after normal config is set up.
Hook is currently not used.
"""
def restore_important_follower_config(dict_):
"""Restore important configuration needed by a follower.
This invokes in the follower process.
Hook is currently not used.
"""
def read_config(root_path):
global file_config
file_config = configparser.ConfigParser()
file_config.read(
[str(root_path / "setup.cfg"), str(root_path / "test.cfg")]
)
def pre_begin(opt):
"""things to set up early, before coverage might be setup."""
global options
options = opt
for fn in pre_configure:
fn(options, file_config)
def set_coverage_flag(value):
options.has_coverage = value
def post_begin():
"""things to set up later, once we know coverage is running."""
# Lazy setup of other options (post coverage)
for fn in post_configure:
fn(options, file_config)
# late imports, has to happen after config.
global util, fixtures, engines, exclusions, assertions, provision
global warnings, profiling, config, testing
from sqlalchemy import testing # noqa
from sqlalchemy.testing import fixtures, engines, exclusions # noqa
from sqlalchemy.testing import assertions, warnings, profiling # noqa
from sqlalchemy.testing import config, provision # noqa
from sqlalchemy import util # noqa
warnings.setup_filters()
def _log(opt_str, value, parser):
global logging
if not logging:
import logging
logging.basicConfig()
if opt_str.endswith("-info"):
logging.getLogger(value).setLevel(logging.INFO)
elif opt_str.endswith("-debug"):
logging.getLogger(value).setLevel(logging.DEBUG)
def _list_dbs(*args):
if file_config is None:
# assume the current working directory is the one containing the
# setup file
read_config(Path.cwd())
print("Available --db options (use --dburi to override)")
for macro in sorted(file_config.options("db")):
print("%20s\t%s" % (macro, file_config.get("db", macro)))
sys.exit(0)
def _requirements_opt(opt_str, value, parser):
_setup_requirements(value)
def _set_tag_include(tag):
def _do_include_tag(opt_str, value, parser):
_include_tag(opt_str, tag, parser)
return _do_include_tag
def _set_tag_exclude(tag):
def _do_exclude_tag(opt_str, value, parser):
_exclude_tag(opt_str, tag, parser)
return _do_exclude_tag
def _exclude_tag(opt_str, value, parser):
exclude_tags.add(value.replace("-", "_"))
def _include_tag(opt_str, value, parser):
include_tags.add(value.replace("-", "_"))
pre_configure = []
post_configure = []
def pre(fn):
pre_configure.append(fn)
return fn
def post(fn):
post_configure.append(fn)
return fn
@pre
def _setup_options(opt, file_config):
global options
options = opt
@pre
def _register_sqlite_numeric_dialect(opt, file_config):
from sqlalchemy.dialects import registry
registry.register(
"sqlite.pysqlite_numeric",
"sqlalchemy.dialects.sqlite.pysqlite",
"_SQLiteDialect_pysqlite_numeric",
)
registry.register(
"sqlite.pysqlite_dollar",
"sqlalchemy.dialects.sqlite.pysqlite",
"_SQLiteDialect_pysqlite_dollar",
)
@post
def __ensure_cext(opt, file_config):
if os.environ.get("REQUIRE_SQLALCHEMY_CEXT", "0") == "1":
from sqlalchemy.util import has_compiled_ext
try:
has_compiled_ext(raise_=True)
except ImportError as err:
raise AssertionError(
"REQUIRE_SQLALCHEMY_CEXT is set but can't import the "
"cython extensions"
) from err
@post
def _init_symbols(options, file_config):
from sqlalchemy.testing import config
config._fixture_functions = _fixture_fn_class()
@pre
def _set_disable_asyncio(opt, file_config):
if opt.disable_asyncio:
asyncio.ENABLE_ASYNCIO = False
@post
def _engine_uri(options, file_config):
from sqlalchemy import testing
from sqlalchemy.testing import config
from sqlalchemy.testing import provision
from sqlalchemy.engine import url as sa_url
if options.dburi:
db_urls = list(options.dburi)
else:
db_urls = []
extra_drivers = options.dbdriver or []
if options.db:
for db_token in options.db:
for db in re.split(r"[,\s]+", db_token):
if db not in file_config.options("db"):
raise RuntimeError(
"Unknown URI specifier '%s'. "
"Specify --dbs for known uris." % db
)
else:
db_urls.append(file_config.get("db", db))
if not db_urls:
db_urls.append(file_config.get("db", "default"))
config._current = None
if options.write_idents and provision.FOLLOWER_IDENT:
for db_url in [sa_url.make_url(db_url) for db_url in db_urls]:
with open(options.write_idents, "a") as file_:
file_.write(
f"{provision.FOLLOWER_IDENT} "
f"{db_url.render_as_string(hide_password=False)}\n"
)
expanded_urls = list(provision.generate_db_urls(db_urls, extra_drivers))
for db_url in expanded_urls:
log.info("Adding database URL: %s", db_url)
cfg = provision.setup_config(
db_url, options, file_config, provision.FOLLOWER_IDENT
)
if not config._current:
cfg.set_as_current(cfg, testing)
@post
def _requirements(options, file_config):
requirement_cls = file_config.get("sqla_testing", "requirement_cls")
_setup_requirements(requirement_cls)
def _setup_requirements(argument):
from sqlalchemy.testing import config
from sqlalchemy import testing
modname, clsname = argument.split(":")
# importlib.import_module() only introduced in 2.7, a little
# late
mod = __import__(modname)
for component in modname.split(".")[1:]:
mod = getattr(mod, component)
req_cls = getattr(mod, clsname)
config.requirements = testing.requires = req_cls()
config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy
@post
def _prep_testing_database(options, file_config):
from sqlalchemy.testing import config
if options.dropfirst:
from sqlalchemy.testing import provision
for cfg in config.Config.all_configs():
provision.drop_all_schema_objects(cfg, cfg.db)
@post
def _post_setup_options(opt, file_config):
from sqlalchemy.testing import config
config.options = options
config.file_config = file_config
@post
def _setup_profiling(options, file_config):
from sqlalchemy.testing import profiling
profiling._profile_stats = profiling.ProfileStatsFile(
file_config.get("sqla_testing", "profile_file"),
sort=options.profilesort,
dump=options.profiledump,
)
def want_class(name, cls):
if not issubclass(cls, fixtures.TestBase):
return False
elif name.startswith("_"):
return False
else:
return True
def want_method(cls, fn):
if not fn.__name__.startswith("test_"):
return False
elif fn.__module__ is None:
return False
else:
return True
def generate_sub_tests(cls, module, markers):
if "backend" in markers or "sparse_backend" in markers:
sparse = "sparse_backend" in markers
for cfg in _possible_configs_for_cls(cls, sparse=sparse):
orig_name = cls.__name__
# we can have special chars in these names except for the
# pytest junit plugin, which is tripped up by the brackets
# and periods, so sanitize
alpha_name = re.sub(r"[_\[\]\.]+", "_", cfg.name)
alpha_name = re.sub(r"_+$", "", alpha_name)
name = "%s_%s" % (cls.__name__, alpha_name)
subcls = type(
name,
(cls,),
{"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg},
)
setattr(module, name, subcls)
yield subcls
else:
yield cls
def start_test_class_outside_fixtures(cls):
_do_skips(cls)
_setup_engine(cls)
def stop_test_class(cls):
# close sessions, immediate connections, etc.
fixtures.stop_test_class_inside_fixtures(cls)
# close outstanding connection pool connections, dispose of
# additional engines
engines.testing_reaper.stop_test_class_inside_fixtures()
def stop_test_class_outside_fixtures(cls):
engines.testing_reaper.stop_test_class_outside_fixtures()
provision.stop_test_class_outside_fixtures(config, config.db, cls)
try:
if not options.low_connections:
assertions.global_cleanup_assertions()
finally:
_restore_engine()
def _restore_engine():
if config._current:
config._current.reset(testing)
def final_process_cleanup():
engines.testing_reaper.final_cleanup()
assertions.global_cleanup_assertions()
_restore_engine()
def _setup_engine(cls):
if getattr(cls, "__engine_options__", None):
opts = dict(cls.__engine_options__)
opts["scope"] = "class"
eng = engines.testing_engine(options=opts)
config._current.push_engine(eng, testing)
def before_test(test, test_module_name, test_class, test_name):
# format looks like:
# "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause"
name = getattr(test_class, "_sa_orig_cls_name", test_class.__name__)
id_ = "%s.%s.%s" % (test_module_name, name, test_name)
profiling._start_current_test(id_)
def after_test(test):
fixtures.after_test()
engines.testing_reaper.after_test()
def after_test_fixtures(test):
engines.testing_reaper.after_test_outside_fixtures(test)
def _possible_configs_for_cls(cls, reasons=None, sparse=False):
all_configs = set(config.Config.all_configs())
if cls.__unsupported_on__:
spec = exclusions.db_spec(*cls.__unsupported_on__)
for config_obj in list(all_configs):
if spec(config_obj):
all_configs.remove(config_obj)
if getattr(cls, "__only_on__", None):
spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
for config_obj in list(all_configs):
if not spec(config_obj):
all_configs.remove(config_obj)
if getattr(cls, "__only_on_config__", None):
all_configs.intersection_update([cls.__only_on_config__])
if hasattr(cls, "__requires__"):
requirements = config.requirements
for config_obj in list(all_configs):
for requirement in cls.__requires__:
check = getattr(requirements, requirement)
skip_reasons = check.matching_config_reasons(config_obj)
if skip_reasons:
all_configs.remove(config_obj)
if reasons is not None:
reasons.extend(skip_reasons)
break
if hasattr(cls, "__prefer_requires__"):
non_preferred = set()
requirements = config.requirements
for config_obj in list(all_configs):
for requirement in cls.__prefer_requires__:
check = getattr(requirements, requirement)
if not check.enabled_for_config(config_obj):
non_preferred.add(config_obj)
if all_configs.difference(non_preferred):
all_configs.difference_update(non_preferred)
if sparse:
# pick only one config from each base dialect
# sorted so we get the same backend each time selecting the highest
# server version info.
per_dialect = {}
for cfg in reversed(
sorted(
all_configs,
key=lambda cfg: (
cfg.db.name,
cfg.db.driver,
cfg.db.dialect.server_version_info,
),
)
):
db = cfg.db.name
if db not in per_dialect:
per_dialect[db] = cfg
return per_dialect.values()
return all_configs
def _do_skips(cls):
reasons = []
all_configs = _possible_configs_for_cls(cls, reasons)
if getattr(cls, "__skip_if__", False):
for c in getattr(cls, "__skip_if__"):
if c():
config.skip_test(
"'%s' skipped by %s" % (cls.__name__, c.__name__)
)
if not all_configs:
msg = "'%s.%s' unsupported on any DB implementation %s%s" % (
cls.__module__,
cls.__name__,
", ".join(
"'%s(%s)+%s'"
% (
config_obj.db.name,
".".join(
str(dig)
for dig in exclusions._server_version(config_obj.db)
),
config_obj.db.driver,
)
for config_obj in config.Config.all_configs()
),
", ".join(reasons),
)
config.skip_test(msg)
elif hasattr(cls, "__prefer_backends__"):
non_preferred = set()
spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
for config_obj in all_configs:
if not spec(config_obj):
non_preferred.add(config_obj)
if all_configs.difference(non_preferred):
all_configs.difference_update(non_preferred)
if config._current not in all_configs:
_setup_config(all_configs.pop(), cls)
def _setup_config(config_obj, ctx):
config._current.push(config_obj, testing)
class FixtureFunctions(abc.ABC):
@abc.abstractmethod
def skip_test_exception(self, *arg, **kw):
raise NotImplementedError()
@abc.abstractmethod
def combinations(self, *args, **kw):
raise NotImplementedError()
@abc.abstractmethod
def param_ident(self, *args, **kw):
raise NotImplementedError()
@abc.abstractmethod
def fixture(self, *arg, **kw):
raise NotImplementedError()
def get_current_test_name(self):
raise NotImplementedError()
@abc.abstractmethod
def mark_base_test_class(self) -> Any:
raise NotImplementedError()
@abc.abstractproperty
def add_to_marker(self):
raise NotImplementedError()
_fixture_fn_class = None
def set_fixture_functions(fixture_fn_class):
global _fixture_fn_class
_fixture_fn_class = fixture_fn_class

View File

@ -0,0 +1,868 @@
# testing/plugin/pytestplugin.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import argparse
import collections
from functools import update_wrapper
import inspect
import itertools
import operator
import os
import re
import sys
from typing import TYPE_CHECKING
import uuid
import pytest
try:
# installed by bootstrap.py
if not TYPE_CHECKING:
import sqla_plugin_base as plugin_base
except ImportError:
# assume we're a package, use traditional import
from . import plugin_base
def pytest_addoption(parser):
group = parser.getgroup("sqlalchemy")
def make_option(name, **kw):
callback_ = kw.pop("callback", None)
if callback_:
class CallableAction(argparse.Action):
def __call__(
self, parser, namespace, values, option_string=None
):
callback_(option_string, values, parser)
kw["action"] = CallableAction
zeroarg_callback = kw.pop("zeroarg_callback", None)
if zeroarg_callback:
class CallableAction(argparse.Action):
def __init__(
self,
option_strings,
dest,
default=False,
required=False,
help=None, # noqa
):
super().__init__(
option_strings=option_strings,
dest=dest,
nargs=0,
const=True,
default=default,
required=required,
help=help,
)
def __call__(
self, parser, namespace, values, option_string=None
):
zeroarg_callback(option_string, values, parser)
kw["action"] = CallableAction
group.addoption(name, **kw)
plugin_base.setup_options(make_option)
def pytest_configure(config: pytest.Config):
plugin_base.read_config(config.rootpath)
if plugin_base.exclude_tags or plugin_base.include_tags:
new_expr = " and ".join(
list(plugin_base.include_tags)
+ [f"not {tag}" for tag in plugin_base.exclude_tags]
)
if config.option.markexpr:
config.option.markexpr += f" and {new_expr}"
else:
config.option.markexpr = new_expr
if config.pluginmanager.hasplugin("xdist"):
config.pluginmanager.register(XDistHooks())
if hasattr(config, "workerinput"):
plugin_base.restore_important_follower_config(config.workerinput)
plugin_base.configure_follower(config.workerinput["follower_ident"])
else:
if config.option.write_idents and os.path.exists(
config.option.write_idents
):
os.remove(config.option.write_idents)
plugin_base.pre_begin(config.option)
plugin_base.set_coverage_flag(
bool(getattr(config.option, "cov_source", False))
)
plugin_base.set_fixture_functions(PytestFixtureFunctions)
if config.option.dump_pyannotate:
global DUMP_PYANNOTATE
DUMP_PYANNOTATE = True
DUMP_PYANNOTATE = False
@pytest.fixture(autouse=True)
def collect_types_fixture():
if DUMP_PYANNOTATE:
from pyannotate_runtime import collect_types
collect_types.start()
yield
if DUMP_PYANNOTATE:
collect_types.stop()
def _log_sqlalchemy_info(session):
import sqlalchemy
from sqlalchemy import __version__
from sqlalchemy.util import has_compiled_ext
from sqlalchemy.util._has_cy import _CYEXTENSION_MSG
greet = "sqlalchemy installation"
site = "no user site" if sys.flags.no_user_site else "user site loaded"
msgs = [
f"SQLAlchemy {__version__} ({site})",
f"Path: {sqlalchemy.__file__}",
]
if has_compiled_ext():
from sqlalchemy.cyextension import util
msgs.append(f"compiled extension enabled, e.g. {util.__file__} ")
else:
msgs.append(f"compiled extension not enabled; {_CYEXTENSION_MSG}")
pm = session.config.pluginmanager.get_plugin("terminalreporter")
if pm:
pm.write_sep("=", greet)
for m in msgs:
pm.write_line(m)
else:
# fancy pants reporter not found, fallback to plain print
print("=" * 25, greet, "=" * 25)
for m in msgs:
print(m)
def pytest_sessionstart(session):
from sqlalchemy.testing import asyncio
_log_sqlalchemy_info(session)
asyncio._assume_async(plugin_base.post_begin)
def pytest_sessionfinish(session):
from sqlalchemy.testing import asyncio
asyncio._maybe_async_provisioning(plugin_base.final_process_cleanup)
if session.config.option.dump_pyannotate:
from pyannotate_runtime import collect_types
collect_types.dump_stats(session.config.option.dump_pyannotate)
def pytest_unconfigure(config):
from sqlalchemy.testing import asyncio
asyncio._shutdown()
def pytest_collection_finish(session):
if session.config.option.dump_pyannotate:
from pyannotate_runtime import collect_types
lib_sqlalchemy = os.path.abspath("lib/sqlalchemy")
def _filter(filename):
filename = os.path.normpath(os.path.abspath(filename))
if "lib/sqlalchemy" not in os.path.commonpath(
[filename, lib_sqlalchemy]
):
return None
if "testing" in filename:
return None
return filename
collect_types.init_types_collection(filter_filename=_filter)
class XDistHooks:
def pytest_configure_node(self, node):
from sqlalchemy.testing import provision
from sqlalchemy.testing import asyncio
# the master for each node fills workerinput dictionary
# which pytest-xdist will transfer to the subprocess
plugin_base.memoize_important_follower_config(node.workerinput)
node.workerinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]
asyncio._maybe_async_provisioning(
provision.create_follower_db, node.workerinput["follower_ident"]
)
def pytest_testnodedown(self, node, error):
from sqlalchemy.testing import provision
from sqlalchemy.testing import asyncio
asyncio._maybe_async_provisioning(
provision.drop_follower_db, node.workerinput["follower_ident"]
)
def pytest_collection_modifyitems(session, config, items):
# look for all those classes that specify __backend__ and
# expand them out into per-database test cases.
# this is much easier to do within pytest_pycollect_makeitem, however
# pytest is iterating through cls.__dict__ as makeitem is
# called which causes a "dictionary changed size" error on py3k.
# I'd submit a pullreq for them to turn it into a list first, but
# it's to suit the rather odd use case here which is that we are adding
# new classes to a module on the fly.
from sqlalchemy.testing import asyncio
rebuilt_items = collections.defaultdict(
lambda: collections.defaultdict(list)
)
items[:] = [
item
for item in items
if item.getparent(pytest.Class) is not None
and not item.getparent(pytest.Class).name.startswith("_")
]
test_classes = {item.getparent(pytest.Class) for item in items}
def collect(element):
for inst_or_fn in element.collect():
if isinstance(inst_or_fn, pytest.Collector):
yield from collect(inst_or_fn)
else:
yield inst_or_fn
def setup_test_classes():
for test_class in test_classes:
# transfer legacy __backend__ and __sparse_backend__ symbols
# to be markers
add_markers = set()
if getattr(test_class.cls, "__backend__", False) or getattr(
test_class.cls, "__only_on__", False
):
add_markers = {"backend"}
elif getattr(test_class.cls, "__sparse_backend__", False):
add_markers = {"sparse_backend"}
else:
add_markers = frozenset()
existing_markers = {
mark.name for mark in test_class.iter_markers()
}
add_markers = add_markers - existing_markers
all_markers = existing_markers.union(add_markers)
for marker in add_markers:
test_class.add_marker(marker)
for sub_cls in plugin_base.generate_sub_tests(
test_class.cls, test_class.module, all_markers
):
if sub_cls is not test_class.cls:
per_cls_dict = rebuilt_items[test_class.cls]
module = test_class.getparent(pytest.Module)
new_cls = pytest.Class.from_parent(
name=sub_cls.__name__, parent=module
)
for marker in add_markers:
new_cls.add_marker(marker)
for fn in collect(new_cls):
per_cls_dict[fn.name].append(fn)
# class requirements will sometimes need to access the DB to check
# capabilities, so need to do this for async
asyncio._maybe_async_provisioning(setup_test_classes)
newitems = []
for item in items:
cls_ = item.cls
if cls_ in rebuilt_items:
newitems.extend(rebuilt_items[cls_][item.name])
else:
newitems.append(item)
# seems like the functions attached to a test class aren't sorted already?
# is that true and why's that? (when using unittest, they're sorted)
items[:] = sorted(
newitems,
key=lambda item: (
item.getparent(pytest.Module).name,
item.getparent(pytest.Class).name,
item.name,
),
)
def pytest_pycollect_makeitem(collector, name, obj):
if inspect.isclass(obj) and plugin_base.want_class(name, obj):
from sqlalchemy.testing import config
if config.any_async:
obj = _apply_maybe_async(obj)
return [
pytest.Class.from_parent(
name=parametrize_cls.__name__, parent=collector
)
for parametrize_cls in _parametrize_cls(collector.module, obj)
]
elif (
inspect.isfunction(obj)
and collector.cls is not None
and plugin_base.want_method(collector.cls, obj)
):
# None means, fall back to default logic, which includes
# method-level parametrize
return None
else:
# empty list means skip this item
return []
def _is_wrapped_coroutine_function(fn):
while hasattr(fn, "__wrapped__"):
fn = fn.__wrapped__
return inspect.iscoroutinefunction(fn)
def _apply_maybe_async(obj, recurse=True):
from sqlalchemy.testing import asyncio
for name, value in vars(obj).items():
if (
(callable(value) or isinstance(value, classmethod))
and not getattr(value, "_maybe_async_applied", False)
and (name.startswith("test_"))
and not _is_wrapped_coroutine_function(value)
):
is_classmethod = False
if isinstance(value, classmethod):
value = value.__func__
is_classmethod = True
@_pytest_fn_decorator
def make_async(fn, *args, **kwargs):
return asyncio._maybe_async(fn, *args, **kwargs)
do_async = make_async(value)
if is_classmethod:
do_async = classmethod(do_async)
do_async._maybe_async_applied = True
setattr(obj, name, do_async)
if recurse:
for cls in obj.mro()[1:]:
if cls != object:
_apply_maybe_async(cls, False)
return obj
def _parametrize_cls(module, cls):
"""implement a class-based version of pytest parametrize."""
if "_sa_parametrize" not in cls.__dict__:
return [cls]
_sa_parametrize = cls._sa_parametrize
classes = []
for full_param_set in itertools.product(
*[params for argname, params in _sa_parametrize]
):
cls_variables = {}
for argname, param in zip(
[_sa_param[0] for _sa_param in _sa_parametrize], full_param_set
):
if not argname:
raise TypeError("need argnames for class-based combinations")
argname_split = re.split(r",\s*", argname)
for arg, val in zip(argname_split, param.values):
cls_variables[arg] = val
parametrized_name = "_".join(
re.sub(r"\W", "", token)
for param in full_param_set
for token in param.id.split("-")
)
name = "%s_%s" % (cls.__name__, parametrized_name)
newcls = type.__new__(type, name, (cls,), cls_variables)
setattr(module, name, newcls)
classes.append(newcls)
return classes
_current_class = None
def pytest_runtest_setup(item):
from sqlalchemy.testing import asyncio
# pytest_runtest_setup runs *before* pytest fixtures with scope="class".
# plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest
# for the whole class and has to run things that are across all current
# databases, so we run this outside of the pytest fixture system altogether
# and ensure asyncio greenlet if any engines are async
global _current_class
if isinstance(item, pytest.Function) and _current_class is None:
asyncio._maybe_async_provisioning(
plugin_base.start_test_class_outside_fixtures,
item.cls,
)
_current_class = item.getparent(pytest.Class)
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_teardown(item, nextitem):
# runs inside of pytest function fixture scope
# after test function runs
from sqlalchemy.testing import asyncio
asyncio._maybe_async(plugin_base.after_test, item)
yield
# this is now after all the fixture teardown have run, the class can be
# finalized. Since pytest v7 this finalizer can no longer be added in
# pytest_runtest_setup since the class has not yet been setup at that
# time.
# See https://github.com/pytest-dev/pytest/issues/9343
global _current_class, _current_report
if _current_class is not None and (
# last test or a new class
nextitem is None
or nextitem.getparent(pytest.Class) is not _current_class
):
_current_class = None
try:
asyncio._maybe_async_provisioning(
plugin_base.stop_test_class_outside_fixtures, item.cls
)
except Exception as e:
# in case of an exception during teardown attach the original
# error to the exception message, otherwise it will get lost
if _current_report.failed:
if not e.args:
e.args = (
"__Original test failure__:\n"
+ _current_report.longreprtext,
)
elif e.args[-1] and isinstance(e.args[-1], str):
args = list(e.args)
args[-1] += (
"\n__Original test failure__:\n"
+ _current_report.longreprtext
)
e.args = tuple(args)
else:
e.args += (
"__Original test failure__",
_current_report.longreprtext,
)
raise
finally:
_current_report = None
def pytest_runtest_call(item):
# runs inside of pytest function fixture scope
# before test function runs
from sqlalchemy.testing import asyncio
asyncio._maybe_async(
plugin_base.before_test,
item,
item.module.__name__,
item.cls,
item.name,
)
_current_report = None
def pytest_runtest_logreport(report):
global _current_report
if report.when == "call":
_current_report = report
@pytest.fixture(scope="class")
def setup_class_methods(request):
from sqlalchemy.testing import asyncio
cls = request.cls
if hasattr(cls, "setup_test_class"):
asyncio._maybe_async(cls.setup_test_class)
yield
if hasattr(cls, "teardown_test_class"):
asyncio._maybe_async(cls.teardown_test_class)
asyncio._maybe_async(plugin_base.stop_test_class, cls)
@pytest.fixture(scope="function")
def setup_test_methods(request):
from sqlalchemy.testing import asyncio
# called for each test
self = request.instance
# before this fixture runs:
# 1. function level "autouse" fixtures under py3k (examples: TablesTest
# define tables / data, MappedTest define tables / mappers / data)
# 2. was for p2k. no longer applies
# 3. run outer xdist-style setup
if hasattr(self, "setup_test"):
asyncio._maybe_async(self.setup_test)
# alembic test suite is using setUp and tearDown
# xdist methods; support these in the test suite
# for the near term
if hasattr(self, "setUp"):
asyncio._maybe_async(self.setUp)
# inside the yield:
# 4. function level fixtures defined on test functions themselves,
# e.g. "connection", "metadata" run next
# 5. pytest hook pytest_runtest_call then runs
# 6. test itself runs
yield
# yield finishes:
# 7. function level fixtures defined on test functions
# themselves, e.g. "connection" rolls back the transaction, "metadata"
# emits drop all
# 8. pytest hook pytest_runtest_teardown hook runs, this is associated
# with fixtures close all sessions, provisioning.stop_test_class(),
# engines.testing_reaper -> ensure all connection pool connections
# are returned, engines created by testing_engine that aren't the
# config engine are disposed
asyncio._maybe_async(plugin_base.after_test_fixtures, self)
# 10. run xdist-style teardown
if hasattr(self, "tearDown"):
asyncio._maybe_async(self.tearDown)
if hasattr(self, "teardown_test"):
asyncio._maybe_async(self.teardown_test)
# 11. was for p2k. no longer applies
# 12. function level "autouse" fixtures under py3k (examples: TablesTest /
# MappedTest delete table data, possibly drop tables and clear mappers
# depending on the flags defined by the test class)
def _pytest_fn_decorator(target):
"""Port of langhelpers.decorator with pytest-specific tricks."""
from sqlalchemy.util.langhelpers import format_argspec_plus
from sqlalchemy.util.compat import inspect_getfullargspec
def _exec_code_in_env(code, env, fn_name):
# note this is affected by "from __future__ import annotations" at
# the top; exec'ed code will use non-evaluated annotations
# which allows us to be more flexible with code rendering
# in format_argpsec_plus()
exec(code, env)
return env[fn_name]
def decorate(fn, add_positional_parameters=()):
spec = inspect_getfullargspec(fn)
if add_positional_parameters:
spec.args.extend(add_positional_parameters)
metadata = dict(
__target_fn="__target_fn", __orig_fn="__orig_fn", name=fn.__name__
)
metadata.update(format_argspec_plus(spec, grouped=False))
code = (
"""\
def %(name)s%(grouped_args)s:
return %(__target_fn)s(%(__orig_fn)s, %(apply_kw)s)
"""
% metadata
)
decorated = _exec_code_in_env(
code, {"__target_fn": target, "__orig_fn": fn}, fn.__name__
)
if not add_positional_parameters:
decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
decorated.__wrapped__ = fn
return update_wrapper(decorated, fn)
else:
# this is the pytest hacky part. don't do a full update wrapper
# because pytest is really being sneaky about finding the args
# for the wrapped function
decorated.__module__ = fn.__module__
decorated.__name__ = fn.__name__
if hasattr(fn, "pytestmark"):
decorated.pytestmark = fn.pytestmark
return decorated
return decorate
class PytestFixtureFunctions(plugin_base.FixtureFunctions):
def skip_test_exception(self, *arg, **kw):
return pytest.skip.Exception(*arg, **kw)
@property
def add_to_marker(self):
return pytest.mark
def mark_base_test_class(self):
return pytest.mark.usefixtures(
"setup_class_methods", "setup_test_methods"
)
_combination_id_fns = {
"i": lambda obj: obj,
"r": repr,
"s": str,
"n": lambda obj: (
obj.__name__ if hasattr(obj, "__name__") else type(obj).__name__
),
}
def combinations(self, *arg_sets, **kw):
"""Facade for pytest.mark.parametrize.
Automatically derives argument names from the callable which in our
case is always a method on a class with positional arguments.
ids for parameter sets are derived using an optional template.
"""
from sqlalchemy.testing import exclusions
if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"):
arg_sets = list(arg_sets[0])
argnames = kw.pop("argnames", None)
def _filter_exclusions(args):
result = []
gathered_exclusions = []
for a in args:
if isinstance(a, exclusions.compound):
gathered_exclusions.append(a)
else:
result.append(a)
return result, gathered_exclusions
id_ = kw.pop("id_", None)
tobuild_pytest_params = []
has_exclusions = False
if id_:
_combination_id_fns = self._combination_id_fns
# because itemgetter is not consistent for one argument vs.
# multiple, make it multiple in all cases and use a slice
# to omit the first argument
_arg_getter = operator.itemgetter(
0,
*[
idx
for idx, char in enumerate(id_)
if char in ("n", "r", "s", "a")
],
)
fns = [
(operator.itemgetter(idx), _combination_id_fns[char])
for idx, char in enumerate(id_)
if char in _combination_id_fns
]
for arg in arg_sets:
if not isinstance(arg, tuple):
arg = (arg,)
fn_params, param_exclusions = _filter_exclusions(arg)
parameters = _arg_getter(fn_params)[1:]
if param_exclusions:
has_exclusions = True
tobuild_pytest_params.append(
(
parameters,
param_exclusions,
"-".join(
comb_fn(getter(arg)) for getter, comb_fn in fns
),
)
)
else:
for arg in arg_sets:
if not isinstance(arg, tuple):
arg = (arg,)
fn_params, param_exclusions = _filter_exclusions(arg)
if param_exclusions:
has_exclusions = True
tobuild_pytest_params.append(
(fn_params, param_exclusions, None)
)
pytest_params = []
for parameters, param_exclusions, id_ in tobuild_pytest_params:
if has_exclusions:
parameters += (param_exclusions,)
param = pytest.param(*parameters, id=id_)
pytest_params.append(param)
def decorate(fn):
if inspect.isclass(fn):
if has_exclusions:
raise NotImplementedError(
"exclusions not supported for class level combinations"
)
if "_sa_parametrize" not in fn.__dict__:
fn._sa_parametrize = []
fn._sa_parametrize.append((argnames, pytest_params))
return fn
else:
_fn_argnames = inspect.getfullargspec(fn).args[1:]
if argnames is None:
_argnames = _fn_argnames
else:
_argnames = re.split(r", *", argnames)
if has_exclusions:
existing_exl = sum(
1 for n in _fn_argnames if n.startswith("_exclusions")
)
current_exclusion_name = f"_exclusions_{existing_exl}"
_argnames += [current_exclusion_name]
@_pytest_fn_decorator
def check_exclusions(fn, *args, **kw):
_exclusions = args[-1]
if _exclusions:
exlu = exclusions.compound().add(*_exclusions)
fn = exlu(fn)
return fn(*args[:-1], **kw)
fn = check_exclusions(
fn, add_positional_parameters=(current_exclusion_name,)
)
return pytest.mark.parametrize(_argnames, pytest_params)(fn)
return decorate
def param_ident(self, *parameters):
ident = parameters[0]
return pytest.param(*parameters[1:], id=ident)
def fixture(self, *arg, **kw):
from sqlalchemy.testing import config
from sqlalchemy.testing import asyncio
# wrapping pytest.fixture function. determine if
# decorator was called as @fixture or @fixture().
if len(arg) > 0 and callable(arg[0]):
# was called as @fixture(), we have the function to wrap.
fn = arg[0]
arg = arg[1:]
else:
# was called as @fixture, don't have the function yet.
fn = None
# create a pytest.fixture marker. because the fn is not being
# passed, this is always a pytest.FixtureFunctionMarker()
# object (or whatever pytest is calling it when you read this)
# that is waiting for a function.
fixture = pytest.fixture(*arg, **kw)
# now apply wrappers to the function, including fixture itself
def wrap(fn):
if config.any_async:
fn = asyncio._maybe_async_wrapper(fn)
# other wrappers may be added here
# now apply FixtureFunctionMarker
fn = fixture(fn)
return fn
if fn:
return wrap(fn)
else:
return wrap
def get_current_test_name(self):
return os.environ.get("PYTEST_CURRENT_TEST")
def async_test(self, fn):
from sqlalchemy.testing import asyncio
@_pytest_fn_decorator
def decorate(fn, *args, **kwargs):
asyncio._run_coroutine_function(fn, *args, **kwargs)
return decorate(fn)

View File

@ -0,0 +1,324 @@
# testing/profiling.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""Profiling support for unit and performance tests.
These are special purpose profiling methods which operate
in a more fine-grained way than nose's profiling plugin.
"""
from __future__ import annotations
import collections
import contextlib
import os
import platform
import pstats
import re
import sys
from . import config
from .util import gc_collect
from ..util import has_compiled_ext
try:
import cProfile
except ImportError:
cProfile = None
_profile_stats = None
"""global ProfileStatsFileInstance.
plugin_base assigns this at the start of all tests.
"""
_current_test = None
"""String id of current test.
plugin_base assigns this at the start of each test using
_start_current_test.
"""
def _start_current_test(id_):
global _current_test
_current_test = id_
if _profile_stats.force_write:
_profile_stats.reset_count()
class ProfileStatsFile:
"""Store per-platform/fn profiling results in a file.
There was no json module available when this was written, but now
the file format which is very deterministically line oriented is kind of
handy in any case for diffs and merges.
"""
def __init__(self, filename, sort="cumulative", dump=None):
self.force_write = (
config.options is not None and config.options.force_write_profiles
)
self.write = self.force_write or (
config.options is not None and config.options.write_profiles
)
self.fname = os.path.abspath(filename)
self.short_fname = os.path.split(self.fname)[-1]
self.data = collections.defaultdict(
lambda: collections.defaultdict(dict)
)
self.dump = dump
self.sort = sort
self._read()
if self.write:
# rewrite for the case where features changed,
# etc.
self._write()
@property
def platform_key(self):
dbapi_key = config.db.name + "_" + config.db.driver
if config.db.name == "sqlite" and config.db.dialect._is_url_file_db(
config.db.url
):
dbapi_key += "_file"
# keep it at 2.7, 3.1, 3.2, etc. for now.
py_version = ".".join([str(v) for v in sys.version_info[0:2]])
platform_tokens = [
platform.machine(),
platform.system().lower(),
platform.python_implementation().lower(),
py_version,
dbapi_key,
]
platform_tokens.append("dbapiunicode")
_has_cext = has_compiled_ext()
platform_tokens.append(_has_cext and "cextensions" or "nocextensions")
return "_".join(platform_tokens)
def has_stats(self):
test_key = _current_test
return (
test_key in self.data and self.platform_key in self.data[test_key]
)
def result(self, callcount):
test_key = _current_test
per_fn = self.data[test_key]
per_platform = per_fn[self.platform_key]
if "counts" not in per_platform:
per_platform["counts"] = counts = []
else:
counts = per_platform["counts"]
if "current_count" not in per_platform:
per_platform["current_count"] = current_count = 0
else:
current_count = per_platform["current_count"]
has_count = len(counts) > current_count
if not has_count:
counts.append(callcount)
if self.write:
self._write()
result = None
else:
result = per_platform["lineno"], counts[current_count]
per_platform["current_count"] += 1
return result
def reset_count(self):
test_key = _current_test
# since self.data is a defaultdict, don't access a key
# if we don't know it's there first.
if test_key not in self.data:
return
per_fn = self.data[test_key]
if self.platform_key not in per_fn:
return
per_platform = per_fn[self.platform_key]
if "counts" in per_platform:
per_platform["counts"][:] = []
def replace(self, callcount):
test_key = _current_test
per_fn = self.data[test_key]
per_platform = per_fn[self.platform_key]
counts = per_platform["counts"]
current_count = per_platform["current_count"]
if current_count < len(counts):
counts[current_count - 1] = callcount
else:
counts[-1] = callcount
if self.write:
self._write()
def _header(self):
return (
"# %s\n"
"# This file is written out on a per-environment basis.\n"
"# For each test in aaa_profiling, the corresponding "
"function and \n"
"# environment is located within this file. "
"If it doesn't exist,\n"
"# the test is skipped.\n"
"# If a callcount does exist, it is compared "
"to what we received. \n"
"# assertions are raised if the counts do not match.\n"
"# \n"
"# To add a new callcount test, apply the function_call_count \n"
"# decorator and re-run the tests using the --write-profiles \n"
"# option - this file will be rewritten including the new count.\n"
"# \n"
) % (self.fname)
def _read(self):
try:
profile_f = open(self.fname)
except OSError:
return
for lineno, line in enumerate(profile_f):
line = line.strip()
if not line or line.startswith("#"):
continue
test_key, platform_key, counts = line.split()
per_fn = self.data[test_key]
per_platform = per_fn[platform_key]
c = [int(count) for count in counts.split(",")]
per_platform["counts"] = c
per_platform["lineno"] = lineno + 1
per_platform["current_count"] = 0
profile_f.close()
def _write(self):
print("Writing profile file %s" % self.fname)
profile_f = open(self.fname, "w")
profile_f.write(self._header())
for test_key in sorted(self.data):
per_fn = self.data[test_key]
profile_f.write("\n# TEST: %s\n\n" % test_key)
for platform_key in sorted(per_fn):
per_platform = per_fn[platform_key]
c = ",".join(str(count) for count in per_platform["counts"])
profile_f.write("%s %s %s\n" % (test_key, platform_key, c))
profile_f.close()
def function_call_count(variance=0.05, times=1, warmup=0):
"""Assert a target for a test case's function call count.
The main purpose of this assertion is to detect changes in
callcounts for various functions - the actual number is not as important.
Callcounts are stored in a file keyed to Python version and OS platform
information. This file is generated automatically for new tests,
and versioned so that unexpected changes in callcounts will be detected.
"""
# use signature-rewriting decorator function so that pytest fixtures
# still work on py27. In Py3, update_wrapper() alone is good enough,
# likely due to the introduction of __signature__.
from sqlalchemy.util import decorator
@decorator
def wrap(fn, *args, **kw):
for warm in range(warmup):
fn(*args, **kw)
timerange = range(times)
with count_functions(variance=variance):
for time in timerange:
rv = fn(*args, **kw)
return rv
return wrap
@contextlib.contextmanager
def count_functions(variance=0.05):
if cProfile is None:
raise config._skip_test_exception("cProfile is not installed")
if not _profile_stats.has_stats() and not _profile_stats.write:
config.skip_test(
"No profiling stats available on this "
"platform for this function. Run tests with "
"--write-profiles to add statistics to %s for "
"this platform." % _profile_stats.short_fname
)
gc_collect()
pr = cProfile.Profile()
pr.enable()
# began = time.time()
yield
# ended = time.time()
pr.disable()
# s = StringIO()
stats = pstats.Stats(pr, stream=sys.stdout)
# timespent = ended - began
callcount = stats.total_calls
expected = _profile_stats.result(callcount)
if expected is None:
expected_count = None
else:
line_no, expected_count = expected
print("Pstats calls: %d Expected %s" % (callcount, expected_count))
stats.sort_stats(*re.split(r"[, ]", _profile_stats.sort))
stats.print_stats()
if _profile_stats.dump:
base, ext = os.path.splitext(_profile_stats.dump)
test_name = _current_test.split(".")[-1]
dumpfile = "%s_%s%s" % (base, test_name, ext or ".profile")
stats.dump_stats(dumpfile)
print("Dumped stats to file %s" % dumpfile)
# stats.print_callers()
if _profile_stats.force_write:
_profile_stats.replace(callcount)
elif expected_count:
deviance = int(callcount * variance)
failed = abs(callcount - expected_count) > deviance
if failed:
if _profile_stats.write:
_profile_stats.replace(callcount)
else:
raise AssertionError(
"Adjusted function call count %s not within %s%% "
"of expected %s, platform %s. Rerun with "
"--write-profiles to "
"regenerate this callcount."
% (
callcount,
(variance * 100),
expected_count,
_profile_stats.platform_key,
)
)

View File

@ -0,0 +1,502 @@
# testing/provision.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import collections
import logging
from . import config
from . import engines
from . import util
from .. import exc
from .. import inspect
from ..engine import url as sa_url
from ..sql import ddl
from ..sql import schema
log = logging.getLogger(__name__)
FOLLOWER_IDENT = None
class register:
def __init__(self, decorator=None):
self.fns = {}
self.decorator = decorator
@classmethod
def init(cls, fn):
return register().for_db("*")(fn)
@classmethod
def init_decorator(cls, decorator):
return register(decorator).for_db("*")
def for_db(self, *dbnames):
def decorate(fn):
if self.decorator:
fn = self.decorator(fn)
for dbname in dbnames:
self.fns[dbname] = fn
return self
return decorate
def __call__(self, cfg, *arg, **kw):
if isinstance(cfg, str):
url = sa_url.make_url(cfg)
elif isinstance(cfg, sa_url.URL):
url = cfg
else:
url = cfg.db.url
backend = url.get_backend_name()
if backend in self.fns:
return self.fns[backend](cfg, *arg, **kw)
else:
return self.fns["*"](cfg, *arg, **kw)
def create_follower_db(follower_ident):
for cfg in _configs_for_db_operation():
log.info("CREATE database %s, URI %r", follower_ident, cfg.db.url)
create_db(cfg, cfg.db, follower_ident)
def setup_config(db_url, options, file_config, follower_ident):
# load the dialect, which should also have it set up its provision
# hooks
dialect = sa_url.make_url(db_url).get_dialect()
dialect.load_provisioning()
if follower_ident:
db_url = follower_url_from_main(db_url, follower_ident)
db_opts = {}
update_db_opts(db_url, db_opts, options)
db_opts["scope"] = "global"
eng = engines.testing_engine(db_url, db_opts)
post_configure_engine(db_url, eng, follower_ident)
eng.connect().close()
cfg = config.Config.register(eng, db_opts, options, file_config)
# a symbolic name that tests can use if they need to disambiguate
# names across databases
if follower_ident:
config.ident = follower_ident
if follower_ident:
configure_follower(cfg, follower_ident)
return cfg
def drop_follower_db(follower_ident):
for cfg in _configs_for_db_operation():
log.info("DROP database %s, URI %r", follower_ident, cfg.db.url)
drop_db(cfg, cfg.db, follower_ident)
def generate_db_urls(db_urls, extra_drivers):
"""Generate a set of URLs to test given configured URLs plus additional
driver names.
Given:
.. sourcecode:: text
--dburi postgresql://db1 \
--dburi postgresql://db2 \
--dburi postgresql://db2 \
--dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true
Noting that the default postgresql driver is psycopg2, the output
would be:
.. sourcecode:: text
postgresql+psycopg2://db1
postgresql+asyncpg://db1
postgresql+psycopg2://db2
postgresql+psycopg2://db3
That is, for the driver in a --dburi, we want to keep that and use that
driver for each URL it's part of . For a driver that is only
in --dbdrivers, we want to use it just once for one of the URLs.
for a driver that is both coming from --dburi as well as --dbdrivers,
we want to keep it in that dburi.
Driver specific query options can be specified by added them to the
driver name. For example, to enable the async fallback option for
asyncpg::
.. sourcecode:: text
--dburi postgresql://db1 \
--dbdriver=asyncpg?async_fallback=true
"""
urls = set()
backend_to_driver_we_already_have = collections.defaultdict(set)
urls_plus_dialects = [
(url_obj, url_obj.get_dialect())
for url_obj in [sa_url.make_url(db_url) for db_url in db_urls]
]
for url_obj, dialect in urls_plus_dialects:
# use get_driver_name instead of dialect.driver to account for
# "_async" virtual drivers like oracledb and psycopg
driver_name = url_obj.get_driver_name()
backend_to_driver_we_already_have[dialect.name].add(driver_name)
backend_to_driver_we_need = {}
for url_obj, dialect in urls_plus_dialects:
backend = dialect.name
dialect.load_provisioning()
if backend not in backend_to_driver_we_need:
backend_to_driver_we_need[backend] = extra_per_backend = set(
extra_drivers
).difference(backend_to_driver_we_already_have[backend])
else:
extra_per_backend = backend_to_driver_we_need[backend]
for driver_url in _generate_driver_urls(url_obj, extra_per_backend):
if driver_url in urls:
continue
urls.add(driver_url)
yield driver_url
def _generate_driver_urls(url, extra_drivers):
main_driver = url.get_driver_name()
extra_drivers.discard(main_driver)
url = generate_driver_url(url, main_driver, "")
yield url
for drv in list(extra_drivers):
if "?" in drv:
driver_only, query_str = drv.split("?", 1)
else:
driver_only = drv
query_str = None
new_url = generate_driver_url(url, driver_only, query_str)
if new_url:
extra_drivers.remove(drv)
yield new_url
@register.init
def generate_driver_url(url, driver, query_str):
backend = url.get_backend_name()
new_url = url.set(
drivername="%s+%s" % (backend, driver),
)
if query_str:
new_url = new_url.update_query_string(query_str)
try:
new_url.get_dialect()
except exc.NoSuchModuleError:
return None
else:
return new_url
def _configs_for_db_operation():
hosts = set()
for cfg in config.Config.all_configs():
cfg.db.dispose()
for cfg in config.Config.all_configs():
url = cfg.db.url
backend = url.get_backend_name()
host_conf = (backend, url.username, url.host, url.database)
if host_conf not in hosts:
yield cfg
hosts.add(host_conf)
for cfg in config.Config.all_configs():
cfg.db.dispose()
@register.init
def drop_all_schema_objects_pre_tables(cfg, eng):
pass
@register.init
def drop_all_schema_objects_post_tables(cfg, eng):
pass
def drop_all_schema_objects(cfg, eng):
drop_all_schema_objects_pre_tables(cfg, eng)
drop_views(cfg, eng)
if config.requirements.materialized_views.enabled:
drop_materialized_views(cfg, eng)
inspector = inspect(eng)
consider_schemas = (None,)
if config.requirements.schemas.enabled_for_config(cfg):
consider_schemas += (cfg.test_schema, cfg.test_schema_2)
util.drop_all_tables(eng, inspector, consider_schemas=consider_schemas)
drop_all_schema_objects_post_tables(cfg, eng)
if config.requirements.sequences.enabled_for_config(cfg):
with eng.begin() as conn:
for seq in inspector.get_sequence_names():
conn.execute(ddl.DropSequence(schema.Sequence(seq)))
if config.requirements.schemas.enabled_for_config(cfg):
for schema_name in [cfg.test_schema, cfg.test_schema_2]:
for seq in inspector.get_sequence_names(
schema=schema_name
):
conn.execute(
ddl.DropSequence(
schema.Sequence(seq, schema=schema_name)
)
)
def drop_views(cfg, eng):
inspector = inspect(eng)
try:
view_names = inspector.get_view_names()
except NotImplementedError:
pass
else:
with eng.begin() as conn:
for vname in view_names:
conn.execute(
ddl._DropView(schema.Table(vname, schema.MetaData()))
)
if config.requirements.schemas.enabled_for_config(cfg):
try:
view_names = inspector.get_view_names(schema=cfg.test_schema)
except NotImplementedError:
pass
else:
with eng.begin() as conn:
for vname in view_names:
conn.execute(
ddl._DropView(
schema.Table(
vname,
schema.MetaData(),
schema=cfg.test_schema,
)
)
)
def drop_materialized_views(cfg, eng):
inspector = inspect(eng)
mview_names = inspector.get_materialized_view_names()
with eng.begin() as conn:
for vname in mview_names:
conn.exec_driver_sql(f"DROP MATERIALIZED VIEW {vname}")
if config.requirements.schemas.enabled_for_config(cfg):
mview_names = inspector.get_materialized_view_names(
schema=cfg.test_schema
)
with eng.begin() as conn:
for vname in mview_names:
conn.exec_driver_sql(
f"DROP MATERIALIZED VIEW {cfg.test_schema}.{vname}"
)
@register.init
def create_db(cfg, eng, ident):
"""Dynamically create a database for testing.
Used when a test run will employ multiple processes, e.g., when run
via `tox` or `pytest -n4`.
"""
raise NotImplementedError(
"no DB creation routine for cfg: %s" % (eng.url,)
)
@register.init
def drop_db(cfg, eng, ident):
"""Drop a database that we dynamically created for testing."""
raise NotImplementedError("no DB drop routine for cfg: %s" % (eng.url,))
def _adapt_update_db_opts(fn):
insp = util.inspect_getfullargspec(fn)
if len(insp.args) == 3:
return fn
else:
return lambda db_url, db_opts, _options: fn(db_url, db_opts)
@register.init_decorator(_adapt_update_db_opts)
def update_db_opts(db_url, db_opts, options):
"""Set database options (db_opts) for a test database that we created."""
@register.init
def post_configure_engine(url, engine, follower_ident):
"""Perform extra steps after configuring an engine for testing.
(For the internal dialects, currently only used by sqlite, oracle, mssql)
"""
@register.init
def follower_url_from_main(url, ident):
"""Create a connection URL for a dynamically-created test database.
:param url: the connection URL specified when the test run was invoked
:param ident: the pytest-xdist "worker identifier" to be used as the
database name
"""
url = sa_url.make_url(url)
return url.set(database=ident)
@register.init
def configure_follower(cfg, ident):
"""Create dialect-specific config settings for a follower database."""
pass
@register.init
def run_reap_dbs(url, ident):
"""Remove databases that were created during the test process, after the
process has ended.
This is an optional step that is invoked for certain backends that do not
reliably release locks on the database as long as a process is still in
use. For the internal dialects, this is currently only necessary for
mssql and oracle.
"""
def reap_dbs(idents_file):
log.info("Reaping databases...")
urls = collections.defaultdict(set)
idents = collections.defaultdict(set)
dialects = {}
with open(idents_file) as file_:
for line in file_:
line = line.strip()
db_name, db_url = line.split(" ")
url_obj = sa_url.make_url(db_url)
if db_name not in dialects:
dialects[db_name] = url_obj.get_dialect()
dialects[db_name].load_provisioning()
url_key = (url_obj.get_backend_name(), url_obj.host)
urls[url_key].add(db_url)
idents[url_key].add(db_name)
for url_key in urls:
url = list(urls[url_key])[0]
ident = idents[url_key]
run_reap_dbs(url, ident)
@register.init
def temp_table_keyword_args(cfg, eng):
"""Specify keyword arguments for creating a temporary Table.
Dialect-specific implementations of this method will return the
kwargs that are passed to the Table method when creating a temporary
table for testing, e.g., in the define_temp_tables method of the
ComponentReflectionTest class in suite/test_reflection.py
"""
raise NotImplementedError(
"no temp table keyword args routine for cfg: %s" % (eng.url,)
)
@register.init
def prepare_for_drop_tables(config, connection):
pass
@register.init
def stop_test_class_outside_fixtures(config, db, testcls):
pass
@register.init
def get_temp_table_name(cfg, eng, base_name):
"""Specify table name for creating a temporary Table.
Dialect-specific implementations of this method will return the
name to use when creating a temporary table for testing,
e.g., in the define_temp_tables method of the
ComponentReflectionTest class in suite/test_reflection.py
Default to just the base name since that's what most dialects will
use. The mssql dialect's implementation will need a "#" prepended.
"""
return base_name
@register.init
def set_default_schema_on_connection(cfg, dbapi_connection, schema_name):
raise NotImplementedError(
"backend does not implement a schema name set function: %s"
% (cfg.db.url,)
)
@register.init
def upsert(
cfg, table, returning, *, set_lambda=None, sort_by_parameter_order=False
):
"""return the backends insert..on conflict / on dupe etc. construct.
while we should add a backend-neutral upsert construct as well, such as
insert().upsert(), it's important that we continue to test the
backend-specific insert() constructs since if we do implement
insert().upsert(), that would be using a different codepath for the things
we need to test like insertmanyvalues, etc.
"""
raise NotImplementedError(
f"backend does not include an upsert implementation: {cfg.db.url}"
)
@register.init
def normalize_sequence(cfg, sequence):
"""Normalize sequence parameters for dialect that don't start with 1
by default.
The default implementation does nothing
"""
return sequence

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,224 @@
# testing/schema.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import sys
from . import config
from . import exclusions
from .. import event
from .. import schema
from .. import types as sqltypes
from ..orm import mapped_column as _orm_mapped_column
from ..util import OrderedDict
__all__ = ["Table", "Column"]
table_options = {}
def Table(*args, **kw) -> schema.Table:
"""A schema.Table wrapper/hook for dialect-specific tweaks."""
test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
kw.update(table_options)
if exclusions.against(config._current, "mysql"):
if (
"mysql_engine" not in kw
and "mysql_type" not in kw
and "autoload_with" not in kw
):
if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts:
kw["mysql_engine"] = "InnoDB"
else:
# there are in fact test fixtures that rely upon MyISAM,
# due to MySQL / MariaDB having poor FK behavior under innodb,
# such as a self-referential table can't be deleted from at
# once without attending to per-row dependencies. We'd need to
# add special steps to some fixtures if we want to not
# explicitly state MyISAM here
kw["mysql_engine"] = "MyISAM"
elif exclusions.against(config._current, "mariadb"):
if (
"mariadb_engine" not in kw
and "mariadb_type" not in kw
and "autoload_with" not in kw
):
if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts:
kw["mariadb_engine"] = "InnoDB"
else:
kw["mariadb_engine"] = "MyISAM"
return schema.Table(*args, **kw)
def mapped_column(*args, **kw):
"""An orm.mapped_column wrapper/hook for dialect-specific tweaks."""
return _schema_column(_orm_mapped_column, args, kw)
def Column(*args, **kw):
"""A schema.Column wrapper/hook for dialect-specific tweaks."""
return _schema_column(schema.Column, args, kw)
def _schema_column(factory, args, kw):
test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
if not config.requirements.foreign_key_ddl.enabled_for_config(config):
args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)]
construct = factory(*args, **kw)
if factory is schema.Column:
col = construct
else:
col = construct.column
if test_opts.get("test_needs_autoincrement", False) and kw.get(
"primary_key", False
):
if col.default is None and col.server_default is None:
col.autoincrement = True
# allow any test suite to pick up on this
col.info["test_needs_autoincrement"] = True
# hardcoded rule for oracle; this should
# be moved out
if exclusions.against(config._current, "oracle"):
def add_seq(c, tbl):
c._init_items(
schema.Sequence(
_truncate_name(
config.db.dialect, tbl.name + "_" + c.name + "_seq"
),
optional=True,
)
)
event.listen(col, "after_parent_attach", add_seq, propagate=True)
return construct
class eq_type_affinity:
"""Helper to compare types inside of datastructures based on affinity.
E.g.::
eq_(
inspect(connection).get_columns("foo"),
[
{
"name": "id",
"type": testing.eq_type_affinity(sqltypes.INTEGER),
"nullable": False,
"default": None,
"autoincrement": False,
},
{
"name": "data",
"type": testing.eq_type_affinity(sqltypes.NullType),
"nullable": True,
"default": None,
"autoincrement": False,
},
],
)
"""
def __init__(self, target):
self.target = sqltypes.to_instance(target)
def __eq__(self, other):
return self.target._type_affinity is other._type_affinity
def __ne__(self, other):
return self.target._type_affinity is not other._type_affinity
class eq_compile_type:
"""similar to eq_type_affinity but uses compile"""
def __init__(self, target):
self.target = target
def __eq__(self, other):
return self.target == other.compile()
def __ne__(self, other):
return self.target != other.compile()
class eq_clause_element:
"""Helper to compare SQL structures based on compare()"""
def __init__(self, target):
self.target = target
def __eq__(self, other):
return self.target.compare(other)
def __ne__(self, other):
return not self.target.compare(other)
def _truncate_name(dialect, name):
if len(name) > dialect.max_identifier_length:
return (
name[0 : max(dialect.max_identifier_length - 6, 0)]
+ "_"
+ hex(hash(name) % 64)[2:]
)
else:
return name
def pep435_enum(name):
# Implements PEP 435 in the minimal fashion needed by SQLAlchemy
__members__ = OrderedDict()
def __init__(self, name, value, alias=None):
self.name = name
self.value = value
self.__members__[name] = self
value_to_member[value] = self
setattr(self.__class__, name, self)
if alias:
self.__members__[alias] = self
setattr(self.__class__, alias, self)
value_to_member = {}
@classmethod
def get(cls, value):
return value_to_member[value]
someenum = type(
name,
(object,),
{"__members__": __members__, "__init__": __init__, "get": get},
)
# getframe() trick for pickling I don't understand courtesy
# Python namedtuple()
try:
module = sys._getframe(1).f_globals.get("__name__", "__main__")
except (AttributeError, ValueError):
pass
if module is not None:
someenum.__module__ = module
return someenum

View File

@ -0,0 +1,19 @@
# testing/suite/__init__.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from .test_cte import * # noqa
from .test_ddl import * # noqa
from .test_deprecations import * # noqa
from .test_dialect import * # noqa
from .test_insert import * # noqa
from .test_reflection import * # noqa
from .test_results import * # noqa
from .test_rowcount import * # noqa
from .test_select import * # noqa
from .test_sequence import * # noqa
from .test_types import * # noqa
from .test_unicode_ddl import * # noqa
from .test_update_delete import * # noqa

View File

@ -0,0 +1,211 @@
# testing/suite/test_cte.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from .. import fixtures
from ..assertions import eq_
from ..schema import Column
from ..schema import Table
from ... import ForeignKey
from ... import Integer
from ... import select
from ... import String
from ... import testing
class CTETest(fixtures.TablesTest):
__backend__ = True
__requires__ = ("ctes",)
run_inserts = "each"
run_deletes = "each"
@classmethod
def define_tables(cls, metadata):
Table(
"some_table",
metadata,
Column("id", Integer, primary_key=True),
Column("data", String(50)),
Column("parent_id", ForeignKey("some_table.id")),
)
Table(
"some_other_table",
metadata,
Column("id", Integer, primary_key=True),
Column("data", String(50)),
Column("parent_id", Integer),
)
@classmethod
def insert_data(cls, connection):
connection.execute(
cls.tables.some_table.insert(),
[
{"id": 1, "data": "d1", "parent_id": None},
{"id": 2, "data": "d2", "parent_id": 1},
{"id": 3, "data": "d3", "parent_id": 1},
{"id": 4, "data": "d4", "parent_id": 3},
{"id": 5, "data": "d5", "parent_id": 3},
],
)
def test_select_nonrecursive_round_trip(self, connection):
some_table = self.tables.some_table
cte = (
select(some_table)
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
.cte("some_cte")
)
result = connection.execute(
select(cte.c.data).where(cte.c.data.in_(["d4", "d5"]))
)
eq_(result.fetchall(), [("d4",)])
def test_select_recursive_round_trip(self, connection):
some_table = self.tables.some_table
cte = (
select(some_table)
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
.cte("some_cte", recursive=True)
)
cte_alias = cte.alias("c1")
st1 = some_table.alias()
# note that SQL Server requires this to be UNION ALL,
# can't be UNION
cte = cte.union_all(
select(st1).where(st1.c.id == cte_alias.c.parent_id)
)
result = connection.execute(
select(cte.c.data)
.where(cte.c.data != "d2")
.order_by(cte.c.data.desc())
)
eq_(
result.fetchall(),
[("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)],
)
def test_insert_from_select_round_trip(self, connection):
some_table = self.tables.some_table
some_other_table = self.tables.some_other_table
cte = (
select(some_table)
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
.cte("some_cte")
)
connection.execute(
some_other_table.insert().from_select(
["id", "data", "parent_id"], select(cte)
)
)
eq_(
connection.execute(
select(some_other_table).order_by(some_other_table.c.id)
).fetchall(),
[(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)],
)
@testing.requires.ctes_with_update_delete
@testing.requires.update_from
def test_update_from_round_trip(self, connection):
some_table = self.tables.some_table
some_other_table = self.tables.some_other_table
connection.execute(
some_other_table.insert().from_select(
["id", "data", "parent_id"], select(some_table)
)
)
cte = (
select(some_table)
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
.cte("some_cte")
)
connection.execute(
some_other_table.update()
.values(parent_id=5)
.where(some_other_table.c.data == cte.c.data)
)
eq_(
connection.execute(
select(some_other_table).order_by(some_other_table.c.id)
).fetchall(),
[
(1, "d1", None),
(2, "d2", 5),
(3, "d3", 5),
(4, "d4", 5),
(5, "d5", 3),
],
)
@testing.requires.ctes_with_update_delete
@testing.requires.delete_from
def test_delete_from_round_trip(self, connection):
some_table = self.tables.some_table
some_other_table = self.tables.some_other_table
connection.execute(
some_other_table.insert().from_select(
["id", "data", "parent_id"], select(some_table)
)
)
cte = (
select(some_table)
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
.cte("some_cte")
)
connection.execute(
some_other_table.delete().where(
some_other_table.c.data == cte.c.data
)
)
eq_(
connection.execute(
select(some_other_table).order_by(some_other_table.c.id)
).fetchall(),
[(1, "d1", None), (5, "d5", 3)],
)
@testing.requires.ctes_with_update_delete
def test_delete_scalar_subq_round_trip(self, connection):
some_table = self.tables.some_table
some_other_table = self.tables.some_other_table
connection.execute(
some_other_table.insert().from_select(
["id", "data", "parent_id"], select(some_table)
)
)
cte = (
select(some_table)
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
.cte("some_cte")
)
connection.execute(
some_other_table.delete().where(
some_other_table.c.data
== select(cte.c.data)
.where(cte.c.id == some_other_table.c.id)
.scalar_subquery()
)
)
eq_(
connection.execute(
select(some_other_table).order_by(some_other_table.c.id)
).fetchall(),
[(1, "d1", None), (5, "d5", 3)],
)

View File

@ -0,0 +1,389 @@
# testing/suite/test_ddl.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
import random
from . import testing
from .. import config
from .. import fixtures
from .. import util
from ..assertions import eq_
from ..assertions import is_false
from ..assertions import is_true
from ..config import requirements
from ..schema import Table
from ... import CheckConstraint
from ... import Column
from ... import ForeignKeyConstraint
from ... import Index
from ... import inspect
from ... import Integer
from ... import schema
from ... import String
from ... import UniqueConstraint
class TableDDLTest(fixtures.TestBase):
__backend__ = True
def _simple_fixture(self, schema=None):
return Table(
"test_table",
self.metadata,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("data", String(50)),
schema=schema,
)
def _underscore_fixture(self):
return Table(
"_test_table",
self.metadata,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("_data", String(50)),
)
def _table_index_fixture(self, schema=None):
table = self._simple_fixture(schema=schema)
idx = Index("test_index", table.c.data)
return table, idx
def _simple_roundtrip(self, table):
with config.db.begin() as conn:
conn.execute(table.insert().values((1, "some data")))
result = conn.execute(table.select())
eq_(result.first(), (1, "some data"))
@requirements.create_table
@util.provide_metadata
def test_create_table(self):
table = self._simple_fixture()
table.create(config.db, checkfirst=False)
self._simple_roundtrip(table)
@requirements.create_table
@requirements.schemas
@util.provide_metadata
def test_create_table_schema(self):
table = self._simple_fixture(schema=config.test_schema)
table.create(config.db, checkfirst=False)
self._simple_roundtrip(table)
@requirements.drop_table
@util.provide_metadata
def test_drop_table(self):
table = self._simple_fixture()
table.create(config.db, checkfirst=False)
table.drop(config.db, checkfirst=False)
@requirements.create_table
@util.provide_metadata
def test_underscore_names(self):
table = self._underscore_fixture()
table.create(config.db, checkfirst=False)
self._simple_roundtrip(table)
@requirements.comment_reflection
@util.provide_metadata
def test_add_table_comment(self, connection):
table = self._simple_fixture()
table.create(connection, checkfirst=False)
table.comment = "a comment"
connection.execute(schema.SetTableComment(table))
eq_(
inspect(connection).get_table_comment("test_table"),
{"text": "a comment"},
)
@requirements.comment_reflection
@util.provide_metadata
def test_drop_table_comment(self, connection):
table = self._simple_fixture()
table.create(connection, checkfirst=False)
table.comment = "a comment"
connection.execute(schema.SetTableComment(table))
connection.execute(schema.DropTableComment(table))
eq_(
inspect(connection).get_table_comment("test_table"), {"text": None}
)
@requirements.table_ddl_if_exists
@util.provide_metadata
def test_create_table_if_not_exists(self, connection):
table = self._simple_fixture()
connection.execute(schema.CreateTable(table, if_not_exists=True))
is_true(inspect(connection).has_table("test_table"))
connection.execute(schema.CreateTable(table, if_not_exists=True))
@requirements.index_ddl_if_exists
@util.provide_metadata
def test_create_index_if_not_exists(self, connection):
table, idx = self._table_index_fixture()
connection.execute(schema.CreateTable(table, if_not_exists=True))
is_true(inspect(connection).has_table("test_table"))
is_false(
"test_index"
in [
ix["name"]
for ix in inspect(connection).get_indexes("test_table")
]
)
connection.execute(schema.CreateIndex(idx, if_not_exists=True))
is_true(
"test_index"
in [
ix["name"]
for ix in inspect(connection).get_indexes("test_table")
]
)
connection.execute(schema.CreateIndex(idx, if_not_exists=True))
@requirements.table_ddl_if_exists
@util.provide_metadata
def test_drop_table_if_exists(self, connection):
table = self._simple_fixture()
table.create(connection)
is_true(inspect(connection).has_table("test_table"))
connection.execute(schema.DropTable(table, if_exists=True))
is_false(inspect(connection).has_table("test_table"))
connection.execute(schema.DropTable(table, if_exists=True))
@requirements.index_ddl_if_exists
@util.provide_metadata
def test_drop_index_if_exists(self, connection):
table, idx = self._table_index_fixture()
table.create(connection)
is_true(
"test_index"
in [
ix["name"]
for ix in inspect(connection).get_indexes("test_table")
]
)
connection.execute(schema.DropIndex(idx, if_exists=True))
is_false(
"test_index"
in [
ix["name"]
for ix in inspect(connection).get_indexes("test_table")
]
)
connection.execute(schema.DropIndex(idx, if_exists=True))
class FutureTableDDLTest(fixtures.FutureEngineMixin, TableDDLTest):
pass
class LongNameBlowoutTest(fixtures.TestBase):
"""test the creation of a variety of DDL structures and ensure
label length limits pass on backends
"""
__backend__ = True
def fk(self, metadata, connection):
convention = {
"fk": "foreign_key_%(table_name)s_"
"%(column_0_N_name)s_"
"%(referred_table_name)s_"
+ (
"_".join(
"".join(random.choice("abcdef") for j in range(20))
for i in range(10)
)
),
}
metadata.naming_convention = convention
Table(
"a_things_with_stuff",
metadata,
Column("id_long_column_name", Integer, primary_key=True),
test_needs_fk=True,
)
cons = ForeignKeyConstraint(
["aid"], ["a_things_with_stuff.id_long_column_name"]
)
Table(
"b_related_things_of_value",
metadata,
Column(
"aid",
),
cons,
test_needs_fk=True,
)
actual_name = cons.name
metadata.create_all(connection)
if testing.requires.foreign_key_constraint_name_reflection.enabled:
insp = inspect(connection)
fks = insp.get_foreign_keys("b_related_things_of_value")
reflected_name = fks[0]["name"]
return actual_name, reflected_name
else:
return actual_name, None
def pk(self, metadata, connection):
convention = {
"pk": "primary_key_%(table_name)s_"
"%(column_0_N_name)s"
+ (
"_".join(
"".join(random.choice("abcdef") for j in range(30))
for i in range(10)
)
),
}
metadata.naming_convention = convention
a = Table(
"a_things_with_stuff",
metadata,
Column("id_long_column_name", Integer, primary_key=True),
Column("id_another_long_name", Integer, primary_key=True),
)
cons = a.primary_key
actual_name = cons.name
metadata.create_all(connection)
insp = inspect(connection)
pk = insp.get_pk_constraint("a_things_with_stuff")
reflected_name = pk["name"]
return actual_name, reflected_name
def ix(self, metadata, connection):
convention = {
"ix": "index_%(table_name)s_"
"%(column_0_N_name)s"
+ (
"_".join(
"".join(random.choice("abcdef") for j in range(30))
for i in range(10)
)
),
}
metadata.naming_convention = convention
a = Table(
"a_things_with_stuff",
metadata,
Column("id_long_column_name", Integer, primary_key=True),
Column("id_another_long_name", Integer),
)
cons = Index(None, a.c.id_long_column_name, a.c.id_another_long_name)
actual_name = cons.name
metadata.create_all(connection)
insp = inspect(connection)
ix = insp.get_indexes("a_things_with_stuff")
reflected_name = ix[0]["name"]
return actual_name, reflected_name
def uq(self, metadata, connection):
convention = {
"uq": "unique_constraint_%(table_name)s_"
"%(column_0_N_name)s"
+ (
"_".join(
"".join(random.choice("abcdef") for j in range(30))
for i in range(10)
)
),
}
metadata.naming_convention = convention
cons = UniqueConstraint("id_long_column_name", "id_another_long_name")
Table(
"a_things_with_stuff",
metadata,
Column("id_long_column_name", Integer, primary_key=True),
Column("id_another_long_name", Integer),
cons,
)
actual_name = cons.name
metadata.create_all(connection)
insp = inspect(connection)
uq = insp.get_unique_constraints("a_things_with_stuff")
reflected_name = uq[0]["name"]
return actual_name, reflected_name
def ck(self, metadata, connection):
convention = {
"ck": "check_constraint_%(table_name)s"
+ (
"_".join(
"".join(random.choice("abcdef") for j in range(30))
for i in range(10)
)
),
}
metadata.naming_convention = convention
cons = CheckConstraint("some_long_column_name > 5")
Table(
"a_things_with_stuff",
metadata,
Column("id_long_column_name", Integer, primary_key=True),
Column("some_long_column_name", Integer),
cons,
)
actual_name = cons.name
metadata.create_all(connection)
insp = inspect(connection)
ck = insp.get_check_constraints("a_things_with_stuff")
reflected_name = ck[0]["name"]
return actual_name, reflected_name
@testing.combinations(
("fk",),
("pk",),
("ix",),
("ck", testing.requires.check_constraint_reflection.as_skips()),
("uq", testing.requires.unique_constraint_reflection.as_skips()),
argnames="type_",
)
def test_long_convention_name(self, type_, metadata, connection):
actual_name, reflected_name = getattr(self, type_)(
metadata, connection
)
assert len(actual_name) > 255
if reflected_name is not None:
overlap = actual_name[0 : len(reflected_name)]
if len(overlap) < len(actual_name):
eq_(overlap[0:-5], reflected_name[0 : len(overlap) - 5])
else:
eq_(overlap, reflected_name)
__all__ = ("TableDDLTest", "FutureTableDDLTest", "LongNameBlowoutTest")

View File

@ -0,0 +1,153 @@
# testing/suite/test_deprecations.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from .. import fixtures
from ..assertions import eq_
from ..schema import Column
from ..schema import Table
from ... import Integer
from ... import select
from ... import testing
from ... import union
class DeprecatedCompoundSelectTest(fixtures.TablesTest):
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
"some_table",
metadata,
Column("id", Integer, primary_key=True),
Column("x", Integer),
Column("y", Integer),
)
@classmethod
def insert_data(cls, connection):
connection.execute(
cls.tables.some_table.insert(),
[
{"id": 1, "x": 1, "y": 2},
{"id": 2, "x": 2, "y": 3},
{"id": 3, "x": 3, "y": 4},
{"id": 4, "x": 4, "y": 5},
],
)
def _assert_result(self, conn, select, result, params=()):
eq_(conn.execute(select, params).fetchall(), result)
def test_plain_union(self, connection):
table = self.tables.some_table
s1 = select(table).where(table.c.id == 2)
s2 = select(table).where(table.c.id == 3)
u1 = union(s1, s2)
with testing.expect_deprecated(
"The SelectBase.c and SelectBase.columns "
"attributes are deprecated"
):
self._assert_result(
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
)
# note we've had to remove one use case entirely, which is this
# one. the Select gets its FROMS from the WHERE clause and the
# columns clause, but not the ORDER BY, which means the old ".c" system
# allowed you to "order_by(s.c.foo)" to get an unnamed column in the
# ORDER BY without adding the SELECT into the FROM and breaking the
# query. Users will have to adjust for this use case if they were doing
# it before.
def _dont_test_select_from_plain_union(self, connection):
table = self.tables.some_table
s1 = select(table).where(table.c.id == 2)
s2 = select(table).where(table.c.id == 3)
u1 = union(s1, s2).alias().select()
with testing.expect_deprecated(
"The SelectBase.c and SelectBase.columns "
"attributes are deprecated"
):
self._assert_result(
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
)
@testing.requires.order_by_col_from_union
@testing.requires.parens_in_union_contained_select_w_limit_offset
def test_limit_offset_selectable_in_unions(self, connection):
table = self.tables.some_table
s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id)
s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id)
u1 = union(s1, s2).limit(2)
with testing.expect_deprecated(
"The SelectBase.c and SelectBase.columns "
"attributes are deprecated"
):
self._assert_result(
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
)
@testing.requires.parens_in_union_contained_select_wo_limit_offset
def test_order_by_selectable_in_unions(self, connection):
table = self.tables.some_table
s1 = select(table).where(table.c.id == 2).order_by(table.c.id)
s2 = select(table).where(table.c.id == 3).order_by(table.c.id)
u1 = union(s1, s2).limit(2)
with testing.expect_deprecated(
"The SelectBase.c and SelectBase.columns "
"attributes are deprecated"
):
self._assert_result(
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
)
def test_distinct_selectable_in_unions(self, connection):
table = self.tables.some_table
s1 = select(table).where(table.c.id == 2).distinct()
s2 = select(table).where(table.c.id == 3).distinct()
u1 = union(s1, s2).limit(2)
with testing.expect_deprecated(
"The SelectBase.c and SelectBase.columns "
"attributes are deprecated"
):
self._assert_result(
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
)
def test_limit_offset_aliased_selectable_in_unions(self, connection):
table = self.tables.some_table
s1 = (
select(table)
.where(table.c.id == 2)
.limit(1)
.order_by(table.c.id)
.alias()
.select()
)
s2 = (
select(table)
.where(table.c.id == 3)
.limit(1)
.order_by(table.c.id)
.alias()
.select()
)
u1 = union(s1, s2).limit(2)
with testing.expect_deprecated(
"The SelectBase.c and SelectBase.columns "
"attributes are deprecated"
):
self._assert_result(
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
)

View File

@ -0,0 +1,740 @@
# testing/suite/test_dialect.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
import importlib
from . import testing
from .. import assert_raises
from .. import config
from .. import engines
from .. import eq_
from .. import fixtures
from .. import is_not_none
from .. import is_true
from .. import ne_
from .. import provide_metadata
from ..assertions import expect_raises
from ..assertions import expect_raises_message
from ..config import requirements
from ..provision import set_default_schema_on_connection
from ..schema import Column
from ..schema import Table
from ... import bindparam
from ... import dialects
from ... import event
from ... import exc
from ... import Integer
from ... import literal_column
from ... import select
from ... import String
from ...sql.compiler import Compiled
from ...util import inspect_getfullargspec
class PingTest(fixtures.TestBase):
__backend__ = True
def test_do_ping(self):
with testing.db.connect() as conn:
is_true(
testing.db.dialect.do_ping(conn.connection.dbapi_connection)
)
class ArgSignatureTest(fixtures.TestBase):
"""test that all visit_XYZ() in :class:`_sql.Compiler` subclasses have
``**kw``, for #8988.
This test uses runtime code inspection. Does not need to be a
``__backend__`` test as it only needs to run once provided all target
dialects have been imported.
For third party dialects, the suite would be run with that third
party as a "--dburi", which means its compiler classes will have been
imported by the time this test runs.
"""
def _all_subclasses(): # type: ignore # noqa
for d in dialects.__all__:
if not d.startswith("_"):
importlib.import_module("sqlalchemy.dialects.%s" % d)
stack = [Compiled]
while stack:
cls = stack.pop(0)
stack.extend(cls.__subclasses__())
yield cls
@testing.fixture(params=list(_all_subclasses()))
def all_subclasses(self, request):
yield request.param
def test_all_visit_methods_accept_kw(self, all_subclasses):
cls = all_subclasses
for k in cls.__dict__:
if k.startswith("visit_"):
meth = getattr(cls, k)
insp = inspect_getfullargspec(meth)
is_not_none(
insp.varkw,
f"Compiler visit method {cls.__name__}.{k}() does "
"not accommodate for **kw in its argument signature",
)
class ExceptionTest(fixtures.TablesTest):
"""Test basic exception wrapping.
DBAPIs vary a lot in exception behavior so to actually anticipate
specific exceptions from real round trips, we need to be conservative.
"""
run_deletes = "each"
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
"manual_pk",
metadata,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("data", String(50)),
)
@requirements.duplicate_key_raises_integrity_error
def test_integrity_error(self):
with config.db.connect() as conn:
trans = conn.begin()
conn.execute(
self.tables.manual_pk.insert(), {"id": 1, "data": "d1"}
)
assert_raises(
exc.IntegrityError,
conn.execute,
self.tables.manual_pk.insert(),
{"id": 1, "data": "d1"},
)
trans.rollback()
def test_exception_with_non_ascii(self):
with config.db.connect() as conn:
try:
# try to create an error message that likely has non-ascii
# characters in the DBAPI's message string. unfortunately
# there's no way to make this happen with some drivers like
# mysqlclient, pymysql. this at least does produce a non-
# ascii error message for cx_oracle, psycopg2
conn.execute(select(literal_column("méil")))
assert False
except exc.DBAPIError as err:
err_str = str(err)
assert str(err.orig) in str(err)
assert isinstance(err_str, str)
class IsolationLevelTest(fixtures.TestBase):
__backend__ = True
__requires__ = ("isolation_level",)
def _get_non_default_isolation_level(self):
levels = requirements.get_isolation_levels(config)
default = levels["default"]
supported = levels["supported"]
s = set(supported).difference(["AUTOCOMMIT", default])
if s:
return s.pop()
else:
config.skip_test("no non-default isolation level available")
def test_default_isolation_level(self):
eq_(
config.db.dialect.default_isolation_level,
requirements.get_isolation_levels(config)["default"],
)
def test_non_default_isolation_level(self):
non_default = self._get_non_default_isolation_level()
with config.db.connect() as conn:
existing = conn.get_isolation_level()
ne_(existing, non_default)
conn.execution_options(isolation_level=non_default)
eq_(conn.get_isolation_level(), non_default)
conn.dialect.reset_isolation_level(
conn.connection.dbapi_connection
)
eq_(conn.get_isolation_level(), existing)
def test_all_levels(self):
levels = requirements.get_isolation_levels(config)
all_levels = levels["supported"]
for level in set(all_levels).difference(["AUTOCOMMIT"]):
with config.db.connect() as conn:
conn.execution_options(isolation_level=level)
eq_(conn.get_isolation_level(), level)
trans = conn.begin()
trans.rollback()
eq_(conn.get_isolation_level(), level)
with config.db.connect() as conn:
eq_(
conn.get_isolation_level(),
levels["default"],
)
@testing.requires.get_isolation_level_values
def test_invalid_level_execution_option(self, connection_no_trans):
"""test for the new get_isolation_level_values() method"""
connection = connection_no_trans
with expect_raises_message(
exc.ArgumentError,
"Invalid value '%s' for isolation_level. "
"Valid isolation levels for '%s' are %s"
% (
"FOO",
connection.dialect.name,
", ".join(
requirements.get_isolation_levels(config)["supported"]
),
),
):
connection.execution_options(isolation_level="FOO")
@testing.requires.get_isolation_level_values
@testing.requires.dialect_level_isolation_level_param
def test_invalid_level_engine_param(self, testing_engine):
"""test for the new get_isolation_level_values() method
and support for the dialect-level 'isolation_level' parameter.
"""
eng = testing_engine(options=dict(isolation_level="FOO"))
with expect_raises_message(
exc.ArgumentError,
"Invalid value '%s' for isolation_level. "
"Valid isolation levels for '%s' are %s"
% (
"FOO",
eng.dialect.name,
", ".join(
requirements.get_isolation_levels(config)["supported"]
),
),
):
eng.connect()
@testing.requires.independent_readonly_connections
def test_dialect_user_setting_is_restored(self, testing_engine):
levels = requirements.get_isolation_levels(config)
default = levels["default"]
supported = (
sorted(
set(levels["supported"]).difference([default, "AUTOCOMMIT"])
)
)[0]
e = testing_engine(options={"isolation_level": supported})
with e.connect() as conn:
eq_(conn.get_isolation_level(), supported)
with e.connect() as conn:
conn.execution_options(isolation_level=default)
eq_(conn.get_isolation_level(), default)
with e.connect() as conn:
eq_(conn.get_isolation_level(), supported)
class AutocommitIsolationTest(fixtures.TablesTest):
run_deletes = "each"
__requires__ = ("autocommit",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
"some_table",
metadata,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("data", String(50)),
test_needs_acid=True,
)
def _test_conn_autocommits(self, conn, autocommit):
trans = conn.begin()
conn.execute(
self.tables.some_table.insert(), {"id": 1, "data": "some data"}
)
trans.rollback()
eq_(
conn.scalar(select(self.tables.some_table.c.id)),
1 if autocommit else None,
)
conn.rollback()
with conn.begin():
conn.execute(self.tables.some_table.delete())
def test_autocommit_on(self, connection_no_trans):
conn = connection_no_trans
c2 = conn.execution_options(isolation_level="AUTOCOMMIT")
self._test_conn_autocommits(c2, True)
c2.dialect.reset_isolation_level(c2.connection.dbapi_connection)
self._test_conn_autocommits(conn, False)
def test_autocommit_off(self, connection_no_trans):
conn = connection_no_trans
self._test_conn_autocommits(conn, False)
def test_turn_autocommit_off_via_default_iso_level(
self, connection_no_trans
):
conn = connection_no_trans
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
self._test_conn_autocommits(conn, True)
conn.execution_options(
isolation_level=requirements.get_isolation_levels(config)[
"default"
]
)
self._test_conn_autocommits(conn, False)
@testing.requires.independent_readonly_connections
@testing.variation("use_dialect_setting", [True, False])
def test_dialect_autocommit_is_restored(
self, testing_engine, use_dialect_setting
):
"""test #10147"""
if use_dialect_setting:
e = testing_engine(options={"isolation_level": "AUTOCOMMIT"})
else:
e = testing_engine().execution_options(
isolation_level="AUTOCOMMIT"
)
levels = requirements.get_isolation_levels(config)
default = levels["default"]
with e.connect() as conn:
self._test_conn_autocommits(conn, True)
with e.connect() as conn:
conn.execution_options(isolation_level=default)
self._test_conn_autocommits(conn, False)
with e.connect() as conn:
self._test_conn_autocommits(conn, True)
class EscapingTest(fixtures.TestBase):
@provide_metadata
def test_percent_sign_round_trip(self):
"""test that the DBAPI accommodates for escaped / nonescaped
percent signs in a way that matches the compiler
"""
m = self.metadata
t = Table("t", m, Column("data", String(50)))
t.create(config.db)
with config.db.begin() as conn:
conn.execute(t.insert(), dict(data="some % value"))
conn.execute(t.insert(), dict(data="some %% other value"))
eq_(
conn.scalar(
select(t.c.data).where(
t.c.data == literal_column("'some % value'")
)
),
"some % value",
)
eq_(
conn.scalar(
select(t.c.data).where(
t.c.data == literal_column("'some %% other value'")
)
),
"some %% other value",
)
class WeCanSetDefaultSchemaWEventsTest(fixtures.TestBase):
__backend__ = True
__requires__ = ("default_schema_name_switch",)
def test_control_case(self):
default_schema_name = config.db.dialect.default_schema_name
eng = engines.testing_engine()
with eng.connect():
pass
eq_(eng.dialect.default_schema_name, default_schema_name)
def test_wont_work_wo_insert(self):
default_schema_name = config.db.dialect.default_schema_name
eng = engines.testing_engine()
@event.listens_for(eng, "connect")
def on_connect(dbapi_connection, connection_record):
set_default_schema_on_connection(
config, dbapi_connection, config.test_schema
)
with eng.connect() as conn:
what_it_should_be = eng.dialect._get_default_schema_name(conn)
eq_(what_it_should_be, config.test_schema)
eq_(eng.dialect.default_schema_name, default_schema_name)
def test_schema_change_on_connect(self):
eng = engines.testing_engine()
@event.listens_for(eng, "connect", insert=True)
def on_connect(dbapi_connection, connection_record):
set_default_schema_on_connection(
config, dbapi_connection, config.test_schema
)
with eng.connect() as conn:
what_it_should_be = eng.dialect._get_default_schema_name(conn)
eq_(what_it_should_be, config.test_schema)
eq_(eng.dialect.default_schema_name, config.test_schema)
def test_schema_change_works_w_transactions(self):
eng = engines.testing_engine()
@event.listens_for(eng, "connect", insert=True)
def on_connect(dbapi_connection, *arg):
set_default_schema_on_connection(
config, dbapi_connection, config.test_schema
)
with eng.connect() as conn:
trans = conn.begin()
what_it_should_be = eng.dialect._get_default_schema_name(conn)
eq_(what_it_should_be, config.test_schema)
trans.rollback()
what_it_should_be = eng.dialect._get_default_schema_name(conn)
eq_(what_it_should_be, config.test_schema)
eq_(eng.dialect.default_schema_name, config.test_schema)
class FutureWeCanSetDefaultSchemaWEventsTest(
fixtures.FutureEngineMixin, WeCanSetDefaultSchemaWEventsTest
):
pass
class DifficultParametersTest(fixtures.TestBase):
__backend__ = True
tough_parameters = testing.combinations(
("boring",),
("per cent",),
("per % cent",),
("%percent",),
("par(ens)",),
("percent%(ens)yah",),
("col:ons",),
("_starts_with_underscore",),
("dot.s",),
("more :: %colons%",),
("_name",),
("___name",),
("[BracketsAndCase]",),
("42numbers",),
("percent%signs",),
("has spaces",),
("/slashes/",),
("more/slashes",),
("q?marks",),
("1param",),
("1col:on",),
argnames="paramname",
)
@tough_parameters
@config.requirements.unusual_column_name_characters
def test_round_trip_same_named_column(
self, paramname, connection, metadata
):
name = paramname
t = Table(
"t",
metadata,
Column("id", Integer, primary_key=True),
Column(name, String(50), nullable=False),
)
# table is created
t.create(connection)
# automatic param generated by insert
connection.execute(t.insert().values({"id": 1, name: "some name"}))
# automatic param generated by criteria, plus selecting the column
stmt = select(t.c[name]).where(t.c[name] == "some name")
eq_(connection.scalar(stmt), "some name")
# use the name in a param explicitly
stmt = select(t.c[name]).where(t.c[name] == bindparam(name))
row = connection.execute(stmt, {name: "some name"}).first()
# name works as the key from cursor.description
eq_(row._mapping[name], "some name")
# use expanding IN
stmt = select(t.c[name]).where(
t.c[name].in_(["some name", "some other_name"])
)
row = connection.execute(stmt).first()
@testing.fixture
def multirow_fixture(self, metadata, connection):
mytable = Table(
"mytable",
metadata,
Column("myid", Integer),
Column("name", String(50)),
Column("desc", String(50)),
)
mytable.create(connection)
connection.execute(
mytable.insert(),
[
{"myid": 1, "name": "a", "desc": "a_desc"},
{"myid": 2, "name": "b", "desc": "b_desc"},
{"myid": 3, "name": "c", "desc": "c_desc"},
{"myid": 4, "name": "d", "desc": "d_desc"},
],
)
yield mytable
@tough_parameters
def test_standalone_bindparam_escape(
self, paramname, connection, multirow_fixture
):
tbl1 = multirow_fixture
stmt = select(tbl1.c.myid).where(
tbl1.c.name == bindparam(paramname, value="x")
)
res = connection.scalar(stmt, {paramname: "c"})
eq_(res, 3)
@tough_parameters
def test_standalone_bindparam_escape_expanding(
self, paramname, connection, multirow_fixture
):
tbl1 = multirow_fixture
stmt = (
select(tbl1.c.myid)
.where(tbl1.c.name.in_(bindparam(paramname, value=["a", "b"])))
.order_by(tbl1.c.myid)
)
res = connection.scalars(stmt, {paramname: ["d", "a"]}).all()
eq_(res, [1, 4])
class ReturningGuardsTest(fixtures.TablesTest):
"""test that the various 'returning' flags are set appropriately"""
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
"t",
metadata,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("data", String(50)),
)
@testing.fixture
def run_stmt(self, connection):
t = self.tables.t
def go(stmt, executemany, id_param_name, expect_success):
stmt = stmt.returning(t.c.id)
if executemany:
if not expect_success:
# for RETURNING executemany(), we raise our own
# error as this is independent of general RETURNING
# support
with expect_raises_message(
exc.StatementError,
rf"Dialect {connection.dialect.name}\+"
f"{connection.dialect.driver} with "
f"current server capabilities does not support "
f".*RETURNING when executemany is used",
):
result = connection.execute(
stmt,
[
{id_param_name: 1, "data": "d1"},
{id_param_name: 2, "data": "d2"},
{id_param_name: 3, "data": "d3"},
],
)
else:
result = connection.execute(
stmt,
[
{id_param_name: 1, "data": "d1"},
{id_param_name: 2, "data": "d2"},
{id_param_name: 3, "data": "d3"},
],
)
eq_(result.all(), [(1,), (2,), (3,)])
else:
if not expect_success:
# for RETURNING execute(), we pass all the way to the DB
# and let it fail
with expect_raises(exc.DBAPIError):
connection.execute(
stmt, {id_param_name: 1, "data": "d1"}
)
else:
result = connection.execute(
stmt, {id_param_name: 1, "data": "d1"}
)
eq_(result.all(), [(1,)])
return go
def test_insert_single(self, connection, run_stmt):
t = self.tables.t
stmt = t.insert()
run_stmt(stmt, False, "id", connection.dialect.insert_returning)
def test_insert_many(self, connection, run_stmt):
t = self.tables.t
stmt = t.insert()
run_stmt(
stmt, True, "id", connection.dialect.insert_executemany_returning
)
def test_update_single(self, connection, run_stmt):
t = self.tables.t
connection.execute(
t.insert(),
[
{"id": 1, "data": "d1"},
{"id": 2, "data": "d2"},
{"id": 3, "data": "d3"},
],
)
stmt = t.update().where(t.c.id == bindparam("b_id"))
run_stmt(stmt, False, "b_id", connection.dialect.update_returning)
def test_update_many(self, connection, run_stmt):
t = self.tables.t
connection.execute(
t.insert(),
[
{"id": 1, "data": "d1"},
{"id": 2, "data": "d2"},
{"id": 3, "data": "d3"},
],
)
stmt = t.update().where(t.c.id == bindparam("b_id"))
run_stmt(
stmt, True, "b_id", connection.dialect.update_executemany_returning
)
def test_delete_single(self, connection, run_stmt):
t = self.tables.t
connection.execute(
t.insert(),
[
{"id": 1, "data": "d1"},
{"id": 2, "data": "d2"},
{"id": 3, "data": "d3"},
],
)
stmt = t.delete().where(t.c.id == bindparam("b_id"))
run_stmt(stmt, False, "b_id", connection.dialect.delete_returning)
def test_delete_many(self, connection, run_stmt):
t = self.tables.t
connection.execute(
t.insert(),
[
{"id": 1, "data": "d1"},
{"id": 2, "data": "d2"},
{"id": 3, "data": "d3"},
],
)
stmt = t.delete().where(t.c.id == bindparam("b_id"))
run_stmt(
stmt, True, "b_id", connection.dialect.delete_executemany_returning
)

View File

@ -0,0 +1,630 @@
# testing/suite/test_insert.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from decimal import Decimal
import uuid
from . import testing
from .. import fixtures
from ..assertions import eq_
from ..config import requirements
from ..schema import Column
from ..schema import Table
from ... import Double
from ... import Float
from ... import Identity
from ... import Integer
from ... import literal
from ... import literal_column
from ... import Numeric
from ... import select
from ... import String
from ...types import LargeBinary
from ...types import UUID
from ...types import Uuid
class LastrowidTest(fixtures.TablesTest):
run_deletes = "each"
__backend__ = True
__requires__ = "implements_get_lastrowid", "autoincrement_insert"
@classmethod
def define_tables(cls, metadata):
Table(
"autoinc_pk",
metadata,
Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
),
Column("data", String(50)),
implicit_returning=False,
)
Table(
"manual_pk",
metadata,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("data", String(50)),
implicit_returning=False,
)
def _assert_round_trip(self, table, conn):
row = conn.execute(table.select()).first()
eq_(
row,
(
conn.dialect.default_sequence_base,
"some data",
),
)
def test_autoincrement_on_insert(self, connection):
connection.execute(
self.tables.autoinc_pk.insert(), dict(data="some data")
)
self._assert_round_trip(self.tables.autoinc_pk, connection)
def test_last_inserted_id(self, connection):
r = connection.execute(
self.tables.autoinc_pk.insert(), dict(data="some data")
)
pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
eq_(r.inserted_primary_key, (pk,))
@requirements.dbapi_lastrowid
def test_native_lastrowid_autoinc(self, connection):
r = connection.execute(
self.tables.autoinc_pk.insert(), dict(data="some data")
)
lastrowid = r.lastrowid
pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
eq_(lastrowid, pk)
class InsertBehaviorTest(fixtures.TablesTest):
run_deletes = "each"
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
"autoinc_pk",
metadata,
Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
),
Column("data", String(50)),
)
Table(
"manual_pk",
metadata,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("data", String(50)),
)
Table(
"no_implicit_returning",
metadata,
Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
),
Column("data", String(50)),
implicit_returning=False,
)
Table(
"includes_defaults",
metadata,
Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
),
Column("data", String(50)),
Column("x", Integer, default=5),
Column(
"y",
Integer,
default=literal_column("2", type_=Integer) + literal(2),
),
)
@testing.variation("style", ["plain", "return_defaults"])
@testing.variation("executemany", [True, False])
def test_no_results_for_non_returning_insert(
self, connection, style, executemany
):
"""test another INSERT issue found during #10453"""
table = self.tables.no_implicit_returning
stmt = table.insert()
if style.return_defaults:
stmt = stmt.return_defaults()
if executemany:
data = [
{"data": "d1"},
{"data": "d2"},
{"data": "d3"},
{"data": "d4"},
{"data": "d5"},
]
else:
data = {"data": "d1"}
r = connection.execute(stmt, data)
assert not r.returns_rows
@requirements.autoincrement_insert
def test_autoclose_on_insert(self, connection):
r = connection.execute(
self.tables.autoinc_pk.insert(), dict(data="some data")
)
assert r._soft_closed
assert not r.closed
assert r.is_insert
# new as of I8091919d45421e3f53029b8660427f844fee0228; for the moment
# an insert where the PK was taken from a row that the dialect
# selected, as is the case for mssql/pyodbc, will still report
# returns_rows as true because there's a cursor description. in that
# case, the row had to have been consumed at least.
assert not r.returns_rows or r.fetchone() is None
@requirements.insert_returning
def test_autoclose_on_insert_implicit_returning(self, connection):
r = connection.execute(
# return_defaults() ensures RETURNING will be used,
# new in 2.0 as sqlite/mariadb offer both RETURNING and
# cursor.lastrowid
self.tables.autoinc_pk.insert().return_defaults(),
dict(data="some data"),
)
assert r._soft_closed
assert not r.closed
assert r.is_insert
# note we are experimenting with having this be True
# as of I8091919d45421e3f53029b8660427f844fee0228 .
# implicit returning has fetched the row, but it still is a
# "returns rows"
assert r.returns_rows
# and we should be able to fetchone() on it, we just get no row
eq_(r.fetchone(), None)
# and the keys, etc.
eq_(r.keys(), ["id"])
# but the dialect took in the row already. not really sure
# what the best behavior is.
@requirements.empty_inserts
def test_empty_insert(self, connection):
r = connection.execute(self.tables.autoinc_pk.insert())
assert r._soft_closed
assert not r.closed
r = connection.execute(
self.tables.autoinc_pk.select().where(
self.tables.autoinc_pk.c.id != None
)
)
eq_(len(r.all()), 1)
@requirements.empty_inserts_executemany
def test_empty_insert_multiple(self, connection):
r = connection.execute(self.tables.autoinc_pk.insert(), [{}, {}, {}])
assert r._soft_closed
assert not r.closed
r = connection.execute(
self.tables.autoinc_pk.select().where(
self.tables.autoinc_pk.c.id != None
)
)
eq_(len(r.all()), 3)
@requirements.insert_from_select
def test_insert_from_select_autoinc(self, connection):
src_table = self.tables.manual_pk
dest_table = self.tables.autoinc_pk
connection.execute(
src_table.insert(),
[
dict(id=1, data="data1"),
dict(id=2, data="data2"),
dict(id=3, data="data3"),
],
)
result = connection.execute(
dest_table.insert().from_select(
("data",),
select(src_table.c.data).where(
src_table.c.data.in_(["data2", "data3"])
),
)
)
eq_(result.inserted_primary_key, (None,))
result = connection.execute(
select(dest_table.c.data).order_by(dest_table.c.data)
)
eq_(result.fetchall(), [("data2",), ("data3",)])
@requirements.insert_from_select
def test_insert_from_select_autoinc_no_rows(self, connection):
src_table = self.tables.manual_pk
dest_table = self.tables.autoinc_pk
result = connection.execute(
dest_table.insert().from_select(
("data",),
select(src_table.c.data).where(
src_table.c.data.in_(["data2", "data3"])
),
)
)
eq_(result.inserted_primary_key, (None,))
result = connection.execute(
select(dest_table.c.data).order_by(dest_table.c.data)
)
eq_(result.fetchall(), [])
@requirements.insert_from_select
def test_insert_from_select(self, connection):
table = self.tables.manual_pk
connection.execute(
table.insert(),
[
dict(id=1, data="data1"),
dict(id=2, data="data2"),
dict(id=3, data="data3"),
],
)
connection.execute(
table.insert()
.inline()
.from_select(
("id", "data"),
select(table.c.id + 5, table.c.data).where(
table.c.data.in_(["data2", "data3"])
),
)
)
eq_(
connection.execute(
select(table.c.data).order_by(table.c.data)
).fetchall(),
[("data1",), ("data2",), ("data2",), ("data3",), ("data3",)],
)
@requirements.insert_from_select
def test_insert_from_select_with_defaults(self, connection):
table = self.tables.includes_defaults
connection.execute(
table.insert(),
[
dict(id=1, data="data1"),
dict(id=2, data="data2"),
dict(id=3, data="data3"),
],
)
connection.execute(
table.insert()
.inline()
.from_select(
("id", "data"),
select(table.c.id + 5, table.c.data).where(
table.c.data.in_(["data2", "data3"])
),
)
)
eq_(
connection.execute(
select(table).order_by(table.c.data, table.c.id)
).fetchall(),
[
(1, "data1", 5, 4),
(2, "data2", 5, 4),
(7, "data2", 5, 4),
(3, "data3", 5, 4),
(8, "data3", 5, 4),
],
)
class ReturningTest(fixtures.TablesTest):
run_create_tables = "each"
__requires__ = "insert_returning", "autoincrement_insert"
__backend__ = True
def _assert_round_trip(self, table, conn):
row = conn.execute(table.select()).first()
eq_(
row,
(
conn.dialect.default_sequence_base,
"some data",
),
)
@classmethod
def define_tables(cls, metadata):
Table(
"autoinc_pk",
metadata,
Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
),
Column("data", String(50)),
)
@requirements.fetch_rows_post_commit
def test_explicit_returning_pk_autocommit(self, connection):
table = self.tables.autoinc_pk
r = connection.execute(
table.insert().returning(table.c.id), dict(data="some data")
)
pk = r.first()[0]
fetched_pk = connection.scalar(select(table.c.id))
eq_(fetched_pk, pk)
def test_explicit_returning_pk_no_autocommit(self, connection):
table = self.tables.autoinc_pk
r = connection.execute(
table.insert().returning(table.c.id), dict(data="some data")
)
pk = r.first()[0]
fetched_pk = connection.scalar(select(table.c.id))
eq_(fetched_pk, pk)
def test_autoincrement_on_insert_implicit_returning(self, connection):
connection.execute(
self.tables.autoinc_pk.insert(), dict(data="some data")
)
self._assert_round_trip(self.tables.autoinc_pk, connection)
def test_last_inserted_id_implicit_returning(self, connection):
r = connection.execute(
self.tables.autoinc_pk.insert(), dict(data="some data")
)
pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
eq_(r.inserted_primary_key, (pk,))
@requirements.insert_executemany_returning
def test_insertmanyvalues_returning(self, connection):
r = connection.execute(
self.tables.autoinc_pk.insert().returning(
self.tables.autoinc_pk.c.id
),
[
{"data": "d1"},
{"data": "d2"},
{"data": "d3"},
{"data": "d4"},
{"data": "d5"},
],
)
rall = r.all()
pks = connection.execute(select(self.tables.autoinc_pk.c.id))
eq_(rall, pks.all())
@testing.combinations(
(Double(), 8.5514716, True),
(
Double(53),
8.5514716,
True,
testing.requires.float_or_double_precision_behaves_generically,
),
(Float(), 8.5514, True),
(
Float(8),
8.5514,
True,
testing.requires.float_or_double_precision_behaves_generically,
),
(
Numeric(precision=15, scale=12, asdecimal=False),
8.5514716,
True,
testing.requires.literal_float_coercion,
),
(
Numeric(precision=15, scale=12, asdecimal=True),
Decimal("8.5514716"),
False,
),
argnames="type_,value,do_rounding",
)
@testing.variation("sort_by_parameter_order", [True, False])
@testing.variation("multiple_rows", [True, False])
def test_insert_w_floats(
self,
connection,
metadata,
sort_by_parameter_order,
type_,
value,
do_rounding,
multiple_rows,
):
"""test #9701.
this tests insertmanyvalues as well as decimal / floating point
RETURNING types
"""
t = Table(
# Oracle backends seems to be getting confused if
# this table is named the same as the one
# in test_imv_returning_datatypes. use a different name
"f_t",
metadata,
Column("id", Integer, Identity(), primary_key=True),
Column("value", type_),
)
t.create(connection)
result = connection.execute(
t.insert().returning(
t.c.id,
t.c.value,
sort_by_parameter_order=bool(sort_by_parameter_order),
),
(
[{"value": value} for i in range(10)]
if multiple_rows
else {"value": value}
),
)
if multiple_rows:
i_range = range(1, 11)
else:
i_range = range(1, 2)
# we want to test only that we are getting floating points back
# with some degree of the original value maintained, that it is not
# being truncated to an integer. there's too much variation in how
# drivers return floats, which should not be relied upon to be
# exact, for us to just compare as is (works for PG drivers but not
# others) so we use rounding here. There's precedent for this
# in suite/test_types.py::NumericTest as well
if do_rounding:
eq_(
{(id_, round(val_, 5)) for id_, val_ in result},
{(id_, round(value, 5)) for id_ in i_range},
)
eq_(
{
round(val_, 5)
for val_ in connection.scalars(select(t.c.value))
},
{round(value, 5)},
)
else:
eq_(
set(result),
{(id_, value) for id_ in i_range},
)
eq_(
set(connection.scalars(select(t.c.value))),
{value},
)
@testing.combinations(
(
"non_native_uuid",
Uuid(native_uuid=False),
uuid.uuid4(),
),
(
"non_native_uuid_str",
Uuid(as_uuid=False, native_uuid=False),
str(uuid.uuid4()),
),
(
"generic_native_uuid",
Uuid(native_uuid=True),
uuid.uuid4(),
testing.requires.uuid_data_type,
),
(
"generic_native_uuid_str",
Uuid(as_uuid=False, native_uuid=True),
str(uuid.uuid4()),
testing.requires.uuid_data_type,
),
("UUID", UUID(), uuid.uuid4(), testing.requires.uuid_data_type),
(
"LargeBinary1",
LargeBinary(),
b"this is binary",
),
("LargeBinary2", LargeBinary(), b"7\xe7\x9f"),
argnames="type_,value",
id_="iaa",
)
@testing.variation("sort_by_parameter_order", [True, False])
@testing.variation("multiple_rows", [True, False])
@testing.requires.insert_returning
def test_imv_returning_datatypes(
self,
connection,
metadata,
sort_by_parameter_order,
type_,
value,
multiple_rows,
):
"""test #9739, #9808 (similar to #9701).
this tests insertmanyvalues in conjunction with various datatypes.
These tests are particularly for the asyncpg driver which needs
most types to be explicitly cast for the new IMV format
"""
t = Table(
"d_t",
metadata,
Column("id", Integer, Identity(), primary_key=True),
Column("value", type_),
)
t.create(connection)
result = connection.execute(
t.insert().returning(
t.c.id,
t.c.value,
sort_by_parameter_order=bool(sort_by_parameter_order),
),
(
[{"value": value} for i in range(10)]
if multiple_rows
else {"value": value}
),
)
if multiple_rows:
i_range = range(1, 11)
else:
i_range = range(1, 2)
eq_(
set(result),
{(id_, value) for id_ in i_range},
)
eq_(
set(connection.scalars(select(t.c.value))),
{value},
)
__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,504 @@
# testing/suite/test_results.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
import datetime
import re
from .. import engines
from .. import fixtures
from ..assertions import eq_
from ..config import requirements
from ..schema import Column
from ..schema import Table
from ... import DateTime
from ... import func
from ... import Integer
from ... import select
from ... import sql
from ... import String
from ... import testing
from ... import text
class RowFetchTest(fixtures.TablesTest):
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
"plain_pk",
metadata,
Column("id", Integer, primary_key=True),
Column("data", String(50)),
)
Table(
"has_dates",
metadata,
Column("id", Integer, primary_key=True),
Column("today", DateTime),
)
@classmethod
def insert_data(cls, connection):
connection.execute(
cls.tables.plain_pk.insert(),
[
{"id": 1, "data": "d1"},
{"id": 2, "data": "d2"},
{"id": 3, "data": "d3"},
],
)
connection.execute(
cls.tables.has_dates.insert(),
[{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}],
)
def test_via_attr(self, connection):
row = connection.execute(
self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
).first()
eq_(row.id, 1)
eq_(row.data, "d1")
def test_via_string(self, connection):
row = connection.execute(
self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
).first()
eq_(row._mapping["id"], 1)
eq_(row._mapping["data"], "d1")
def test_via_int(self, connection):
row = connection.execute(
self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
).first()
eq_(row[0], 1)
eq_(row[1], "d1")
def test_via_col_object(self, connection):
row = connection.execute(
self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
).first()
eq_(row._mapping[self.tables.plain_pk.c.id], 1)
eq_(row._mapping[self.tables.plain_pk.c.data], "d1")
@requirements.duplicate_names_in_cursor_description
def test_row_with_dupe_names(self, connection):
result = connection.execute(
select(
self.tables.plain_pk.c.data,
self.tables.plain_pk.c.data.label("data"),
).order_by(self.tables.plain_pk.c.id)
)
row = result.first()
eq_(result.keys(), ["data", "data"])
eq_(row, ("d1", "d1"))
def test_row_w_scalar_select(self, connection):
"""test that a scalar select as a column is returned as such
and that type conversion works OK.
(this is half a SQLAlchemy Core test and half to catch database
backends that may have unusual behavior with scalar selects.)
"""
datetable = self.tables.has_dates
s = select(datetable.alias("x").c.today).scalar_subquery()
s2 = select(datetable.c.id, s.label("somelabel"))
row = connection.execute(s2).first()
eq_(row.somelabel, datetime.datetime(2006, 5, 12, 12, 0, 0))
class PercentSchemaNamesTest(fixtures.TablesTest):
"""tests using percent signs, spaces in table and column names.
This didn't work for PostgreSQL / MySQL drivers for a long time
but is now supported.
"""
__requires__ = ("percent_schema_names",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
cls.tables.percent_table = Table(
"percent%table",
metadata,
Column("percent%", Integer),
Column("spaces % more spaces", Integer),
)
cls.tables.lightweight_percent_table = sql.table(
"percent%table",
sql.column("percent%"),
sql.column("spaces % more spaces"),
)
def test_single_roundtrip(self, connection):
percent_table = self.tables.percent_table
for params in [
{"percent%": 5, "spaces % more spaces": 12},
{"percent%": 7, "spaces % more spaces": 11},
{"percent%": 9, "spaces % more spaces": 10},
{"percent%": 11, "spaces % more spaces": 9},
]:
connection.execute(percent_table.insert(), params)
self._assert_table(connection)
def test_executemany_roundtrip(self, connection):
percent_table = self.tables.percent_table
connection.execute(
percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
)
connection.execute(
percent_table.insert(),
[
{"percent%": 7, "spaces % more spaces": 11},
{"percent%": 9, "spaces % more spaces": 10},
{"percent%": 11, "spaces % more spaces": 9},
],
)
self._assert_table(connection)
@requirements.insert_executemany_returning
def test_executemany_returning_roundtrip(self, connection):
percent_table = self.tables.percent_table
connection.execute(
percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
)
result = connection.execute(
percent_table.insert().returning(
percent_table.c["percent%"],
percent_table.c["spaces % more spaces"],
),
[
{"percent%": 7, "spaces % more spaces": 11},
{"percent%": 9, "spaces % more spaces": 10},
{"percent%": 11, "spaces % more spaces": 9},
],
)
eq_(result.all(), [(7, 11), (9, 10), (11, 9)])
self._assert_table(connection)
def _assert_table(self, conn):
percent_table = self.tables.percent_table
lightweight_percent_table = self.tables.lightweight_percent_table
for table in (
percent_table,
percent_table.alias(),
lightweight_percent_table,
lightweight_percent_table.alias(),
):
eq_(
list(
conn.execute(table.select().order_by(table.c["percent%"]))
),
[(5, 12), (7, 11), (9, 10), (11, 9)],
)
eq_(
list(
conn.execute(
table.select()
.where(table.c["spaces % more spaces"].in_([9, 10]))
.order_by(table.c["percent%"])
)
),
[(9, 10), (11, 9)],
)
row = conn.execute(
table.select().order_by(table.c["percent%"])
).first()
eq_(row._mapping["percent%"], 5)
eq_(row._mapping["spaces % more spaces"], 12)
eq_(row._mapping[table.c["percent%"]], 5)
eq_(row._mapping[table.c["spaces % more spaces"]], 12)
conn.execute(
percent_table.update().values(
{percent_table.c["spaces % more spaces"]: 15}
)
)
eq_(
list(
conn.execute(
percent_table.select().order_by(
percent_table.c["percent%"]
)
)
),
[(5, 15), (7, 15), (9, 15), (11, 15)],
)
class ServerSideCursorsTest(
fixtures.TestBase, testing.AssertsExecutionResults
):
__requires__ = ("server_side_cursors",)
__backend__ = True
def _is_server_side(self, cursor):
# TODO: this is a huge issue as it prevents these tests from being
# usable by third party dialects.
if self.engine.dialect.driver == "psycopg2":
return bool(cursor.name)
elif self.engine.dialect.driver == "pymysql":
sscursor = __import__("pymysql.cursors").cursors.SSCursor
return isinstance(cursor, sscursor)
elif self.engine.dialect.driver in ("aiomysql", "asyncmy", "aioodbc"):
return cursor.server_side
elif self.engine.dialect.driver == "mysqldb":
sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
return isinstance(cursor, sscursor)
elif self.engine.dialect.driver == "mariadbconnector":
return not cursor.buffered
elif self.engine.dialect.driver == "mysqlconnector":
return "buffered" not in type(cursor).__name__.lower()
elif self.engine.dialect.driver in ("asyncpg", "aiosqlite"):
return cursor.server_side
elif self.engine.dialect.driver == "pg8000":
return getattr(cursor, "server_side", False)
elif self.engine.dialect.driver == "psycopg":
return bool(getattr(cursor, "name", False))
elif self.engine.dialect.driver == "oracledb":
return getattr(cursor, "server_side", False)
else:
return False
def _fixture(self, server_side_cursors):
if server_side_cursors:
with testing.expect_deprecated(
"The create_engine.server_side_cursors parameter is "
"deprecated and will be removed in a future release. "
"Please use the Connection.execution_options.stream_results "
"parameter."
):
self.engine = engines.testing_engine(
options={"server_side_cursors": server_side_cursors}
)
else:
self.engine = engines.testing_engine(
options={"server_side_cursors": server_side_cursors}
)
return self.engine
def stringify(self, str_):
return re.compile(r"SELECT (\d+)", re.I).sub(
lambda m: str(select(int(m.group(1))).compile(testing.db)), str_
)
@testing.combinations(
("global_string", True, lambda stringify: stringify("select 1"), True),
(
"global_text",
True,
lambda stringify: text(stringify("select 1")),
True,
),
("global_expr", True, select(1), True),
(
"global_off_explicit",
False,
lambda stringify: text(stringify("select 1")),
False,
),
(
"stmt_option",
False,
select(1).execution_options(stream_results=True),
True,
),
(
"stmt_option_disabled",
True,
select(1).execution_options(stream_results=False),
False,
),
("for_update_expr", True, select(1).with_for_update(), True),
# TODO: need a real requirement for this, or dont use this test
(
"for_update_string",
True,
lambda stringify: stringify("SELECT 1 FOR UPDATE"),
True,
testing.skip_if(["sqlite", "mssql"]),
),
(
"text_no_ss",
False,
lambda stringify: text(stringify("select 42")),
False,
),
(
"text_ss_option",
False,
lambda stringify: text(stringify("select 42")).execution_options(
stream_results=True
),
True,
),
id_="iaaa",
argnames="engine_ss_arg, statement, cursor_ss_status",
)
def test_ss_cursor_status(
self, engine_ss_arg, statement, cursor_ss_status
):
engine = self._fixture(engine_ss_arg)
with engine.begin() as conn:
if callable(statement):
statement = testing.resolve_lambda(
statement, stringify=self.stringify
)
if isinstance(statement, str):
result = conn.exec_driver_sql(statement)
else:
result = conn.execute(statement)
eq_(self._is_server_side(result.cursor), cursor_ss_status)
result.close()
def test_conn_option(self):
engine = self._fixture(False)
with engine.connect() as conn:
# should be enabled for this one
result = conn.execution_options(
stream_results=True
).exec_driver_sql(self.stringify("select 1"))
assert self._is_server_side(result.cursor)
# the connection has autobegun, which means at the end of the
# block, we will roll back, which on MySQL at least will fail
# with "Commands out of sync" if the result set
# is not closed, so we close it first.
#
# fun fact! why did we not have this result.close() in this test
# before 2.0? don't we roll back in the connection pool
# unconditionally? yes! and in fact if you run this test in 1.4
# with stdout shown, there is in fact "Exception during reset or
# similar" with "Commands out sync" emitted a warning! 2.0's
# architecture finds and fixes what was previously an expensive
# silent error condition.
result.close()
def test_stmt_enabled_conn_option_disabled(self):
engine = self._fixture(False)
s = select(1).execution_options(stream_results=True)
with engine.connect() as conn:
# not this one
result = conn.execution_options(stream_results=False).execute(s)
assert not self._is_server_side(result.cursor)
def test_aliases_and_ss(self):
engine = self._fixture(False)
s1 = (
select(sql.literal_column("1").label("x"))
.execution_options(stream_results=True)
.subquery()
)
# options don't propagate out when subquery is used as a FROM clause
with engine.begin() as conn:
result = conn.execute(s1.select())
assert not self._is_server_side(result.cursor)
result.close()
s2 = select(1).select_from(s1)
with engine.begin() as conn:
result = conn.execute(s2)
assert not self._is_server_side(result.cursor)
result.close()
def test_roundtrip_fetchall(self, metadata):
md = self.metadata
engine = self._fixture(True)
test_table = Table(
"test_table",
md,
Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
),
Column("data", String(50)),
)
with engine.begin() as connection:
test_table.create(connection, checkfirst=True)
connection.execute(test_table.insert(), dict(data="data1"))
connection.execute(test_table.insert(), dict(data="data2"))
eq_(
connection.execute(
test_table.select().order_by(test_table.c.id)
).fetchall(),
[(1, "data1"), (2, "data2")],
)
connection.execute(
test_table.update()
.where(test_table.c.id == 2)
.values(data=test_table.c.data + " updated")
)
eq_(
connection.execute(
test_table.select().order_by(test_table.c.id)
).fetchall(),
[(1, "data1"), (2, "data2 updated")],
)
connection.execute(test_table.delete())
eq_(
connection.scalar(
select(func.count("*")).select_from(test_table)
),
0,
)
def test_roundtrip_fetchmany(self, metadata):
md = self.metadata
engine = self._fixture(True)
test_table = Table(
"test_table",
md,
Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
),
Column("data", String(50)),
)
with engine.begin() as connection:
test_table.create(connection, checkfirst=True)
connection.execute(
test_table.insert(),
[dict(data="data%d" % i) for i in range(1, 20)],
)
result = connection.execute(
test_table.select().order_by(test_table.c.id)
)
eq_(
result.fetchmany(5),
[(i, "data%d" % i) for i in range(1, 6)],
)
eq_(
result.fetchmany(10),
[(i, "data%d" % i) for i in range(6, 16)],
)
eq_(result.fetchall(), [(i, "data%d" % i) for i in range(16, 20)])

View File

@ -0,0 +1,258 @@
# testing/suite/test_rowcount.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from sqlalchemy import bindparam
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import testing
from sqlalchemy import text
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
class RowCountTest(fixtures.TablesTest):
"""test rowcount functionality"""
__requires__ = ("sane_rowcount",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
"employees",
metadata,
Column(
"employee_id",
Integer,
autoincrement=False,
primary_key=True,
),
Column("name", String(50)),
Column("department", String(1)),
)
@classmethod
def insert_data(cls, connection):
cls.data = data = [
("Angela", "A"),
("Andrew", "A"),
("Anand", "A"),
("Bob", "B"),
("Bobette", "B"),
("Buffy", "B"),
("Charlie", "C"),
("Cynthia", "C"),
("Chris", "C"),
]
employees_table = cls.tables.employees
connection.execute(
employees_table.insert(),
[
{"employee_id": i, "name": n, "department": d}
for i, (n, d) in enumerate(data)
],
)
def test_basic(self, connection):
employees_table = self.tables.employees
s = select(
employees_table.c.name, employees_table.c.department
).order_by(employees_table.c.employee_id)
rows = connection.execute(s).fetchall()
eq_(rows, self.data)
@testing.variation("statement", ["update", "delete", "insert", "select"])
@testing.variation("close_first", [True, False])
def test_non_rowcount_scenarios_no_raise(
self, connection, statement, close_first
):
employees_table = self.tables.employees
# WHERE matches 3, 3 rows changed
department = employees_table.c.department
if statement.update:
r = connection.execute(
employees_table.update().where(department == "C"),
{"department": "Z"},
)
elif statement.delete:
r = connection.execute(
employees_table.delete().where(department == "C"),
{"department": "Z"},
)
elif statement.insert:
r = connection.execute(
employees_table.insert(),
[
{"employee_id": 25, "name": "none 1", "department": "X"},
{"employee_id": 26, "name": "none 2", "department": "Z"},
{"employee_id": 27, "name": "none 3", "department": "Z"},
],
)
elif statement.select:
s = select(
employees_table.c.name, employees_table.c.department
).where(employees_table.c.department == "C")
r = connection.execute(s)
r.all()
else:
statement.fail()
if close_first:
r.close()
assert r.rowcount in (-1, 3)
def test_update_rowcount1(self, connection):
employees_table = self.tables.employees
# WHERE matches 3, 3 rows changed
department = employees_table.c.department
r = connection.execute(
employees_table.update().where(department == "C"),
{"department": "Z"},
)
assert r.rowcount == 3
def test_update_rowcount2(self, connection):
employees_table = self.tables.employees
# WHERE matches 3, 0 rows changed
department = employees_table.c.department
r = connection.execute(
employees_table.update().where(department == "C"),
{"department": "C"},
)
eq_(r.rowcount, 3)
@testing.variation("implicit_returning", [True, False])
@testing.variation(
"dml",
[
("update", testing.requires.update_returning),
("delete", testing.requires.delete_returning),
],
)
def test_update_delete_rowcount_return_defaults(
self, connection, implicit_returning, dml
):
"""note this test should succeed for all RETURNING backends
as of 2.0. In
Idf28379f8705e403a3c6a937f6a798a042ef2540 we changed rowcount to use
len(rows) when we have implicit returning
"""
if implicit_returning:
employees_table = self.tables.employees
else:
employees_table = Table(
"employees",
MetaData(),
Column(
"employee_id",
Integer,
autoincrement=False,
primary_key=True,
),
Column("name", String(50)),
Column("department", String(1)),
implicit_returning=False,
)
department = employees_table.c.department
if dml.update:
stmt = (
employees_table.update()
.where(department == "C")
.values(name=employees_table.c.department + "Z")
.return_defaults()
)
elif dml.delete:
stmt = (
employees_table.delete()
.where(department == "C")
.return_defaults()
)
else:
dml.fail()
r = connection.execute(stmt)
eq_(r.rowcount, 3)
def test_raw_sql_rowcount(self, connection):
# test issue #3622, make sure eager rowcount is called for text
result = connection.exec_driver_sql(
"update employees set department='Z' where department='C'"
)
eq_(result.rowcount, 3)
def test_text_rowcount(self, connection):
# test issue #3622, make sure eager rowcount is called for text
result = connection.execute(
text("update employees set department='Z' where department='C'")
)
eq_(result.rowcount, 3)
def test_delete_rowcount(self, connection):
employees_table = self.tables.employees
# WHERE matches 3, 3 rows deleted
department = employees_table.c.department
r = connection.execute(
employees_table.delete().where(department == "C")
)
eq_(r.rowcount, 3)
@testing.requires.sane_multi_rowcount
def test_multi_update_rowcount(self, connection):
employees_table = self.tables.employees
stmt = (
employees_table.update()
.where(employees_table.c.name == bindparam("emp_name"))
.values(department="C")
)
r = connection.execute(
stmt,
[
{"emp_name": "Bob"},
{"emp_name": "Cynthia"},
{"emp_name": "nonexistent"},
],
)
eq_(r.rowcount, 2)
@testing.requires.sane_multi_rowcount
def test_multi_delete_rowcount(self, connection):
employees_table = self.tables.employees
stmt = employees_table.delete().where(
employees_table.c.name == bindparam("emp_name")
)
r = connection.execute(
stmt,
[
{"emp_name": "Bob"},
{"emp_name": "Cynthia"},
{"emp_name": "nonexistent"},
],
)
eq_(r.rowcount, 2)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,317 @@
# testing/suite/test_sequence.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from .. import config
from .. import fixtures
from ..assertions import eq_
from ..assertions import is_true
from ..config import requirements
from ..provision import normalize_sequence
from ..schema import Column
from ..schema import Table
from ... import inspect
from ... import Integer
from ... import MetaData
from ... import Sequence
from ... import String
from ... import testing
class SequenceTest(fixtures.TablesTest):
__requires__ = ("sequences",)
__backend__ = True
run_create_tables = "each"
@classmethod
def define_tables(cls, metadata):
Table(
"seq_pk",
metadata,
Column(
"id",
Integer,
normalize_sequence(config, Sequence("tab_id_seq")),
primary_key=True,
),
Column("data", String(50)),
)
Table(
"seq_opt_pk",
metadata,
Column(
"id",
Integer,
normalize_sequence(
config,
Sequence("tab_id_seq", data_type=Integer, optional=True),
),
primary_key=True,
),
Column("data", String(50)),
)
Table(
"seq_no_returning",
metadata,
Column(
"id",
Integer,
normalize_sequence(config, Sequence("noret_id_seq")),
primary_key=True,
),
Column("data", String(50)),
implicit_returning=False,
)
if testing.requires.schemas.enabled:
Table(
"seq_no_returning_sch",
metadata,
Column(
"id",
Integer,
normalize_sequence(
config,
Sequence(
"noret_sch_id_seq", schema=config.test_schema
),
),
primary_key=True,
),
Column("data", String(50)),
implicit_returning=False,
schema=config.test_schema,
)
def test_insert_roundtrip(self, connection):
connection.execute(self.tables.seq_pk.insert(), dict(data="some data"))
self._assert_round_trip(self.tables.seq_pk, connection)
def test_insert_lastrowid(self, connection):
r = connection.execute(
self.tables.seq_pk.insert(), dict(data="some data")
)
eq_(
r.inserted_primary_key, (testing.db.dialect.default_sequence_base,)
)
def test_nextval_direct(self, connection):
r = connection.scalar(self.tables.seq_pk.c.id.default)
eq_(r, testing.db.dialect.default_sequence_base)
@requirements.sequences_optional
def test_optional_seq(self, connection):
r = connection.execute(
self.tables.seq_opt_pk.insert(), dict(data="some data")
)
eq_(r.inserted_primary_key, (1,))
def _assert_round_trip(self, table, conn):
row = conn.execute(table.select()).first()
eq_(row, (testing.db.dialect.default_sequence_base, "some data"))
def test_insert_roundtrip_no_implicit_returning(self, connection):
connection.execute(
self.tables.seq_no_returning.insert(), dict(data="some data")
)
self._assert_round_trip(self.tables.seq_no_returning, connection)
@testing.combinations((True,), (False,), argnames="implicit_returning")
@testing.requires.schemas
def test_insert_roundtrip_translate(self, connection, implicit_returning):
seq_no_returning = Table(
"seq_no_returning_sch",
MetaData(),
Column(
"id",
Integer,
normalize_sequence(
config, Sequence("noret_sch_id_seq", schema="alt_schema")
),
primary_key=True,
),
Column("data", String(50)),
implicit_returning=implicit_returning,
schema="alt_schema",
)
connection = connection.execution_options(
schema_translate_map={"alt_schema": config.test_schema}
)
connection.execute(seq_no_returning.insert(), dict(data="some data"))
self._assert_round_trip(seq_no_returning, connection)
@testing.requires.schemas
def test_nextval_direct_schema_translate(self, connection):
seq = normalize_sequence(
config, Sequence("noret_sch_id_seq", schema="alt_schema")
)
connection = connection.execution_options(
schema_translate_map={"alt_schema": config.test_schema}
)
r = connection.scalar(seq)
eq_(r, testing.db.dialect.default_sequence_base)
class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase):
__requires__ = ("sequences",)
__backend__ = True
def test_literal_binds_inline_compile(self, connection):
table = Table(
"x",
MetaData(),
Column(
"y", Integer, normalize_sequence(config, Sequence("y_seq"))
),
Column("q", Integer),
)
stmt = table.insert().values(q=5)
seq_nextval = connection.dialect.statement_compiler(
statement=None, dialect=connection.dialect
).visit_sequence(normalize_sequence(config, Sequence("y_seq")))
self.assert_compile(
stmt,
"INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval,),
literal_binds=True,
dialect=connection.dialect,
)
class HasSequenceTest(fixtures.TablesTest):
run_deletes = None
__requires__ = ("sequences",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
normalize_sequence(config, Sequence("user_id_seq", metadata=metadata))
normalize_sequence(
config,
Sequence(
"other_seq",
metadata=metadata,
nomaxvalue=True,
nominvalue=True,
),
)
if testing.requires.schemas.enabled:
normalize_sequence(
config,
Sequence(
"user_id_seq", schema=config.test_schema, metadata=metadata
),
)
normalize_sequence(
config,
Sequence(
"schema_seq", schema=config.test_schema, metadata=metadata
),
)
Table(
"user_id_table",
metadata,
Column("id", Integer, primary_key=True),
)
def test_has_sequence(self, connection):
eq_(inspect(connection).has_sequence("user_id_seq"), True)
def test_has_sequence_cache(self, connection, metadata):
insp = inspect(connection)
eq_(insp.has_sequence("user_id_seq"), True)
ss = normalize_sequence(config, Sequence("new_seq", metadata=metadata))
eq_(insp.has_sequence("new_seq"), False)
ss.create(connection)
try:
eq_(insp.has_sequence("new_seq"), False)
insp.clear_cache()
eq_(insp.has_sequence("new_seq"), True)
finally:
ss.drop(connection)
def test_has_sequence_other_object(self, connection):
eq_(inspect(connection).has_sequence("user_id_table"), False)
@testing.requires.schemas
def test_has_sequence_schema(self, connection):
eq_(
inspect(connection).has_sequence(
"user_id_seq", schema=config.test_schema
),
True,
)
def test_has_sequence_neg(self, connection):
eq_(inspect(connection).has_sequence("some_sequence"), False)
@testing.requires.schemas
def test_has_sequence_schemas_neg(self, connection):
eq_(
inspect(connection).has_sequence(
"some_sequence", schema=config.test_schema
),
False,
)
@testing.requires.schemas
def test_has_sequence_default_not_in_remote(self, connection):
eq_(
inspect(connection).has_sequence(
"other_sequence", schema=config.test_schema
),
False,
)
@testing.requires.schemas
def test_has_sequence_remote_not_in_default(self, connection):
eq_(inspect(connection).has_sequence("schema_seq"), False)
def test_get_sequence_names(self, connection):
exp = {"other_seq", "user_id_seq"}
res = set(inspect(connection).get_sequence_names())
is_true(res.intersection(exp) == exp)
is_true("schema_seq" not in res)
@testing.requires.schemas
def test_get_sequence_names_no_sequence_schema(self, connection):
eq_(
inspect(connection).get_sequence_names(
schema=config.test_schema_2
),
[],
)
@testing.requires.schemas
def test_get_sequence_names_sequences_schema(self, connection):
eq_(
sorted(
inspect(connection).get_sequence_names(
schema=config.test_schema
)
),
["schema_seq", "user_id_seq"],
)
class HasSequenceTestEmpty(fixtures.TestBase):
__requires__ = ("sequences",)
__backend__ = True
def test_get_sequence_names_no_sequence(self, connection):
eq_(
inspect(connection).get_sequence_names(),
[],
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,189 @@
# testing/suite/test_unicode_ddl.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from sqlalchemy import desc
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import testing
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
class UnicodeSchemaTest(fixtures.TablesTest):
__requires__ = ("unicode_ddl",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
global t1, t2, t3
t1 = Table(
"unitable1",
metadata,
Column("méil", Integer, primary_key=True),
Column("\u6e2c\u8a66", Integer),
test_needs_fk=True,
)
t2 = Table(
"Unitéble2",
metadata,
Column("méil", Integer, primary_key=True, key="a"),
Column(
"\u6e2c\u8a66",
Integer,
ForeignKey("unitable1.méil"),
key="b",
),
test_needs_fk=True,
)
# Few DBs support Unicode foreign keys
if testing.against("sqlite"):
t3 = Table(
"\u6e2c\u8a66",
metadata,
Column(
"\u6e2c\u8a66_id",
Integer,
primary_key=True,
autoincrement=False,
),
Column(
"unitable1_\u6e2c\u8a66",
Integer,
ForeignKey("unitable1.\u6e2c\u8a66"),
),
Column("Unitéble2_b", Integer, ForeignKey("Unitéble2.b")),
Column(
"\u6e2c\u8a66_self",
Integer,
ForeignKey("\u6e2c\u8a66.\u6e2c\u8a66_id"),
),
test_needs_fk=True,
)
else:
t3 = Table(
"\u6e2c\u8a66",
metadata,
Column(
"\u6e2c\u8a66_id",
Integer,
primary_key=True,
autoincrement=False,
),
Column("unitable1_\u6e2c\u8a66", Integer),
Column("Unitéble2_b", Integer),
Column("\u6e2c\u8a66_self", Integer),
test_needs_fk=True,
)
def test_insert(self, connection):
connection.execute(t1.insert(), {"méil": 1, "\u6e2c\u8a66": 5})
connection.execute(t2.insert(), {"a": 1, "b": 1})
connection.execute(
t3.insert(),
{
"\u6e2c\u8a66_id": 1,
"unitable1_\u6e2c\u8a66": 5,
"Unitéble2_b": 1,
"\u6e2c\u8a66_self": 1,
},
)
eq_(connection.execute(t1.select()).fetchall(), [(1, 5)])
eq_(connection.execute(t2.select()).fetchall(), [(1, 1)])
eq_(connection.execute(t3.select()).fetchall(), [(1, 5, 1, 1)])
def test_col_targeting(self, connection):
connection.execute(t1.insert(), {"méil": 1, "\u6e2c\u8a66": 5})
connection.execute(t2.insert(), {"a": 1, "b": 1})
connection.execute(
t3.insert(),
{
"\u6e2c\u8a66_id": 1,
"unitable1_\u6e2c\u8a66": 5,
"Unitéble2_b": 1,
"\u6e2c\u8a66_self": 1,
},
)
row = connection.execute(t1.select()).first()
eq_(row._mapping[t1.c["méil"]], 1)
eq_(row._mapping[t1.c["\u6e2c\u8a66"]], 5)
row = connection.execute(t2.select()).first()
eq_(row._mapping[t2.c["a"]], 1)
eq_(row._mapping[t2.c["b"]], 1)
row = connection.execute(t3.select()).first()
eq_(row._mapping[t3.c["\u6e2c\u8a66_id"]], 1)
eq_(row._mapping[t3.c["unitable1_\u6e2c\u8a66"]], 5)
eq_(row._mapping[t3.c["Unitéble2_b"]], 1)
eq_(row._mapping[t3.c["\u6e2c\u8a66_self"]], 1)
def test_reflect(self, connection):
connection.execute(t1.insert(), {"méil": 2, "\u6e2c\u8a66": 7})
connection.execute(t2.insert(), {"a": 2, "b": 2})
connection.execute(
t3.insert(),
{
"\u6e2c\u8a66_id": 2,
"unitable1_\u6e2c\u8a66": 7,
"Unitéble2_b": 2,
"\u6e2c\u8a66_self": 2,
},
)
meta = MetaData()
tt1 = Table(t1.name, meta, autoload_with=connection)
tt2 = Table(t2.name, meta, autoload_with=connection)
tt3 = Table(t3.name, meta, autoload_with=connection)
connection.execute(tt1.insert(), {"méil": 1, "\u6e2c\u8a66": 5})
connection.execute(tt2.insert(), {"méil": 1, "\u6e2c\u8a66": 1})
connection.execute(
tt3.insert(),
{
"\u6e2c\u8a66_id": 1,
"unitable1_\u6e2c\u8a66": 5,
"Unitéble2_b": 1,
"\u6e2c\u8a66_self": 1,
},
)
eq_(
connection.execute(tt1.select().order_by(desc("méil"))).fetchall(),
[(2, 7), (1, 5)],
)
eq_(
connection.execute(tt2.select().order_by(desc("méil"))).fetchall(),
[(2, 2), (1, 1)],
)
eq_(
connection.execute(
tt3.select().order_by(desc("\u6e2c\u8a66_id"))
).fetchall(),
[(2, 7, 2, 2), (1, 5, 1, 1)],
)
def test_repr(self):
meta = MetaData()
t = Table("\u6e2c\u8a66", meta, Column("\u6e2c\u8a66_id", Integer))
eq_(
repr(t),
(
"Table('測試', MetaData(), "
"Column('測試_id', Integer(), "
"table=<測試>), "
"schema=None)"
),
)

View File

@ -0,0 +1,139 @@
# testing/suite/test_update_delete.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from .. import fixtures
from ..assertions import eq_
from ..schema import Column
from ..schema import Table
from ... import Integer
from ... import String
from ... import testing
class SimpleUpdateDeleteTest(fixtures.TablesTest):
run_deletes = "each"
__requires__ = ("sane_rowcount",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
"plain_pk",
metadata,
Column("id", Integer, primary_key=True),
Column("data", String(50)),
)
@classmethod
def insert_data(cls, connection):
connection.execute(
cls.tables.plain_pk.insert(),
[
{"id": 1, "data": "d1"},
{"id": 2, "data": "d2"},
{"id": 3, "data": "d3"},
],
)
def test_update(self, connection):
t = self.tables.plain_pk
r = connection.execute(
t.update().where(t.c.id == 2), dict(data="d2_new")
)
assert not r.is_insert
assert not r.returns_rows
assert r.rowcount == 1
eq_(
connection.execute(t.select().order_by(t.c.id)).fetchall(),
[(1, "d1"), (2, "d2_new"), (3, "d3")],
)
def test_delete(self, connection):
t = self.tables.plain_pk
r = connection.execute(t.delete().where(t.c.id == 2))
assert not r.is_insert
assert not r.returns_rows
assert r.rowcount == 1
eq_(
connection.execute(t.select().order_by(t.c.id)).fetchall(),
[(1, "d1"), (3, "d3")],
)
@testing.variation("criteria", ["rows", "norows", "emptyin"])
@testing.requires.update_returning
def test_update_returning(self, connection, criteria):
t = self.tables.plain_pk
stmt = t.update().returning(t.c.id, t.c.data)
if criteria.norows:
stmt = stmt.where(t.c.id == 10)
elif criteria.rows:
stmt = stmt.where(t.c.id == 2)
elif criteria.emptyin:
stmt = stmt.where(t.c.id.in_([]))
else:
criteria.fail()
r = connection.execute(stmt, dict(data="d2_new"))
assert not r.is_insert
assert r.returns_rows
eq_(r.keys(), ["id", "data"])
if criteria.rows:
eq_(r.all(), [(2, "d2_new")])
else:
eq_(r.all(), [])
eq_(
connection.execute(t.select().order_by(t.c.id)).fetchall(),
(
[(1, "d1"), (2, "d2_new"), (3, "d3")]
if criteria.rows
else [(1, "d1"), (2, "d2"), (3, "d3")]
),
)
@testing.variation("criteria", ["rows", "norows", "emptyin"])
@testing.requires.delete_returning
def test_delete_returning(self, connection, criteria):
t = self.tables.plain_pk
stmt = t.delete().returning(t.c.id, t.c.data)
if criteria.norows:
stmt = stmt.where(t.c.id == 10)
elif criteria.rows:
stmt = stmt.where(t.c.id == 2)
elif criteria.emptyin:
stmt = stmt.where(t.c.id.in_([]))
else:
criteria.fail()
r = connection.execute(stmt)
assert not r.is_insert
assert r.returns_rows
eq_(r.keys(), ["id", "data"])
if criteria.rows:
eq_(r.all(), [(2, "d2")])
else:
eq_(r.all(), [])
eq_(
connection.execute(t.select().order_by(t.c.id)).fetchall(),
(
[(1, "d1"), (3, "d3")]
if criteria.rows
else [(1, "d1"), (2, "d2"), (3, "d3")]
),
)
__all__ = ("SimpleUpdateDeleteTest",)

View File

@ -0,0 +1,538 @@
# testing/util.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
from collections import deque
import contextlib
import decimal
import gc
from itertools import chain
import random
import sys
from sys import getsizeof
import time
import types
from typing import Any
from . import config
from . import mock
from .. import inspect
from ..engine import Connection
from ..schema import Column
from ..schema import DropConstraint
from ..schema import DropTable
from ..schema import ForeignKeyConstraint
from ..schema import MetaData
from ..schema import Table
from ..sql import schema
from ..sql.sqltypes import Integer
from ..util import decorator
from ..util import defaultdict
from ..util import has_refcount_gc
from ..util import inspect_getfullargspec
if not has_refcount_gc:
def non_refcount_gc_collect(*args):
gc.collect()
gc.collect()
gc_collect = lazy_gc = non_refcount_gc_collect
else:
# assume CPython - straight gc.collect, lazy_gc() is a pass
gc_collect = gc.collect
def lazy_gc():
pass
def picklers():
picklers = set()
import pickle
picklers.add(pickle)
# yes, this thing needs this much testing
for pickle_ in picklers:
for protocol in range(-2, pickle.HIGHEST_PROTOCOL + 1):
yield pickle_.loads, lambda d: pickle_.dumps(d, protocol)
def random_choices(population, k=1):
return random.choices(population, k=k)
def round_decimal(value, prec):
if isinstance(value, float):
return round(value, prec)
# can also use shift() here but that is 2.6 only
return (value * decimal.Decimal("1" + "0" * prec)).to_integral(
decimal.ROUND_FLOOR
) / pow(10, prec)
class RandomSet(set):
def __iter__(self):
l = list(set.__iter__(self))
random.shuffle(l)
return iter(l)
def pop(self):
index = random.randint(0, len(self) - 1)
item = list(set.__iter__(self))[index]
self.remove(item)
return item
def union(self, other):
return RandomSet(set.union(self, other))
def difference(self, other):
return RandomSet(set.difference(self, other))
def intersection(self, other):
return RandomSet(set.intersection(self, other))
def copy(self):
return RandomSet(self)
def conforms_partial_ordering(tuples, sorted_elements):
"""True if the given sorting conforms to the given partial ordering."""
deps = defaultdict(set)
for parent, child in tuples:
deps[parent].add(child)
for i, node in enumerate(sorted_elements):
for n in sorted_elements[i:]:
if node in deps[n]:
return False
else:
return True
def all_partial_orderings(tuples, elements):
edges = defaultdict(set)
for parent, child in tuples:
edges[child].add(parent)
def _all_orderings(elements):
if len(elements) == 1:
yield list(elements)
else:
for elem in elements:
subset = set(elements).difference([elem])
if not subset.intersection(edges[elem]):
for sub_ordering in _all_orderings(subset):
yield [elem] + sub_ordering
return iter(_all_orderings(elements))
def function_named(fn, name):
"""Return a function with a given __name__.
Will assign to __name__ and return the original function if possible on
the Python implementation, otherwise a new function will be constructed.
This function should be phased out as much as possible
in favor of @decorator. Tests that "generate" many named tests
should be modernized.
"""
try:
fn.__name__ = name
except TypeError:
fn = types.FunctionType(
fn.__code__, fn.__globals__, name, fn.__defaults__, fn.__closure__
)
return fn
def run_as_contextmanager(ctx, fn, *arg, **kw):
"""Run the given function under the given contextmanager,
simulating the behavior of 'with' to support older
Python versions.
This is not necessary anymore as we have placed 2.6
as minimum Python version, however some tests are still using
this structure.
"""
obj = ctx.__enter__()
try:
result = fn(obj, *arg, **kw)
ctx.__exit__(None, None, None)
return result
except:
exc_info = sys.exc_info()
raise_ = ctx.__exit__(*exc_info)
if not raise_:
raise
else:
return raise_
def rowset(results):
"""Converts the results of sql execution into a plain set of column tuples.
Useful for asserting the results of an unordered query.
"""
return {tuple(row) for row in results}
def fail(msg):
assert False, msg
@decorator
def provide_metadata(fn, *args, **kw):
"""Provide bound MetaData for a single test, dropping afterwards.
Legacy; use the "metadata" pytest fixture.
"""
from . import fixtures
metadata = schema.MetaData()
self = args[0]
prev_meta = getattr(self, "metadata", None)
self.metadata = metadata
try:
return fn(*args, **kw)
finally:
# close out some things that get in the way of dropping tables.
# when using the "metadata" fixture, there is a set ordering
# of things that makes sure things are cleaned up in order, however
# the simple "decorator" nature of this legacy function means
# we have to hardcode some of that cleanup ahead of time.
# close ORM sessions
fixtures.close_all_sessions()
# integrate with the "connection" fixture as there are many
# tests where it is used along with provide_metadata
cfc = fixtures.base._connection_fixture_connection
if cfc:
# TODO: this warning can be used to find all the places
# this is used with connection fixture
# warn("mixing legacy provide metadata with connection fixture")
drop_all_tables_from_metadata(metadata, cfc)
# as the provide_metadata fixture is often used with "testing.db",
# when we do the drop we have to commit the transaction so that
# the DB is actually updated as the CREATE would have been
# committed
cfc.get_transaction().commit()
else:
drop_all_tables_from_metadata(metadata, config.db)
self.metadata = prev_meta
def flag_combinations(*combinations):
"""A facade around @testing.combinations() oriented towards boolean
keyword-based arguments.
Basically generates a nice looking identifier based on the keywords
and also sets up the argument names.
E.g.::
@testing.flag_combinations(
dict(lazy=False, passive=False),
dict(lazy=True, passive=False),
dict(lazy=False, passive=True),
dict(lazy=False, passive=True, raiseload=True),
)
def test_fn(lazy, passive, raiseload): ...
would result in::
@testing.combinations(
("", False, False, False),
("lazy", True, False, False),
("lazy_passive", True, True, False),
("lazy_passive", True, True, True),
id_="iaaa",
argnames="lazy,passive,raiseload",
)
def test_fn(lazy, passive, raiseload): ...
"""
keys = set()
for d in combinations:
keys.update(d)
keys = sorted(keys)
return config.combinations(
*[
("_".join(k for k in keys if d.get(k, False)),)
+ tuple(d.get(k, False) for k in keys)
for d in combinations
],
id_="i" + ("a" * len(keys)),
argnames=",".join(keys),
)
def lambda_combinations(lambda_arg_sets, **kw):
args = inspect_getfullargspec(lambda_arg_sets)
arg_sets = lambda_arg_sets(*[mock.Mock() for arg in args[0]])
def create_fixture(pos):
def fixture(**kw):
return lambda_arg_sets(**kw)[pos]
fixture.__name__ = "fixture_%3.3d" % pos
return fixture
return config.combinations(
*[(create_fixture(i),) for i in range(len(arg_sets))], **kw
)
def resolve_lambda(__fn, **kw):
"""Given a no-arg lambda and a namespace, return a new lambda that
has all the values filled in.
This is used so that we can have module-level fixtures that
refer to instance-level variables using lambdas.
"""
pos_args = inspect_getfullargspec(__fn)[0]
pass_pos_args = {arg: kw.pop(arg) for arg in pos_args}
glb = dict(__fn.__globals__)
glb.update(kw)
new_fn = types.FunctionType(__fn.__code__, glb)
return new_fn(**pass_pos_args)
def metadata_fixture(ddl="function"):
"""Provide MetaData for a pytest fixture."""
def decorate(fn):
def run_ddl(self):
metadata = self.metadata = schema.MetaData()
try:
result = fn(self, metadata)
metadata.create_all(config.db)
# TODO:
# somehow get a per-function dml erase fixture here
yield result
finally:
metadata.drop_all(config.db)
return config.fixture(scope=ddl)(run_ddl)
return decorate
def force_drop_names(*names):
"""Force the given table names to be dropped after test complete,
isolating for foreign key cycles
"""
@decorator
def go(fn, *args, **kw):
try:
return fn(*args, **kw)
finally:
drop_all_tables(config.db, inspect(config.db), include_names=names)
return go
class adict(dict):
"""Dict keys available as attributes. Shadows."""
def __getattribute__(self, key):
try:
return self[key]
except KeyError:
return dict.__getattribute__(self, key)
def __call__(self, *keys):
return tuple([self[key] for key in keys])
get_all = __call__
def drop_all_tables_from_metadata(metadata, engine_or_connection):
from . import engines
def go(connection):
engines.testing_reaper.prepare_for_drop_tables(connection)
if not connection.dialect.supports_alter:
from . import assertions
with assertions.expect_warnings(
"Can't sort tables", assert_=False
):
metadata.drop_all(connection)
else:
metadata.drop_all(connection)
if not isinstance(engine_or_connection, Connection):
with engine_or_connection.begin() as connection:
go(connection)
else:
go(engine_or_connection)
def drop_all_tables(
engine,
inspector,
schema=None,
consider_schemas=(None,),
include_names=None,
):
if include_names is not None:
include_names = set(include_names)
if schema is not None:
assert consider_schemas == (
None,
), "consider_schemas and schema are mutually exclusive"
consider_schemas = (schema,)
with engine.begin() as conn:
for table_key, fkcs in reversed(
inspector.sort_tables_on_foreign_key_dependency(
consider_schemas=consider_schemas
)
):
if table_key:
if (
include_names is not None
and table_key[1] not in include_names
):
continue
conn.execute(
DropTable(
Table(table_key[1], MetaData(), schema=table_key[0])
)
)
elif fkcs:
if not engine.dialect.supports_alter:
continue
for t_key, fkc in fkcs:
if (
include_names is not None
and t_key[1] not in include_names
):
continue
tb = Table(
t_key[1],
MetaData(),
Column("x", Integer),
Column("y", Integer),
schema=t_key[0],
)
conn.execute(
DropConstraint(
ForeignKeyConstraint([tb.c.x], [tb.c.y], name=fkc)
)
)
def teardown_events(event_cls):
@decorator
def decorate(fn, *arg, **kw):
try:
return fn(*arg, **kw)
finally:
event_cls._clear()
return decorate
def total_size(o):
"""Returns the approximate memory footprint an object and all of its
contents.
source: https://code.activestate.com/recipes/577504/
"""
def dict_handler(d):
return chain.from_iterable(d.items())
all_handlers = {
tuple: iter,
list: iter,
deque: iter,
dict: dict_handler,
set: iter,
frozenset: iter,
}
seen = set() # track which object id's have already been seen
default_size = getsizeof(0) # estimate sizeof object without __sizeof__
def sizeof(o):
if id(o) in seen: # do not double count the same object
return 0
seen.add(id(o))
s = getsizeof(o, default_size)
for typ, handler in all_handlers.items():
if isinstance(o, typ):
s += sum(map(sizeof, handler(o)))
break
return s
return sizeof(o)
def count_cache_key_tuples(tup):
"""given a cache key tuple, counts how many instances of actual
tuples are found.
used to alert large jumps in cache key complexity.
"""
stack = [tup]
sentinel = object()
num_elements = 0
while stack:
elem = stack.pop(0)
if elem is sentinel:
num_elements += 1
elif isinstance(elem, tuple):
if elem:
stack = list(elem) + [sentinel] + stack
return num_elements
@contextlib.contextmanager
def skip_if_timeout(seconds: float, cleanup: Any = None):
now = time.time()
yield
sec = time.time() - now
if sec > seconds:
try:
cleanup()
finally:
config.skip_test(
f"test took too long ({sec:.4f} seconds > {seconds})"
)

View File

@ -0,0 +1,52 @@
# testing/warnings.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import warnings
from . import assertions
from .. import exc
from .. import exc as sa_exc
from ..exc import SATestSuiteWarning
from ..util.langhelpers import _warnings_warn
def warn_test_suite(message):
_warnings_warn(message, category=SATestSuiteWarning)
def setup_filters():
"""hook for setting up warnings filters.
SQLAlchemy-specific classes must only be here and not in pytest config,
as we need to delay importing SQLAlchemy until conftest.py has been
processed.
NOTE: filters on subclasses of DeprecationWarning or
PendingDeprecationWarning have no effect if added here, since pytest
will add at each test the following filters
``always::PendingDeprecationWarning`` and ``always::DeprecationWarning``
that will take precedence over any added here.
"""
warnings.filterwarnings("error", category=exc.SAWarning)
warnings.filterwarnings("always", category=exc.SATestSuiteWarning)
def assert_warnings(fn, warning_msgs, regex=False):
"""Assert that each of the given warnings are emitted by fn.
Deprecated. Please use assertions.expect_warnings().
"""
with assertions._expect_warnings(
sa_exc.SAWarning, warning_msgs, regex=regex
):
return fn()