Update 2025-04-13_16:25:39
This commit is contained in:
474
venv/lib/python3.11/site-packages/sqlalchemy/testing/engines.py
Normal file
474
venv/lib/python3.11/site-packages/sqlalchemy/testing/engines.py
Normal file
@ -0,0 +1,474 @@
|
||||
# testing/engines.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import re
|
||||
import typing
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
import warnings
|
||||
import weakref
|
||||
|
||||
from . import config
|
||||
from .util import decorator
|
||||
from .util import gc_collect
|
||||
from .. import event
|
||||
from .. import pool
|
||||
from ..util import await_only
|
||||
from ..util.typing import Literal
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ..engine import Engine
|
||||
from ..engine.url import URL
|
||||
from ..ext.asyncio import AsyncEngine
|
||||
|
||||
|
||||
class ConnectionKiller:
|
||||
def __init__(self):
|
||||
self.proxy_refs = weakref.WeakKeyDictionary()
|
||||
self.testing_engines = collections.defaultdict(set)
|
||||
self.dbapi_connections = set()
|
||||
|
||||
def add_pool(self, pool):
|
||||
event.listen(pool, "checkout", self._add_conn)
|
||||
event.listen(pool, "checkin", self._remove_conn)
|
||||
event.listen(pool, "close", self._remove_conn)
|
||||
event.listen(pool, "close_detached", self._remove_conn)
|
||||
# note we are keeping "invalidated" here, as those are still
|
||||
# opened connections we would like to roll back
|
||||
|
||||
def _add_conn(self, dbapi_con, con_record, con_proxy):
|
||||
self.dbapi_connections.add(dbapi_con)
|
||||
self.proxy_refs[con_proxy] = True
|
||||
|
||||
def _remove_conn(self, dbapi_conn, *arg):
|
||||
self.dbapi_connections.discard(dbapi_conn)
|
||||
|
||||
def add_engine(self, engine, scope):
|
||||
self.add_pool(engine.pool)
|
||||
|
||||
assert scope in ("class", "global", "function", "fixture")
|
||||
self.testing_engines[scope].add(engine)
|
||||
|
||||
def _safe(self, fn):
|
||||
try:
|
||||
fn()
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
"testing_reaper couldn't rollback/close connection: %s" % e
|
||||
)
|
||||
|
||||
def rollback_all(self):
|
||||
for rec in list(self.proxy_refs):
|
||||
if rec is not None and rec.is_valid:
|
||||
self._safe(rec.rollback)
|
||||
|
||||
def checkin_all(self):
|
||||
# run pool.checkin() for all ConnectionFairy instances we have
|
||||
# tracked.
|
||||
|
||||
for rec in list(self.proxy_refs):
|
||||
if rec is not None and rec.is_valid:
|
||||
self.dbapi_connections.discard(rec.dbapi_connection)
|
||||
self._safe(rec._checkin)
|
||||
|
||||
# for fairy refs that were GCed and could not close the connection,
|
||||
# such as asyncio, roll back those remaining connections
|
||||
for con in self.dbapi_connections:
|
||||
self._safe(con.rollback)
|
||||
self.dbapi_connections.clear()
|
||||
|
||||
def close_all(self):
|
||||
self.checkin_all()
|
||||
|
||||
def prepare_for_drop_tables(self, connection):
|
||||
# don't do aggressive checks for third party test suites
|
||||
if not config.bootstrapped_as_sqlalchemy:
|
||||
return
|
||||
|
||||
from . import provision
|
||||
|
||||
provision.prepare_for_drop_tables(connection.engine.url, connection)
|
||||
|
||||
def _drop_testing_engines(self, scope):
|
||||
eng = self.testing_engines[scope]
|
||||
for rec in list(eng):
|
||||
for proxy_ref in list(self.proxy_refs):
|
||||
if proxy_ref is not None and proxy_ref.is_valid:
|
||||
if (
|
||||
proxy_ref._pool is not None
|
||||
and proxy_ref._pool is rec.pool
|
||||
):
|
||||
self._safe(proxy_ref._checkin)
|
||||
|
||||
if hasattr(rec, "sync_engine"):
|
||||
await_only(rec.dispose())
|
||||
else:
|
||||
rec.dispose()
|
||||
eng.clear()
|
||||
|
||||
def after_test(self):
|
||||
self._drop_testing_engines("function")
|
||||
|
||||
def after_test_outside_fixtures(self, test):
|
||||
# don't do aggressive checks for third party test suites
|
||||
if not config.bootstrapped_as_sqlalchemy:
|
||||
return
|
||||
|
||||
if test.__class__.__leave_connections_for_teardown__:
|
||||
return
|
||||
|
||||
self.checkin_all()
|
||||
|
||||
# on PostgreSQL, this will test for any "idle in transaction"
|
||||
# connections. useful to identify tests with unusual patterns
|
||||
# that can't be cleaned up correctly.
|
||||
from . import provision
|
||||
|
||||
with config.db.connect() as conn:
|
||||
provision.prepare_for_drop_tables(conn.engine.url, conn)
|
||||
|
||||
def stop_test_class_inside_fixtures(self):
|
||||
self.checkin_all()
|
||||
self._drop_testing_engines("function")
|
||||
self._drop_testing_engines("class")
|
||||
|
||||
def stop_test_class_outside_fixtures(self):
|
||||
# ensure no refs to checked out connections at all.
|
||||
|
||||
if pool.base._strong_ref_connection_records:
|
||||
gc_collect()
|
||||
|
||||
if pool.base._strong_ref_connection_records:
|
||||
ln = len(pool.base._strong_ref_connection_records)
|
||||
pool.base._strong_ref_connection_records.clear()
|
||||
assert (
|
||||
False
|
||||
), "%d connection recs not cleared after test suite" % (ln)
|
||||
|
||||
def final_cleanup(self):
|
||||
self.checkin_all()
|
||||
for scope in self.testing_engines:
|
||||
self._drop_testing_engines(scope)
|
||||
|
||||
def assert_all_closed(self):
|
||||
for rec in self.proxy_refs:
|
||||
if rec.is_valid:
|
||||
assert False
|
||||
|
||||
|
||||
testing_reaper = ConnectionKiller()
|
||||
|
||||
|
||||
@decorator
|
||||
def assert_conns_closed(fn, *args, **kw):
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.assert_all_closed()
|
||||
|
||||
|
||||
@decorator
|
||||
def rollback_open_connections(fn, *args, **kw):
|
||||
"""Decorator that rolls back all open connections after fn execution."""
|
||||
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.rollback_all()
|
||||
|
||||
|
||||
@decorator
|
||||
def close_first(fn, *args, **kw):
|
||||
"""Decorator that closes all connections before fn execution."""
|
||||
|
||||
testing_reaper.checkin_all()
|
||||
fn(*args, **kw)
|
||||
|
||||
|
||||
@decorator
|
||||
def close_open_connections(fn, *args, **kw):
|
||||
"""Decorator that closes all connections after fn execution."""
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.checkin_all()
|
||||
|
||||
|
||||
def all_dialects(exclude=None):
|
||||
import sqlalchemy.dialects as d
|
||||
|
||||
for name in d.__all__:
|
||||
# TEMPORARY
|
||||
if exclude and name in exclude:
|
||||
continue
|
||||
mod = getattr(d, name, None)
|
||||
if not mod:
|
||||
mod = getattr(
|
||||
__import__("sqlalchemy.dialects.%s" % name).dialects, name
|
||||
)
|
||||
yield mod.dialect()
|
||||
|
||||
|
||||
class ReconnectFixture:
|
||||
def __init__(self, dbapi):
|
||||
self.dbapi = dbapi
|
||||
self.connections = []
|
||||
self.is_stopped = False
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.dbapi, key)
|
||||
|
||||
def connect(self, *args, **kwargs):
|
||||
conn = self.dbapi.connect(*args, **kwargs)
|
||||
if self.is_stopped:
|
||||
self._safe(conn.close)
|
||||
curs = conn.cursor() # should fail on Oracle etc.
|
||||
# should fail for everything that didn't fail
|
||||
# above, connection is closed
|
||||
curs.execute("select 1")
|
||||
assert False, "simulated connect failure didn't work"
|
||||
else:
|
||||
self.connections.append(conn)
|
||||
return conn
|
||||
|
||||
def _safe(self, fn):
|
||||
try:
|
||||
fn()
|
||||
except Exception as e:
|
||||
warnings.warn("ReconnectFixture couldn't close connection: %s" % e)
|
||||
|
||||
def shutdown(self, stop=False):
|
||||
# TODO: this doesn't cover all cases
|
||||
# as nicely as we'd like, namely MySQLdb.
|
||||
# would need to implement R. Brewer's
|
||||
# proxy server idea to get better
|
||||
# coverage.
|
||||
self.is_stopped = stop
|
||||
for c in list(self.connections):
|
||||
self._safe(c.close)
|
||||
self.connections = []
|
||||
|
||||
def restart(self):
|
||||
self.is_stopped = False
|
||||
|
||||
|
||||
def reconnecting_engine(url=None, options=None):
|
||||
url = url or config.db.url
|
||||
dbapi = config.db.dialect.dbapi
|
||||
if not options:
|
||||
options = {}
|
||||
options["module"] = ReconnectFixture(dbapi)
|
||||
engine = testing_engine(url, options)
|
||||
_dispose = engine.dispose
|
||||
|
||||
def dispose():
|
||||
engine.dialect.dbapi.shutdown()
|
||||
engine.dialect.dbapi.is_stopped = False
|
||||
_dispose()
|
||||
|
||||
engine.test_shutdown = engine.dialect.dbapi.shutdown
|
||||
engine.test_restart = engine.dialect.dbapi.restart
|
||||
engine.dispose = dispose
|
||||
return engine
|
||||
|
||||
|
||||
@typing.overload
|
||||
def testing_engine(
|
||||
url: Optional[URL] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
asyncio: Literal[False] = False,
|
||||
transfer_staticpool: bool = False,
|
||||
) -> Engine: ...
|
||||
|
||||
|
||||
@typing.overload
|
||||
def testing_engine(
|
||||
url: Optional[URL] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
asyncio: Literal[True] = True,
|
||||
transfer_staticpool: bool = False,
|
||||
) -> AsyncEngine: ...
|
||||
|
||||
|
||||
def testing_engine(
|
||||
url=None,
|
||||
options=None,
|
||||
asyncio=False,
|
||||
transfer_staticpool=False,
|
||||
share_pool=False,
|
||||
_sqlite_savepoint=False,
|
||||
):
|
||||
if asyncio:
|
||||
assert not _sqlite_savepoint
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
create_async_engine as create_engine,
|
||||
)
|
||||
else:
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
if not options:
|
||||
use_reaper = True
|
||||
scope = "function"
|
||||
sqlite_savepoint = False
|
||||
else:
|
||||
use_reaper = options.pop("use_reaper", True)
|
||||
scope = options.pop("scope", "function")
|
||||
sqlite_savepoint = options.pop("sqlite_savepoint", False)
|
||||
|
||||
url = url or config.db.url
|
||||
|
||||
url = make_url(url)
|
||||
|
||||
if (
|
||||
config.db is None or url.drivername == config.db.url.drivername
|
||||
) and config.db_opts:
|
||||
use_options = config.db_opts.copy()
|
||||
else:
|
||||
use_options = {}
|
||||
|
||||
if options is not None:
|
||||
use_options.update(options)
|
||||
|
||||
engine = create_engine(url, **use_options)
|
||||
|
||||
if sqlite_savepoint and engine.name == "sqlite":
|
||||
# apply SQLite savepoint workaround
|
||||
@event.listens_for(engine, "connect")
|
||||
def do_connect(dbapi_connection, connection_record):
|
||||
dbapi_connection.isolation_level = None
|
||||
|
||||
@event.listens_for(engine, "begin")
|
||||
def do_begin(conn):
|
||||
conn.exec_driver_sql("BEGIN")
|
||||
|
||||
if transfer_staticpool:
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
if config.db is not None and isinstance(config.db.pool, StaticPool):
|
||||
use_reaper = False
|
||||
engine.pool._transfer_from(config.db.pool)
|
||||
elif share_pool:
|
||||
engine.pool = config.db.pool
|
||||
|
||||
if scope == "global":
|
||||
if asyncio:
|
||||
engine.sync_engine._has_events = True
|
||||
else:
|
||||
engine._has_events = (
|
||||
True # enable event blocks, helps with profiling
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(engine.pool, pool.QueuePool)
|
||||
and "pool" not in use_options
|
||||
and "pool_timeout" not in use_options
|
||||
and "max_overflow" not in use_options
|
||||
):
|
||||
engine.pool._timeout = 0
|
||||
engine.pool._max_overflow = 0
|
||||
if use_reaper:
|
||||
testing_reaper.add_engine(engine, scope)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def mock_engine(dialect_name=None):
|
||||
"""Provides a mocking engine based on the current testing.db.
|
||||
|
||||
This is normally used to test DDL generation flow as emitted
|
||||
by an Engine.
|
||||
|
||||
It should not be used in other cases, as assert_compile() and
|
||||
assert_sql_execution() are much better choices with fewer
|
||||
moving parts.
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import create_mock_engine
|
||||
|
||||
if not dialect_name:
|
||||
dialect_name = config.db.name
|
||||
|
||||
buffer = []
|
||||
|
||||
def executor(sql, *a, **kw):
|
||||
buffer.append(sql)
|
||||
|
||||
def assert_sql(stmts):
|
||||
recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer]
|
||||
assert recv == stmts, recv
|
||||
|
||||
def print_sql():
|
||||
d = engine.dialect
|
||||
return "\n".join(str(s.compile(dialect=d)) for s in engine.mock)
|
||||
|
||||
engine = create_mock_engine(dialect_name + "://", executor)
|
||||
assert not hasattr(engine, "mock")
|
||||
engine.mock = buffer
|
||||
engine.assert_sql = assert_sql
|
||||
engine.print_sql = print_sql
|
||||
return engine
|
||||
|
||||
|
||||
class DBAPIProxyCursor:
|
||||
"""Proxy a DBAPI cursor.
|
||||
|
||||
Tests can provide subclasses of this to intercept
|
||||
DBAPI-level cursor operations.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, engine, conn, *args, **kwargs):
|
||||
self.engine = engine
|
||||
self.connection = conn
|
||||
self.cursor = conn.cursor(*args, **kwargs)
|
||||
|
||||
def execute(self, stmt, parameters=None, **kw):
|
||||
if parameters:
|
||||
return self.cursor.execute(stmt, parameters, **kw)
|
||||
else:
|
||||
return self.cursor.execute(stmt, **kw)
|
||||
|
||||
def executemany(self, stmt, params, **kw):
|
||||
return self.cursor.executemany(stmt, params, **kw)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.cursor)
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.cursor, key)
|
||||
|
||||
|
||||
class DBAPIProxyConnection:
|
||||
"""Proxy a DBAPI connection.
|
||||
|
||||
Tests can provide subclasses of this to intercept
|
||||
DBAPI-level connection operations.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, engine, conn, cursor_cls):
|
||||
self.conn = conn
|
||||
self.engine = engine
|
||||
self.cursor_cls = cursor_cls
|
||||
|
||||
def cursor(self, *args, **kwargs):
|
||||
return self.cursor_cls(self.engine, self.conn, *args, **kwargs)
|
||||
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.conn, key)
|
Reference in New Issue
Block a user