Update 2025-04-13_16:25:39

This commit is contained in:
root
2025-04-13 16:25:41 +02:00
commit 4c711360d3
2979 changed files with 666585 additions and 0 deletions

View File

@ -0,0 +1,6 @@
# testing/plugin/__init__.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php

View File

@ -0,0 +1,51 @@
# testing/plugin/bootstrap.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""
Bootstrapper for test framework plugins.
The entire rationale for this system is to get the modules in plugin/
imported without importing all of the supporting library, so that we can
set up things for testing before coverage starts.
The rationale for all of plugin/ being *in* the supporting library in the
first place is so that the testing and plugin suite is available to other
libraries, mainly external SQLAlchemy and Alembic dialects, to make use
of the same test environment and standard suites available to
SQLAlchemy/Alembic themselves without the need to ship/install a separate
package outside of SQLAlchemy.
"""
import importlib.util
import os
import sys
bootstrap_file = locals()["bootstrap_file"]
to_bootstrap = locals()["to_bootstrap"]
def load_file_as_module(name):
path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
spec = importlib.util.spec_from_file_location(name, path)
assert spec is not None
assert spec.loader is not None
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
if to_bootstrap == "pytest":
sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
sys.modules["sqla_plugin_base"].bootstrapped_as_sqlalchemy = True
sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
else:
raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa

View File

@ -0,0 +1,779 @@
# testing/plugin/plugin_base.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import abc
from argparse import Namespace
import configparser
import logging
import os
from pathlib import Path
import re
import sys
from typing import Any
from sqlalchemy.testing import asyncio
"""Testing extensions.
this module is designed to work as a testing-framework-agnostic library,
created so that multiple test frameworks can be supported at once
(mostly so that we can migrate to new ones). The current target
is pytest.
"""
# flag which indicates we are in the SQLAlchemy testing suite,
# and not that of Alembic or a third party dialect.
bootstrapped_as_sqlalchemy = False
log = logging.getLogger("sqlalchemy.testing.plugin_base")
# late imports
fixtures = None
engines = None
exclusions = None
warnings = None
profiling = None
provision = None
assertions = None
requirements = None
config = None
testing = None
util = None
file_config = None
logging = None
include_tags = set()
exclude_tags = set()
options: Namespace = None # type: ignore
def setup_options(make_option):
make_option(
"--log-info",
action="callback",
type=str,
callback=_log,
help="turn on info logging for <LOG> (multiple OK)",
)
make_option(
"--log-debug",
action="callback",
type=str,
callback=_log,
help="turn on debug logging for <LOG> (multiple OK)",
)
make_option(
"--db",
action="append",
type=str,
dest="db",
help="Use prefab database uri. Multiple OK, "
"first one is run by default.",
)
make_option(
"--dbs",
action="callback",
zeroarg_callback=_list_dbs,
help="List available prefab dbs",
)
make_option(
"--dburi",
action="append",
type=str,
dest="dburi",
help="Database uri. Multiple OK, first one is run by default.",
)
make_option(
"--dbdriver",
action="append",
type=str,
dest="dbdriver",
help="Additional database drivers to include in tests. "
"These are linked to the existing database URLs by the "
"provisioning system.",
)
make_option(
"--dropfirst",
action="store_true",
dest="dropfirst",
help="Drop all tables in the target database first",
)
make_option(
"--disable-asyncio",
action="store_true",
help="disable test / fixtures / provisoning running in asyncio",
)
make_option(
"--backend-only",
action="callback",
zeroarg_callback=_set_tag_include("backend"),
help=(
"Run only tests marked with __backend__ or __sparse_backend__; "
"this is now equivalent to the pytest -m backend mark expression"
),
)
make_option(
"--nomemory",
action="callback",
zeroarg_callback=_set_tag_exclude("memory_intensive"),
help="Don't run memory profiling tests; "
"this is now equivalent to the pytest -m 'not memory_intensive' "
"mark expression",
)
make_option(
"--notimingintensive",
action="callback",
zeroarg_callback=_set_tag_exclude("timing_intensive"),
help="Don't run timing intensive tests; "
"this is now equivalent to the pytest -m 'not timing_intensive' "
"mark expression",
)
make_option(
"--nomypy",
action="callback",
zeroarg_callback=_set_tag_exclude("mypy"),
help="Don't run mypy typing tests; "
"this is now equivalent to the pytest -m 'not mypy' mark expression",
)
make_option(
"--profile-sort",
type=str,
default="cumulative",
dest="profilesort",
help="Type of sort for profiling standard output",
)
make_option(
"--profile-dump",
type=str,
dest="profiledump",
help="Filename where a single profile run will be dumped",
)
make_option(
"--low-connections",
action="store_true",
dest="low_connections",
help="Use a low number of distinct connections - "
"i.e. for Oracle TNS",
)
make_option(
"--write-idents",
type=str,
dest="write_idents",
help="write out generated follower idents to <file>, "
"when -n<num> is used",
)
make_option(
"--requirements",
action="callback",
type=str,
callback=_requirements_opt,
help="requirements class for testing, overrides setup.cfg",
)
make_option(
"--include-tag",
action="callback",
callback=_include_tag,
type=str,
help="Include tests with tag <tag>; "
"legacy, use pytest -m 'tag' instead",
)
make_option(
"--exclude-tag",
action="callback",
callback=_exclude_tag,
type=str,
help="Exclude tests with tag <tag>; "
"legacy, use pytest -m 'not tag' instead",
)
make_option(
"--write-profiles",
action="store_true",
dest="write_profiles",
default=False,
help="Write/update failing profiling data.",
)
make_option(
"--force-write-profiles",
action="store_true",
dest="force_write_profiles",
default=False,
help="Unconditionally write/update profiling data.",
)
make_option(
"--dump-pyannotate",
type=str,
dest="dump_pyannotate",
help="Run pyannotate and dump json info to given file",
)
make_option(
"--mypy-extra-test-path",
type=str,
action="append",
default=[],
dest="mypy_extra_test_paths",
help="Additional test directories to add to the mypy tests. "
"This is used only when running mypy tests. Multiple OK",
)
# db specific options
make_option(
"--postgresql-templatedb",
type=str,
help="name of template database to use for PostgreSQL "
"CREATE DATABASE (defaults to current database)",
)
make_option(
"--oracledb-thick-mode",
action="store_true",
help="enables the 'thick mode' when testing with oracle+oracledb",
)
def configure_follower(follower_ident):
"""Configure required state for a follower.
This invokes in the parent process and typically includes
database creation.
"""
from sqlalchemy.testing import provision
provision.FOLLOWER_IDENT = follower_ident
def memoize_important_follower_config(dict_):
"""Store important configuration we will need to send to a follower.
This invokes in the parent process after normal config is set up.
Hook is currently not used.
"""
def restore_important_follower_config(dict_):
"""Restore important configuration needed by a follower.
This invokes in the follower process.
Hook is currently not used.
"""
def read_config(root_path):
global file_config
file_config = configparser.ConfigParser()
file_config.read(
[str(root_path / "setup.cfg"), str(root_path / "test.cfg")]
)
def pre_begin(opt):
"""things to set up early, before coverage might be setup."""
global options
options = opt
for fn in pre_configure:
fn(options, file_config)
def set_coverage_flag(value):
options.has_coverage = value
def post_begin():
"""things to set up later, once we know coverage is running."""
# Lazy setup of other options (post coverage)
for fn in post_configure:
fn(options, file_config)
# late imports, has to happen after config.
global util, fixtures, engines, exclusions, assertions, provision
global warnings, profiling, config, testing
from sqlalchemy import testing # noqa
from sqlalchemy.testing import fixtures, engines, exclusions # noqa
from sqlalchemy.testing import assertions, warnings, profiling # noqa
from sqlalchemy.testing import config, provision # noqa
from sqlalchemy import util # noqa
warnings.setup_filters()
def _log(opt_str, value, parser):
global logging
if not logging:
import logging
logging.basicConfig()
if opt_str.endswith("-info"):
logging.getLogger(value).setLevel(logging.INFO)
elif opt_str.endswith("-debug"):
logging.getLogger(value).setLevel(logging.DEBUG)
def _list_dbs(*args):
if file_config is None:
# assume the current working directory is the one containing the
# setup file
read_config(Path.cwd())
print("Available --db options (use --dburi to override)")
for macro in sorted(file_config.options("db")):
print("%20s\t%s" % (macro, file_config.get("db", macro)))
sys.exit(0)
def _requirements_opt(opt_str, value, parser):
_setup_requirements(value)
def _set_tag_include(tag):
def _do_include_tag(opt_str, value, parser):
_include_tag(opt_str, tag, parser)
return _do_include_tag
def _set_tag_exclude(tag):
def _do_exclude_tag(opt_str, value, parser):
_exclude_tag(opt_str, tag, parser)
return _do_exclude_tag
def _exclude_tag(opt_str, value, parser):
exclude_tags.add(value.replace("-", "_"))
def _include_tag(opt_str, value, parser):
include_tags.add(value.replace("-", "_"))
pre_configure = []
post_configure = []
def pre(fn):
pre_configure.append(fn)
return fn
def post(fn):
post_configure.append(fn)
return fn
@pre
def _setup_options(opt, file_config):
global options
options = opt
@pre
def _register_sqlite_numeric_dialect(opt, file_config):
from sqlalchemy.dialects import registry
registry.register(
"sqlite.pysqlite_numeric",
"sqlalchemy.dialects.sqlite.pysqlite",
"_SQLiteDialect_pysqlite_numeric",
)
registry.register(
"sqlite.pysqlite_dollar",
"sqlalchemy.dialects.sqlite.pysqlite",
"_SQLiteDialect_pysqlite_dollar",
)
@post
def __ensure_cext(opt, file_config):
if os.environ.get("REQUIRE_SQLALCHEMY_CEXT", "0") == "1":
from sqlalchemy.util import has_compiled_ext
try:
has_compiled_ext(raise_=True)
except ImportError as err:
raise AssertionError(
"REQUIRE_SQLALCHEMY_CEXT is set but can't import the "
"cython extensions"
) from err
@post
def _init_symbols(options, file_config):
from sqlalchemy.testing import config
config._fixture_functions = _fixture_fn_class()
@pre
def _set_disable_asyncio(opt, file_config):
if opt.disable_asyncio:
asyncio.ENABLE_ASYNCIO = False
@post
def _engine_uri(options, file_config):
from sqlalchemy import testing
from sqlalchemy.testing import config
from sqlalchemy.testing import provision
from sqlalchemy.engine import url as sa_url
if options.dburi:
db_urls = list(options.dburi)
else:
db_urls = []
extra_drivers = options.dbdriver or []
if options.db:
for db_token in options.db:
for db in re.split(r"[,\s]+", db_token):
if db not in file_config.options("db"):
raise RuntimeError(
"Unknown URI specifier '%s'. "
"Specify --dbs for known uris." % db
)
else:
db_urls.append(file_config.get("db", db))
if not db_urls:
db_urls.append(file_config.get("db", "default"))
config._current = None
if options.write_idents and provision.FOLLOWER_IDENT:
for db_url in [sa_url.make_url(db_url) for db_url in db_urls]:
with open(options.write_idents, "a") as file_:
file_.write(
f"{provision.FOLLOWER_IDENT} "
f"{db_url.render_as_string(hide_password=False)}\n"
)
expanded_urls = list(provision.generate_db_urls(db_urls, extra_drivers))
for db_url in expanded_urls:
log.info("Adding database URL: %s", db_url)
cfg = provision.setup_config(
db_url, options, file_config, provision.FOLLOWER_IDENT
)
if not config._current:
cfg.set_as_current(cfg, testing)
@post
def _requirements(options, file_config):
requirement_cls = file_config.get("sqla_testing", "requirement_cls")
_setup_requirements(requirement_cls)
def _setup_requirements(argument):
from sqlalchemy.testing import config
from sqlalchemy import testing
modname, clsname = argument.split(":")
# importlib.import_module() only introduced in 2.7, a little
# late
mod = __import__(modname)
for component in modname.split(".")[1:]:
mod = getattr(mod, component)
req_cls = getattr(mod, clsname)
config.requirements = testing.requires = req_cls()
config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy
@post
def _prep_testing_database(options, file_config):
from sqlalchemy.testing import config
if options.dropfirst:
from sqlalchemy.testing import provision
for cfg in config.Config.all_configs():
provision.drop_all_schema_objects(cfg, cfg.db)
@post
def _post_setup_options(opt, file_config):
from sqlalchemy.testing import config
config.options = options
config.file_config = file_config
@post
def _setup_profiling(options, file_config):
from sqlalchemy.testing import profiling
profiling._profile_stats = profiling.ProfileStatsFile(
file_config.get("sqla_testing", "profile_file"),
sort=options.profilesort,
dump=options.profiledump,
)
def want_class(name, cls):
if not issubclass(cls, fixtures.TestBase):
return False
elif name.startswith("_"):
return False
else:
return True
def want_method(cls, fn):
if not fn.__name__.startswith("test_"):
return False
elif fn.__module__ is None:
return False
else:
return True
def generate_sub_tests(cls, module, markers):
if "backend" in markers or "sparse_backend" in markers:
sparse = "sparse_backend" in markers
for cfg in _possible_configs_for_cls(cls, sparse=sparse):
orig_name = cls.__name__
# we can have special chars in these names except for the
# pytest junit plugin, which is tripped up by the brackets
# and periods, so sanitize
alpha_name = re.sub(r"[_\[\]\.]+", "_", cfg.name)
alpha_name = re.sub(r"_+$", "", alpha_name)
name = "%s_%s" % (cls.__name__, alpha_name)
subcls = type(
name,
(cls,),
{"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg},
)
setattr(module, name, subcls)
yield subcls
else:
yield cls
def start_test_class_outside_fixtures(cls):
_do_skips(cls)
_setup_engine(cls)
def stop_test_class(cls):
# close sessions, immediate connections, etc.
fixtures.stop_test_class_inside_fixtures(cls)
# close outstanding connection pool connections, dispose of
# additional engines
engines.testing_reaper.stop_test_class_inside_fixtures()
def stop_test_class_outside_fixtures(cls):
engines.testing_reaper.stop_test_class_outside_fixtures()
provision.stop_test_class_outside_fixtures(config, config.db, cls)
try:
if not options.low_connections:
assertions.global_cleanup_assertions()
finally:
_restore_engine()
def _restore_engine():
if config._current:
config._current.reset(testing)
def final_process_cleanup():
engines.testing_reaper.final_cleanup()
assertions.global_cleanup_assertions()
_restore_engine()
def _setup_engine(cls):
if getattr(cls, "__engine_options__", None):
opts = dict(cls.__engine_options__)
opts["scope"] = "class"
eng = engines.testing_engine(options=opts)
config._current.push_engine(eng, testing)
def before_test(test, test_module_name, test_class, test_name):
# format looks like:
# "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause"
name = getattr(test_class, "_sa_orig_cls_name", test_class.__name__)
id_ = "%s.%s.%s" % (test_module_name, name, test_name)
profiling._start_current_test(id_)
def after_test(test):
fixtures.after_test()
engines.testing_reaper.after_test()
def after_test_fixtures(test):
engines.testing_reaper.after_test_outside_fixtures(test)
def _possible_configs_for_cls(cls, reasons=None, sparse=False):
all_configs = set(config.Config.all_configs())
if cls.__unsupported_on__:
spec = exclusions.db_spec(*cls.__unsupported_on__)
for config_obj in list(all_configs):
if spec(config_obj):
all_configs.remove(config_obj)
if getattr(cls, "__only_on__", None):
spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
for config_obj in list(all_configs):
if not spec(config_obj):
all_configs.remove(config_obj)
if getattr(cls, "__only_on_config__", None):
all_configs.intersection_update([cls.__only_on_config__])
if hasattr(cls, "__requires__"):
requirements = config.requirements
for config_obj in list(all_configs):
for requirement in cls.__requires__:
check = getattr(requirements, requirement)
skip_reasons = check.matching_config_reasons(config_obj)
if skip_reasons:
all_configs.remove(config_obj)
if reasons is not None:
reasons.extend(skip_reasons)
break
if hasattr(cls, "__prefer_requires__"):
non_preferred = set()
requirements = config.requirements
for config_obj in list(all_configs):
for requirement in cls.__prefer_requires__:
check = getattr(requirements, requirement)
if not check.enabled_for_config(config_obj):
non_preferred.add(config_obj)
if all_configs.difference(non_preferred):
all_configs.difference_update(non_preferred)
if sparse:
# pick only one config from each base dialect
# sorted so we get the same backend each time selecting the highest
# server version info.
per_dialect = {}
for cfg in reversed(
sorted(
all_configs,
key=lambda cfg: (
cfg.db.name,
cfg.db.driver,
cfg.db.dialect.server_version_info,
),
)
):
db = cfg.db.name
if db not in per_dialect:
per_dialect[db] = cfg
return per_dialect.values()
return all_configs
def _do_skips(cls):
reasons = []
all_configs = _possible_configs_for_cls(cls, reasons)
if getattr(cls, "__skip_if__", False):
for c in getattr(cls, "__skip_if__"):
if c():
config.skip_test(
"'%s' skipped by %s" % (cls.__name__, c.__name__)
)
if not all_configs:
msg = "'%s.%s' unsupported on any DB implementation %s%s" % (
cls.__module__,
cls.__name__,
", ".join(
"'%s(%s)+%s'"
% (
config_obj.db.name,
".".join(
str(dig)
for dig in exclusions._server_version(config_obj.db)
),
config_obj.db.driver,
)
for config_obj in config.Config.all_configs()
),
", ".join(reasons),
)
config.skip_test(msg)
elif hasattr(cls, "__prefer_backends__"):
non_preferred = set()
spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
for config_obj in all_configs:
if not spec(config_obj):
non_preferred.add(config_obj)
if all_configs.difference(non_preferred):
all_configs.difference_update(non_preferred)
if config._current not in all_configs:
_setup_config(all_configs.pop(), cls)
def _setup_config(config_obj, ctx):
config._current.push(config_obj, testing)
class FixtureFunctions(abc.ABC):
@abc.abstractmethod
def skip_test_exception(self, *arg, **kw):
raise NotImplementedError()
@abc.abstractmethod
def combinations(self, *args, **kw):
raise NotImplementedError()
@abc.abstractmethod
def param_ident(self, *args, **kw):
raise NotImplementedError()
@abc.abstractmethod
def fixture(self, *arg, **kw):
raise NotImplementedError()
def get_current_test_name(self):
raise NotImplementedError()
@abc.abstractmethod
def mark_base_test_class(self) -> Any:
raise NotImplementedError()
@abc.abstractproperty
def add_to_marker(self):
raise NotImplementedError()
_fixture_fn_class = None
def set_fixture_functions(fixture_fn_class):
global _fixture_fn_class
_fixture_fn_class = fixture_fn_class

View File

@ -0,0 +1,868 @@
# testing/plugin/pytestplugin.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import argparse
import collections
from functools import update_wrapper
import inspect
import itertools
import operator
import os
import re
import sys
from typing import TYPE_CHECKING
import uuid
import pytest
try:
# installed by bootstrap.py
if not TYPE_CHECKING:
import sqla_plugin_base as plugin_base
except ImportError:
# assume we're a package, use traditional import
from . import plugin_base
def pytest_addoption(parser):
group = parser.getgroup("sqlalchemy")
def make_option(name, **kw):
callback_ = kw.pop("callback", None)
if callback_:
class CallableAction(argparse.Action):
def __call__(
self, parser, namespace, values, option_string=None
):
callback_(option_string, values, parser)
kw["action"] = CallableAction
zeroarg_callback = kw.pop("zeroarg_callback", None)
if zeroarg_callback:
class CallableAction(argparse.Action):
def __init__(
self,
option_strings,
dest,
default=False,
required=False,
help=None, # noqa
):
super().__init__(
option_strings=option_strings,
dest=dest,
nargs=0,
const=True,
default=default,
required=required,
help=help,
)
def __call__(
self, parser, namespace, values, option_string=None
):
zeroarg_callback(option_string, values, parser)
kw["action"] = CallableAction
group.addoption(name, **kw)
plugin_base.setup_options(make_option)
def pytest_configure(config: pytest.Config):
plugin_base.read_config(config.rootpath)
if plugin_base.exclude_tags or plugin_base.include_tags:
new_expr = " and ".join(
list(plugin_base.include_tags)
+ [f"not {tag}" for tag in plugin_base.exclude_tags]
)
if config.option.markexpr:
config.option.markexpr += f" and {new_expr}"
else:
config.option.markexpr = new_expr
if config.pluginmanager.hasplugin("xdist"):
config.pluginmanager.register(XDistHooks())
if hasattr(config, "workerinput"):
plugin_base.restore_important_follower_config(config.workerinput)
plugin_base.configure_follower(config.workerinput["follower_ident"])
else:
if config.option.write_idents and os.path.exists(
config.option.write_idents
):
os.remove(config.option.write_idents)
plugin_base.pre_begin(config.option)
plugin_base.set_coverage_flag(
bool(getattr(config.option, "cov_source", False))
)
plugin_base.set_fixture_functions(PytestFixtureFunctions)
if config.option.dump_pyannotate:
global DUMP_PYANNOTATE
DUMP_PYANNOTATE = True
DUMP_PYANNOTATE = False
@pytest.fixture(autouse=True)
def collect_types_fixture():
if DUMP_PYANNOTATE:
from pyannotate_runtime import collect_types
collect_types.start()
yield
if DUMP_PYANNOTATE:
collect_types.stop()
def _log_sqlalchemy_info(session):
import sqlalchemy
from sqlalchemy import __version__
from sqlalchemy.util import has_compiled_ext
from sqlalchemy.util._has_cy import _CYEXTENSION_MSG
greet = "sqlalchemy installation"
site = "no user site" if sys.flags.no_user_site else "user site loaded"
msgs = [
f"SQLAlchemy {__version__} ({site})",
f"Path: {sqlalchemy.__file__}",
]
if has_compiled_ext():
from sqlalchemy.cyextension import util
msgs.append(f"compiled extension enabled, e.g. {util.__file__} ")
else:
msgs.append(f"compiled extension not enabled; {_CYEXTENSION_MSG}")
pm = session.config.pluginmanager.get_plugin("terminalreporter")
if pm:
pm.write_sep("=", greet)
for m in msgs:
pm.write_line(m)
else:
# fancy pants reporter not found, fallback to plain print
print("=" * 25, greet, "=" * 25)
for m in msgs:
print(m)
def pytest_sessionstart(session):
from sqlalchemy.testing import asyncio
_log_sqlalchemy_info(session)
asyncio._assume_async(plugin_base.post_begin)
def pytest_sessionfinish(session):
from sqlalchemy.testing import asyncio
asyncio._maybe_async_provisioning(plugin_base.final_process_cleanup)
if session.config.option.dump_pyannotate:
from pyannotate_runtime import collect_types
collect_types.dump_stats(session.config.option.dump_pyannotate)
def pytest_unconfigure(config):
from sqlalchemy.testing import asyncio
asyncio._shutdown()
def pytest_collection_finish(session):
if session.config.option.dump_pyannotate:
from pyannotate_runtime import collect_types
lib_sqlalchemy = os.path.abspath("lib/sqlalchemy")
def _filter(filename):
filename = os.path.normpath(os.path.abspath(filename))
if "lib/sqlalchemy" not in os.path.commonpath(
[filename, lib_sqlalchemy]
):
return None
if "testing" in filename:
return None
return filename
collect_types.init_types_collection(filter_filename=_filter)
class XDistHooks:
def pytest_configure_node(self, node):
from sqlalchemy.testing import provision
from sqlalchemy.testing import asyncio
# the master for each node fills workerinput dictionary
# which pytest-xdist will transfer to the subprocess
plugin_base.memoize_important_follower_config(node.workerinput)
node.workerinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]
asyncio._maybe_async_provisioning(
provision.create_follower_db, node.workerinput["follower_ident"]
)
def pytest_testnodedown(self, node, error):
from sqlalchemy.testing import provision
from sqlalchemy.testing import asyncio
asyncio._maybe_async_provisioning(
provision.drop_follower_db, node.workerinput["follower_ident"]
)
def pytest_collection_modifyitems(session, config, items):
# look for all those classes that specify __backend__ and
# expand them out into per-database test cases.
# this is much easier to do within pytest_pycollect_makeitem, however
# pytest is iterating through cls.__dict__ as makeitem is
# called which causes a "dictionary changed size" error on py3k.
# I'd submit a pullreq for them to turn it into a list first, but
# it's to suit the rather odd use case here which is that we are adding
# new classes to a module on the fly.
from sqlalchemy.testing import asyncio
rebuilt_items = collections.defaultdict(
lambda: collections.defaultdict(list)
)
items[:] = [
item
for item in items
if item.getparent(pytest.Class) is not None
and not item.getparent(pytest.Class).name.startswith("_")
]
test_classes = {item.getparent(pytest.Class) for item in items}
def collect(element):
for inst_or_fn in element.collect():
if isinstance(inst_or_fn, pytest.Collector):
yield from collect(inst_or_fn)
else:
yield inst_or_fn
def setup_test_classes():
for test_class in test_classes:
# transfer legacy __backend__ and __sparse_backend__ symbols
# to be markers
add_markers = set()
if getattr(test_class.cls, "__backend__", False) or getattr(
test_class.cls, "__only_on__", False
):
add_markers = {"backend"}
elif getattr(test_class.cls, "__sparse_backend__", False):
add_markers = {"sparse_backend"}
else:
add_markers = frozenset()
existing_markers = {
mark.name for mark in test_class.iter_markers()
}
add_markers = add_markers - existing_markers
all_markers = existing_markers.union(add_markers)
for marker in add_markers:
test_class.add_marker(marker)
for sub_cls in plugin_base.generate_sub_tests(
test_class.cls, test_class.module, all_markers
):
if sub_cls is not test_class.cls:
per_cls_dict = rebuilt_items[test_class.cls]
module = test_class.getparent(pytest.Module)
new_cls = pytest.Class.from_parent(
name=sub_cls.__name__, parent=module
)
for marker in add_markers:
new_cls.add_marker(marker)
for fn in collect(new_cls):
per_cls_dict[fn.name].append(fn)
# class requirements will sometimes need to access the DB to check
# capabilities, so need to do this for async
asyncio._maybe_async_provisioning(setup_test_classes)
newitems = []
for item in items:
cls_ = item.cls
if cls_ in rebuilt_items:
newitems.extend(rebuilt_items[cls_][item.name])
else:
newitems.append(item)
# seems like the functions attached to a test class aren't sorted already?
# is that true and why's that? (when using unittest, they're sorted)
items[:] = sorted(
newitems,
key=lambda item: (
item.getparent(pytest.Module).name,
item.getparent(pytest.Class).name,
item.name,
),
)
def pytest_pycollect_makeitem(collector, name, obj):
if inspect.isclass(obj) and plugin_base.want_class(name, obj):
from sqlalchemy.testing import config
if config.any_async:
obj = _apply_maybe_async(obj)
return [
pytest.Class.from_parent(
name=parametrize_cls.__name__, parent=collector
)
for parametrize_cls in _parametrize_cls(collector.module, obj)
]
elif (
inspect.isfunction(obj)
and collector.cls is not None
and plugin_base.want_method(collector.cls, obj)
):
# None means, fall back to default logic, which includes
# method-level parametrize
return None
else:
# empty list means skip this item
return []
def _is_wrapped_coroutine_function(fn):
while hasattr(fn, "__wrapped__"):
fn = fn.__wrapped__
return inspect.iscoroutinefunction(fn)
def _apply_maybe_async(obj, recurse=True):
from sqlalchemy.testing import asyncio
for name, value in vars(obj).items():
if (
(callable(value) or isinstance(value, classmethod))
and not getattr(value, "_maybe_async_applied", False)
and (name.startswith("test_"))
and not _is_wrapped_coroutine_function(value)
):
is_classmethod = False
if isinstance(value, classmethod):
value = value.__func__
is_classmethod = True
@_pytest_fn_decorator
def make_async(fn, *args, **kwargs):
return asyncio._maybe_async(fn, *args, **kwargs)
do_async = make_async(value)
if is_classmethod:
do_async = classmethod(do_async)
do_async._maybe_async_applied = True
setattr(obj, name, do_async)
if recurse:
for cls in obj.mro()[1:]:
if cls != object:
_apply_maybe_async(cls, False)
return obj
def _parametrize_cls(module, cls):
"""implement a class-based version of pytest parametrize."""
if "_sa_parametrize" not in cls.__dict__:
return [cls]
_sa_parametrize = cls._sa_parametrize
classes = []
for full_param_set in itertools.product(
*[params for argname, params in _sa_parametrize]
):
cls_variables = {}
for argname, param in zip(
[_sa_param[0] for _sa_param in _sa_parametrize], full_param_set
):
if not argname:
raise TypeError("need argnames for class-based combinations")
argname_split = re.split(r",\s*", argname)
for arg, val in zip(argname_split, param.values):
cls_variables[arg] = val
parametrized_name = "_".join(
re.sub(r"\W", "", token)
for param in full_param_set
for token in param.id.split("-")
)
name = "%s_%s" % (cls.__name__, parametrized_name)
newcls = type.__new__(type, name, (cls,), cls_variables)
setattr(module, name, newcls)
classes.append(newcls)
return classes
_current_class = None
def pytest_runtest_setup(item):
from sqlalchemy.testing import asyncio
# pytest_runtest_setup runs *before* pytest fixtures with scope="class".
# plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest
# for the whole class and has to run things that are across all current
# databases, so we run this outside of the pytest fixture system altogether
# and ensure asyncio greenlet if any engines are async
global _current_class
if isinstance(item, pytest.Function) and _current_class is None:
asyncio._maybe_async_provisioning(
plugin_base.start_test_class_outside_fixtures,
item.cls,
)
_current_class = item.getparent(pytest.Class)
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_teardown(item, nextitem):
# runs inside of pytest function fixture scope
# after test function runs
from sqlalchemy.testing import asyncio
asyncio._maybe_async(plugin_base.after_test, item)
yield
# this is now after all the fixture teardown have run, the class can be
# finalized. Since pytest v7 this finalizer can no longer be added in
# pytest_runtest_setup since the class has not yet been setup at that
# time.
# See https://github.com/pytest-dev/pytest/issues/9343
global _current_class, _current_report
if _current_class is not None and (
# last test or a new class
nextitem is None
or nextitem.getparent(pytest.Class) is not _current_class
):
_current_class = None
try:
asyncio._maybe_async_provisioning(
plugin_base.stop_test_class_outside_fixtures, item.cls
)
except Exception as e:
# in case of an exception during teardown attach the original
# error to the exception message, otherwise it will get lost
if _current_report.failed:
if not e.args:
e.args = (
"__Original test failure__:\n"
+ _current_report.longreprtext,
)
elif e.args[-1] and isinstance(e.args[-1], str):
args = list(e.args)
args[-1] += (
"\n__Original test failure__:\n"
+ _current_report.longreprtext
)
e.args = tuple(args)
else:
e.args += (
"__Original test failure__",
_current_report.longreprtext,
)
raise
finally:
_current_report = None
def pytest_runtest_call(item):
# runs inside of pytest function fixture scope
# before test function runs
from sqlalchemy.testing import asyncio
asyncio._maybe_async(
plugin_base.before_test,
item,
item.module.__name__,
item.cls,
item.name,
)
_current_report = None
def pytest_runtest_logreport(report):
global _current_report
if report.when == "call":
_current_report = report
@pytest.fixture(scope="class")
def setup_class_methods(request):
from sqlalchemy.testing import asyncio
cls = request.cls
if hasattr(cls, "setup_test_class"):
asyncio._maybe_async(cls.setup_test_class)
yield
if hasattr(cls, "teardown_test_class"):
asyncio._maybe_async(cls.teardown_test_class)
asyncio._maybe_async(plugin_base.stop_test_class, cls)
@pytest.fixture(scope="function")
def setup_test_methods(request):
from sqlalchemy.testing import asyncio
# called for each test
self = request.instance
# before this fixture runs:
# 1. function level "autouse" fixtures under py3k (examples: TablesTest
# define tables / data, MappedTest define tables / mappers / data)
# 2. was for p2k. no longer applies
# 3. run outer xdist-style setup
if hasattr(self, "setup_test"):
asyncio._maybe_async(self.setup_test)
# alembic test suite is using setUp and tearDown
# xdist methods; support these in the test suite
# for the near term
if hasattr(self, "setUp"):
asyncio._maybe_async(self.setUp)
# inside the yield:
# 4. function level fixtures defined on test functions themselves,
# e.g. "connection", "metadata" run next
# 5. pytest hook pytest_runtest_call then runs
# 6. test itself runs
yield
# yield finishes:
# 7. function level fixtures defined on test functions
# themselves, e.g. "connection" rolls back the transaction, "metadata"
# emits drop all
# 8. pytest hook pytest_runtest_teardown hook runs, this is associated
# with fixtures close all sessions, provisioning.stop_test_class(),
# engines.testing_reaper -> ensure all connection pool connections
# are returned, engines created by testing_engine that aren't the
# config engine are disposed
asyncio._maybe_async(plugin_base.after_test_fixtures, self)
# 10. run xdist-style teardown
if hasattr(self, "tearDown"):
asyncio._maybe_async(self.tearDown)
if hasattr(self, "teardown_test"):
asyncio._maybe_async(self.teardown_test)
# 11. was for p2k. no longer applies
# 12. function level "autouse" fixtures under py3k (examples: TablesTest /
# MappedTest delete table data, possibly drop tables and clear mappers
# depending on the flags defined by the test class)
def _pytest_fn_decorator(target):
"""Port of langhelpers.decorator with pytest-specific tricks."""
from sqlalchemy.util.langhelpers import format_argspec_plus
from sqlalchemy.util.compat import inspect_getfullargspec
def _exec_code_in_env(code, env, fn_name):
# note this is affected by "from __future__ import annotations" at
# the top; exec'ed code will use non-evaluated annotations
# which allows us to be more flexible with code rendering
# in format_argpsec_plus()
exec(code, env)
return env[fn_name]
def decorate(fn, add_positional_parameters=()):
spec = inspect_getfullargspec(fn)
if add_positional_parameters:
spec.args.extend(add_positional_parameters)
metadata = dict(
__target_fn="__target_fn", __orig_fn="__orig_fn", name=fn.__name__
)
metadata.update(format_argspec_plus(spec, grouped=False))
code = (
"""\
def %(name)s%(grouped_args)s:
return %(__target_fn)s(%(__orig_fn)s, %(apply_kw)s)
"""
% metadata
)
decorated = _exec_code_in_env(
code, {"__target_fn": target, "__orig_fn": fn}, fn.__name__
)
if not add_positional_parameters:
decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
decorated.__wrapped__ = fn
return update_wrapper(decorated, fn)
else:
# this is the pytest hacky part. don't do a full update wrapper
# because pytest is really being sneaky about finding the args
# for the wrapped function
decorated.__module__ = fn.__module__
decorated.__name__ = fn.__name__
if hasattr(fn, "pytestmark"):
decorated.pytestmark = fn.pytestmark
return decorated
return decorate
class PytestFixtureFunctions(plugin_base.FixtureFunctions):
def skip_test_exception(self, *arg, **kw):
return pytest.skip.Exception(*arg, **kw)
@property
def add_to_marker(self):
return pytest.mark
def mark_base_test_class(self):
return pytest.mark.usefixtures(
"setup_class_methods", "setup_test_methods"
)
_combination_id_fns = {
"i": lambda obj: obj,
"r": repr,
"s": str,
"n": lambda obj: (
obj.__name__ if hasattr(obj, "__name__") else type(obj).__name__
),
}
def combinations(self, *arg_sets, **kw):
"""Facade for pytest.mark.parametrize.
Automatically derives argument names from the callable which in our
case is always a method on a class with positional arguments.
ids for parameter sets are derived using an optional template.
"""
from sqlalchemy.testing import exclusions
if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"):
arg_sets = list(arg_sets[0])
argnames = kw.pop("argnames", None)
def _filter_exclusions(args):
result = []
gathered_exclusions = []
for a in args:
if isinstance(a, exclusions.compound):
gathered_exclusions.append(a)
else:
result.append(a)
return result, gathered_exclusions
id_ = kw.pop("id_", None)
tobuild_pytest_params = []
has_exclusions = False
if id_:
_combination_id_fns = self._combination_id_fns
# because itemgetter is not consistent for one argument vs.
# multiple, make it multiple in all cases and use a slice
# to omit the first argument
_arg_getter = operator.itemgetter(
0,
*[
idx
for idx, char in enumerate(id_)
if char in ("n", "r", "s", "a")
],
)
fns = [
(operator.itemgetter(idx), _combination_id_fns[char])
for idx, char in enumerate(id_)
if char in _combination_id_fns
]
for arg in arg_sets:
if not isinstance(arg, tuple):
arg = (arg,)
fn_params, param_exclusions = _filter_exclusions(arg)
parameters = _arg_getter(fn_params)[1:]
if param_exclusions:
has_exclusions = True
tobuild_pytest_params.append(
(
parameters,
param_exclusions,
"-".join(
comb_fn(getter(arg)) for getter, comb_fn in fns
),
)
)
else:
for arg in arg_sets:
if not isinstance(arg, tuple):
arg = (arg,)
fn_params, param_exclusions = _filter_exclusions(arg)
if param_exclusions:
has_exclusions = True
tobuild_pytest_params.append(
(fn_params, param_exclusions, None)
)
pytest_params = []
for parameters, param_exclusions, id_ in tobuild_pytest_params:
if has_exclusions:
parameters += (param_exclusions,)
param = pytest.param(*parameters, id=id_)
pytest_params.append(param)
def decorate(fn):
if inspect.isclass(fn):
if has_exclusions:
raise NotImplementedError(
"exclusions not supported for class level combinations"
)
if "_sa_parametrize" not in fn.__dict__:
fn._sa_parametrize = []
fn._sa_parametrize.append((argnames, pytest_params))
return fn
else:
_fn_argnames = inspect.getfullargspec(fn).args[1:]
if argnames is None:
_argnames = _fn_argnames
else:
_argnames = re.split(r", *", argnames)
if has_exclusions:
existing_exl = sum(
1 for n in _fn_argnames if n.startswith("_exclusions")
)
current_exclusion_name = f"_exclusions_{existing_exl}"
_argnames += [current_exclusion_name]
@_pytest_fn_decorator
def check_exclusions(fn, *args, **kw):
_exclusions = args[-1]
if _exclusions:
exlu = exclusions.compound().add(*_exclusions)
fn = exlu(fn)
return fn(*args[:-1], **kw)
fn = check_exclusions(
fn, add_positional_parameters=(current_exclusion_name,)
)
return pytest.mark.parametrize(_argnames, pytest_params)(fn)
return decorate
def param_ident(self, *parameters):
ident = parameters[0]
return pytest.param(*parameters[1:], id=ident)
def fixture(self, *arg, **kw):
from sqlalchemy.testing import config
from sqlalchemy.testing import asyncio
# wrapping pytest.fixture function. determine if
# decorator was called as @fixture or @fixture().
if len(arg) > 0 and callable(arg[0]):
# was called as @fixture(), we have the function to wrap.
fn = arg[0]
arg = arg[1:]
else:
# was called as @fixture, don't have the function yet.
fn = None
# create a pytest.fixture marker. because the fn is not being
# passed, this is always a pytest.FixtureFunctionMarker()
# object (or whatever pytest is calling it when you read this)
# that is waiting for a function.
fixture = pytest.fixture(*arg, **kw)
# now apply wrappers to the function, including fixture itself
def wrap(fn):
if config.any_async:
fn = asyncio._maybe_async_wrapper(fn)
# other wrappers may be added here
# now apply FixtureFunctionMarker
fn = fixture(fn)
return fn
if fn:
return wrap(fn)
else:
return wrap
def get_current_test_name(self):
return os.environ.get("PYTEST_CURRENT_TEST")
def async_test(self, fn):
from sqlalchemy.testing import asyncio
@_pytest_fn_decorator
def decorate(fn, *args, **kwargs):
asyncio._run_coroutine_function(fn, *args, **kwargs)
return decorate(fn)