Update 2025-04-24_11:44:19
This commit is contained in:
@ -0,0 +1,503 @@
|
||||
# testing/fixtures/sql.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
|
||||
import sqlalchemy as sa
|
||||
from .base import TestBase
|
||||
from .. import config
|
||||
from .. import mock
|
||||
from ..assertions import eq_
|
||||
from ..assertions import ne_
|
||||
from ..util import adict
|
||||
from ..util import drop_all_tables_from_metadata
|
||||
from ... import event
|
||||
from ... import util
|
||||
from ...schema import sort_tables_and_constraints
|
||||
from ...sql import visitors
|
||||
from ...sql.elements import ClauseElement
|
||||
|
||||
|
||||
class TablesTest(TestBase):
|
||||
# 'once', None
|
||||
run_setup_bind = "once"
|
||||
|
||||
# 'once', 'each', None
|
||||
run_define_tables = "once"
|
||||
|
||||
# 'once', 'each', None
|
||||
run_create_tables = "once"
|
||||
|
||||
# 'once', 'each', None
|
||||
run_inserts = "each"
|
||||
|
||||
# 'each', None
|
||||
run_deletes = "each"
|
||||
|
||||
# 'once', None
|
||||
run_dispose_bind = None
|
||||
|
||||
bind = None
|
||||
_tables_metadata = None
|
||||
tables = None
|
||||
other = None
|
||||
sequences = None
|
||||
|
||||
@config.fixture(autouse=True, scope="class")
|
||||
def _setup_tables_test_class(self):
|
||||
cls = self.__class__
|
||||
cls._init_class()
|
||||
|
||||
cls._setup_once_tables()
|
||||
|
||||
cls._setup_once_inserts()
|
||||
|
||||
yield
|
||||
|
||||
cls._teardown_once_metadata_bind()
|
||||
|
||||
@config.fixture(autouse=True, scope="function")
|
||||
def _setup_tables_test_instance(self):
|
||||
self._setup_each_tables()
|
||||
self._setup_each_inserts()
|
||||
|
||||
yield
|
||||
|
||||
self._teardown_each_tables()
|
||||
|
||||
@property
|
||||
def tables_test_metadata(self):
|
||||
return self._tables_metadata
|
||||
|
||||
@classmethod
|
||||
def _init_class(cls):
|
||||
if cls.run_define_tables == "each":
|
||||
if cls.run_create_tables == "once":
|
||||
cls.run_create_tables = "each"
|
||||
assert cls.run_inserts in ("each", None)
|
||||
|
||||
cls.other = adict()
|
||||
cls.tables = adict()
|
||||
cls.sequences = adict()
|
||||
|
||||
cls.bind = cls.setup_bind()
|
||||
cls._tables_metadata = sa.MetaData()
|
||||
|
||||
@classmethod
|
||||
def _setup_once_inserts(cls):
|
||||
if cls.run_inserts == "once":
|
||||
cls._load_fixtures()
|
||||
with cls.bind.begin() as conn:
|
||||
cls.insert_data(conn)
|
||||
|
||||
@classmethod
|
||||
def _setup_once_tables(cls):
|
||||
if cls.run_define_tables == "once":
|
||||
cls.define_tables(cls._tables_metadata)
|
||||
if cls.run_create_tables == "once":
|
||||
cls._tables_metadata.create_all(cls.bind)
|
||||
cls.tables.update(cls._tables_metadata.tables)
|
||||
cls.sequences.update(cls._tables_metadata._sequences)
|
||||
|
||||
def _setup_each_tables(self):
|
||||
if self.run_define_tables == "each":
|
||||
self.define_tables(self._tables_metadata)
|
||||
if self.run_create_tables == "each":
|
||||
self._tables_metadata.create_all(self.bind)
|
||||
self.tables.update(self._tables_metadata.tables)
|
||||
self.sequences.update(self._tables_metadata._sequences)
|
||||
elif self.run_create_tables == "each":
|
||||
self._tables_metadata.create_all(self.bind)
|
||||
|
||||
def _setup_each_inserts(self):
|
||||
if self.run_inserts == "each":
|
||||
self._load_fixtures()
|
||||
with self.bind.begin() as conn:
|
||||
self.insert_data(conn)
|
||||
|
||||
def _teardown_each_tables(self):
|
||||
if self.run_define_tables == "each":
|
||||
self.tables.clear()
|
||||
if self.run_create_tables == "each":
|
||||
drop_all_tables_from_metadata(self._tables_metadata, self.bind)
|
||||
self._tables_metadata.clear()
|
||||
elif self.run_create_tables == "each":
|
||||
drop_all_tables_from_metadata(self._tables_metadata, self.bind)
|
||||
|
||||
savepoints = getattr(config.requirements, "savepoints", False)
|
||||
if savepoints:
|
||||
savepoints = savepoints.enabled
|
||||
|
||||
# no need to run deletes if tables are recreated on setup
|
||||
if (
|
||||
self.run_define_tables != "each"
|
||||
and self.run_create_tables != "each"
|
||||
and self.run_deletes == "each"
|
||||
):
|
||||
with self.bind.begin() as conn:
|
||||
for table in reversed(
|
||||
[
|
||||
t
|
||||
for (t, fks) in sort_tables_and_constraints(
|
||||
self._tables_metadata.tables.values()
|
||||
)
|
||||
if t is not None
|
||||
]
|
||||
):
|
||||
try:
|
||||
if savepoints:
|
||||
with conn.begin_nested():
|
||||
conn.execute(table.delete())
|
||||
else:
|
||||
conn.execute(table.delete())
|
||||
except sa.exc.DBAPIError as ex:
|
||||
print(
|
||||
("Error emptying table %s: %r" % (table, ex)),
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _teardown_once_metadata_bind(cls):
|
||||
if cls.run_create_tables:
|
||||
drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
|
||||
|
||||
if cls.run_dispose_bind == "once":
|
||||
cls.dispose_bind(cls.bind)
|
||||
|
||||
cls._tables_metadata.bind = None
|
||||
|
||||
if cls.run_setup_bind is not None:
|
||||
cls.bind = None
|
||||
|
||||
@classmethod
|
||||
def setup_bind(cls):
|
||||
return config.db
|
||||
|
||||
@classmethod
|
||||
def dispose_bind(cls, bind):
|
||||
if hasattr(bind, "dispose"):
|
||||
bind.dispose()
|
||||
elif hasattr(bind, "close"):
|
||||
bind.close()
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def fixtures(cls):
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls, connection):
|
||||
pass
|
||||
|
||||
def sql_count_(self, count, fn):
|
||||
self.assert_sql_count(self.bind, fn, count)
|
||||
|
||||
def sql_eq_(self, callable_, statements):
|
||||
self.assert_sql(self.bind, callable_, statements)
|
||||
|
||||
@classmethod
|
||||
def _load_fixtures(cls):
|
||||
"""Insert rows as represented by the fixtures() method."""
|
||||
headers, rows = {}, {}
|
||||
for table, data in cls.fixtures().items():
|
||||
if len(data) < 2:
|
||||
continue
|
||||
if isinstance(table, str):
|
||||
table = cls.tables[table]
|
||||
headers[table] = data[0]
|
||||
rows[table] = data[1:]
|
||||
for table, fks in sort_tables_and_constraints(
|
||||
cls._tables_metadata.tables.values()
|
||||
):
|
||||
if table is None:
|
||||
continue
|
||||
if table not in headers:
|
||||
continue
|
||||
with cls.bind.begin() as conn:
|
||||
conn.execute(
|
||||
table.insert(),
|
||||
[
|
||||
dict(zip(headers[table], column_values))
|
||||
for column_values in rows[table]
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class NoCache:
|
||||
@config.fixture(autouse=True, scope="function")
|
||||
def _disable_cache(self):
|
||||
_cache = config.db._compiled_cache
|
||||
config.db._compiled_cache = None
|
||||
yield
|
||||
config.db._compiled_cache = _cache
|
||||
|
||||
|
||||
class RemovesEvents:
|
||||
@util.memoized_property
|
||||
def _event_fns(self):
|
||||
return set()
|
||||
|
||||
def event_listen(self, target, name, fn, **kw):
|
||||
self._event_fns.add((target, name, fn))
|
||||
event.listen(target, name, fn, **kw)
|
||||
|
||||
@config.fixture(autouse=True, scope="function")
|
||||
def _remove_events(self):
|
||||
yield
|
||||
for key in self._event_fns:
|
||||
event.remove(*key)
|
||||
|
||||
|
||||
class ComputedReflectionFixtureTest(TablesTest):
|
||||
run_inserts = run_deletes = None
|
||||
|
||||
__backend__ = True
|
||||
__requires__ = ("computed_columns", "table_reflection")
|
||||
|
||||
regexp = re.compile(r"[\[\]\(\)\s`'\"]*")
|
||||
|
||||
def normalize(self, text):
|
||||
return self.regexp.sub("", text).lower()
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
from ... import Integer
|
||||
from ... import testing
|
||||
from ...schema import Column
|
||||
from ...schema import Computed
|
||||
from ...schema import Table
|
||||
|
||||
Table(
|
||||
"computed_default_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("normal", Integer),
|
||||
Column("computed_col", Integer, Computed("normal + 42")),
|
||||
Column("with_default", Integer, server_default="42"),
|
||||
)
|
||||
|
||||
t = Table(
|
||||
"computed_column_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("normal", Integer),
|
||||
Column("computed_no_flag", Integer, Computed("normal + 42")),
|
||||
)
|
||||
|
||||
if testing.requires.schemas.enabled:
|
||||
t2 = Table(
|
||||
"computed_column_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("normal", Integer),
|
||||
Column("computed_no_flag", Integer, Computed("normal / 42")),
|
||||
schema=config.test_schema,
|
||||
)
|
||||
|
||||
if testing.requires.computed_columns_virtual.enabled:
|
||||
t.append_column(
|
||||
Column(
|
||||
"computed_virtual",
|
||||
Integer,
|
||||
Computed("normal + 2", persisted=False),
|
||||
)
|
||||
)
|
||||
if testing.requires.schemas.enabled:
|
||||
t2.append_column(
|
||||
Column(
|
||||
"computed_virtual",
|
||||
Integer,
|
||||
Computed("normal / 2", persisted=False),
|
||||
)
|
||||
)
|
||||
if testing.requires.computed_columns_stored.enabled:
|
||||
t.append_column(
|
||||
Column(
|
||||
"computed_stored",
|
||||
Integer,
|
||||
Computed("normal - 42", persisted=True),
|
||||
)
|
||||
)
|
||||
if testing.requires.schemas.enabled:
|
||||
t2.append_column(
|
||||
Column(
|
||||
"computed_stored",
|
||||
Integer,
|
||||
Computed("normal * 42", persisted=True),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class CacheKeyFixture:
|
||||
def _compare_equal(self, a, b, compare_values):
|
||||
a_key = a._generate_cache_key()
|
||||
b_key = b._generate_cache_key()
|
||||
|
||||
if a_key is None:
|
||||
assert a._annotations.get("nocache")
|
||||
|
||||
assert b_key is None
|
||||
else:
|
||||
eq_(a_key.key, b_key.key)
|
||||
eq_(hash(a_key.key), hash(b_key.key))
|
||||
|
||||
for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
|
||||
assert a_param.compare(b_param, compare_values=compare_values)
|
||||
return a_key, b_key
|
||||
|
||||
def _run_cache_key_fixture(self, fixture, compare_values):
|
||||
case_a = fixture()
|
||||
case_b = fixture()
|
||||
|
||||
for a, b in itertools.combinations_with_replacement(
|
||||
range(len(case_a)), 2
|
||||
):
|
||||
if a == b:
|
||||
a_key, b_key = self._compare_equal(
|
||||
case_a[a], case_b[b], compare_values
|
||||
)
|
||||
if a_key is None:
|
||||
continue
|
||||
else:
|
||||
a_key = case_a[a]._generate_cache_key()
|
||||
b_key = case_b[b]._generate_cache_key()
|
||||
|
||||
if a_key is None or b_key is None:
|
||||
if a_key is None:
|
||||
assert case_a[a]._annotations.get("nocache")
|
||||
if b_key is None:
|
||||
assert case_b[b]._annotations.get("nocache")
|
||||
continue
|
||||
|
||||
if a_key.key == b_key.key:
|
||||
for a_param, b_param in zip(
|
||||
a_key.bindparams, b_key.bindparams
|
||||
):
|
||||
if not a_param.compare(
|
||||
b_param, compare_values=compare_values
|
||||
):
|
||||
break
|
||||
else:
|
||||
# this fails unconditionally since we could not
|
||||
# find bound parameter values that differed.
|
||||
# Usually we intended to get two distinct keys here
|
||||
# so the failure will be more descriptive using the
|
||||
# ne_() assertion.
|
||||
ne_(a_key.key, b_key.key)
|
||||
else:
|
||||
ne_(a_key.key, b_key.key)
|
||||
|
||||
# ClauseElement-specific test to ensure the cache key
|
||||
# collected all the bound parameters that aren't marked
|
||||
# as "literal execute"
|
||||
if isinstance(case_a[a], ClauseElement) and isinstance(
|
||||
case_b[b], ClauseElement
|
||||
):
|
||||
assert_a_params = []
|
||||
assert_b_params = []
|
||||
|
||||
for elem in visitors.iterate(case_a[a]):
|
||||
if elem.__visit_name__ == "bindparam":
|
||||
assert_a_params.append(elem)
|
||||
|
||||
for elem in visitors.iterate(case_b[b]):
|
||||
if elem.__visit_name__ == "bindparam":
|
||||
assert_b_params.append(elem)
|
||||
|
||||
# note we're asserting the order of the params as well as
|
||||
# if there are dupes or not. ordering has to be
|
||||
# deterministic and matches what a traversal would provide.
|
||||
eq_(
|
||||
sorted(a_key.bindparams, key=lambda b: b.key),
|
||||
sorted(
|
||||
util.unique_list(assert_a_params), key=lambda b: b.key
|
||||
),
|
||||
)
|
||||
eq_(
|
||||
sorted(b_key.bindparams, key=lambda b: b.key),
|
||||
sorted(
|
||||
util.unique_list(assert_b_params), key=lambda b: b.key
|
||||
),
|
||||
)
|
||||
|
||||
def _run_cache_key_equal_fixture(self, fixture, compare_values):
|
||||
case_a = fixture()
|
||||
case_b = fixture()
|
||||
|
||||
for a, b in itertools.combinations_with_replacement(
|
||||
range(len(case_a)), 2
|
||||
):
|
||||
self._compare_equal(case_a[a], case_b[b], compare_values)
|
||||
|
||||
|
||||
def insertmanyvalues_fixture(
|
||||
connection, randomize_rows=False, warn_on_downgraded=False
|
||||
):
|
||||
dialect = connection.dialect
|
||||
orig_dialect = dialect._deliver_insertmanyvalues_batches
|
||||
orig_conn = connection._exec_insertmany_context
|
||||
|
||||
class RandomCursor:
|
||||
__slots__ = ("cursor",)
|
||||
|
||||
def __init__(self, cursor):
|
||||
self.cursor = cursor
|
||||
|
||||
# only this method is called by the deliver method.
|
||||
# by not having the other methods we assert that those aren't being
|
||||
# used
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return self.cursor.description
|
||||
|
||||
def fetchall(self):
|
||||
rows = self.cursor.fetchall()
|
||||
rows = list(rows)
|
||||
random.shuffle(rows)
|
||||
return rows
|
||||
|
||||
def _deliver_insertmanyvalues_batches(
|
||||
connection,
|
||||
cursor,
|
||||
statement,
|
||||
parameters,
|
||||
generic_setinputsizes,
|
||||
context,
|
||||
):
|
||||
if randomize_rows:
|
||||
cursor = RandomCursor(cursor)
|
||||
for batch in orig_dialect(
|
||||
connection,
|
||||
cursor,
|
||||
statement,
|
||||
parameters,
|
||||
generic_setinputsizes,
|
||||
context,
|
||||
):
|
||||
if warn_on_downgraded and batch.is_downgraded:
|
||||
util.warn("Batches were downgraded for sorted INSERT")
|
||||
|
||||
yield batch
|
||||
|
||||
def _exec_insertmany_context(dialect, context):
|
||||
with mock.patch.object(
|
||||
dialect,
|
||||
"_deliver_insertmanyvalues_batches",
|
||||
new=_deliver_insertmanyvalues_batches,
|
||||
):
|
||||
return orig_conn(dialect, context)
|
||||
|
||||
connection._exec_insertmany_context = _exec_insertmany_context
|
Reference in New Issue
Block a user