Update 2025-04-13_16:25:39

This commit is contained in:
root
2025-04-13 16:25:41 +02:00
commit 4c711360d3
2979 changed files with 666585 additions and 0 deletions

View File

@ -0,0 +1,19 @@
# testing/suite/__init__.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from .test_cte import * # noqa
from .test_ddl import * # noqa
from .test_deprecations import * # noqa
from .test_dialect import * # noqa
from .test_insert import * # noqa
from .test_reflection import * # noqa
from .test_results import * # noqa
from .test_rowcount import * # noqa
from .test_select import * # noqa
from .test_sequence import * # noqa
from .test_types import * # noqa
from .test_unicode_ddl import * # noqa
from .test_update_delete import * # noqa

View File

@ -0,0 +1,211 @@
# testing/suite/test_cte.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from .. import fixtures
from ..assertions import eq_
from ..schema import Column
from ..schema import Table
from ... import ForeignKey
from ... import Integer
from ... import select
from ... import String
from ... import testing
class CTETest(fixtures.TablesTest):
__backend__ = True
__requires__ = ("ctes",)
run_inserts = "each"
run_deletes = "each"
@classmethod
def define_tables(cls, metadata):
Table(
"some_table",
metadata,
Column("id", Integer, primary_key=True),
Column("data", String(50)),
Column("parent_id", ForeignKey("some_table.id")),
)
Table(
"some_other_table",
metadata,
Column("id", Integer, primary_key=True),
Column("data", String(50)),
Column("parent_id", Integer),
)
@classmethod
def insert_data(cls, connection):
connection.execute(
cls.tables.some_table.insert(),
[
{"id": 1, "data": "d1", "parent_id": None},
{"id": 2, "data": "d2", "parent_id": 1},
{"id": 3, "data": "d3", "parent_id": 1},
{"id": 4, "data": "d4", "parent_id": 3},
{"id": 5, "data": "d5", "parent_id": 3},
],
)
def test_select_nonrecursive_round_trip(self, connection):
some_table = self.tables.some_table
cte = (
select(some_table)
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
.cte("some_cte")
)
result = connection.execute(
select(cte.c.data).where(cte.c.data.in_(["d4", "d5"]))
)
eq_(result.fetchall(), [("d4",)])
def test_select_recursive_round_trip(self, connection):
some_table = self.tables.some_table
cte = (
select(some_table)
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
.cte("some_cte", recursive=True)
)
cte_alias = cte.alias("c1")
st1 = some_table.alias()
# note that SQL Server requires this to be UNION ALL,
# can't be UNION
cte = cte.union_all(
select(st1).where(st1.c.id == cte_alias.c.parent_id)
)
result = connection.execute(
select(cte.c.data)
.where(cte.c.data != "d2")
.order_by(cte.c.data.desc())
)
eq_(
result.fetchall(),
[("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)],
)
def test_insert_from_select_round_trip(self, connection):
some_table = self.tables.some_table
some_other_table = self.tables.some_other_table
cte = (
select(some_table)
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
.cte("some_cte")
)
connection.execute(
some_other_table.insert().from_select(
["id", "data", "parent_id"], select(cte)
)
)
eq_(
connection.execute(
select(some_other_table).order_by(some_other_table.c.id)
).fetchall(),
[(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)],
)
@testing.requires.ctes_with_update_delete
@testing.requires.update_from
def test_update_from_round_trip(self, connection):
some_table = self.tables.some_table
some_other_table = self.tables.some_other_table
connection.execute(
some_other_table.insert().from_select(
["id", "data", "parent_id"], select(some_table)
)
)
cte = (
select(some_table)
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
.cte("some_cte")
)
connection.execute(
some_other_table.update()
.values(parent_id=5)
.where(some_other_table.c.data == cte.c.data)
)
eq_(
connection.execute(
select(some_other_table).order_by(some_other_table.c.id)
).fetchall(),
[
(1, "d1", None),
(2, "d2", 5),
(3, "d3", 5),
(4, "d4", 5),
(5, "d5", 3),
],
)
@testing.requires.ctes_with_update_delete
@testing.requires.delete_from
def test_delete_from_round_trip(self, connection):
some_table = self.tables.some_table
some_other_table = self.tables.some_other_table
connection.execute(
some_other_table.insert().from_select(
["id", "data", "parent_id"], select(some_table)
)
)
cte = (
select(some_table)
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
.cte("some_cte")
)
connection.execute(
some_other_table.delete().where(
some_other_table.c.data == cte.c.data
)
)
eq_(
connection.execute(
select(some_other_table).order_by(some_other_table.c.id)
).fetchall(),
[(1, "d1", None), (5, "d5", 3)],
)
@testing.requires.ctes_with_update_delete
def test_delete_scalar_subq_round_trip(self, connection):
some_table = self.tables.some_table
some_other_table = self.tables.some_other_table
connection.execute(
some_other_table.insert().from_select(
["id", "data", "parent_id"], select(some_table)
)
)
cte = (
select(some_table)
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
.cte("some_cte")
)
connection.execute(
some_other_table.delete().where(
some_other_table.c.data
== select(cte.c.data)
.where(cte.c.id == some_other_table.c.id)
.scalar_subquery()
)
)
eq_(
connection.execute(
select(some_other_table).order_by(some_other_table.c.id)
).fetchall(),
[(1, "d1", None), (5, "d5", 3)],
)

View File

@ -0,0 +1,389 @@
# testing/suite/test_ddl.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
import random
from . import testing
from .. import config
from .. import fixtures
from .. import util
from ..assertions import eq_
from ..assertions import is_false
from ..assertions import is_true
from ..config import requirements
from ..schema import Table
from ... import CheckConstraint
from ... import Column
from ... import ForeignKeyConstraint
from ... import Index
from ... import inspect
from ... import Integer
from ... import schema
from ... import String
from ... import UniqueConstraint
class TableDDLTest(fixtures.TestBase):
__backend__ = True
def _simple_fixture(self, schema=None):
return Table(
"test_table",
self.metadata,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("data", String(50)),
schema=schema,
)
def _underscore_fixture(self):
return Table(
"_test_table",
self.metadata,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("_data", String(50)),
)
def _table_index_fixture(self, schema=None):
table = self._simple_fixture(schema=schema)
idx = Index("test_index", table.c.data)
return table, idx
def _simple_roundtrip(self, table):
with config.db.begin() as conn:
conn.execute(table.insert().values((1, "some data")))
result = conn.execute(table.select())
eq_(result.first(), (1, "some data"))
@requirements.create_table
@util.provide_metadata
def test_create_table(self):
table = self._simple_fixture()
table.create(config.db, checkfirst=False)
self._simple_roundtrip(table)
@requirements.create_table
@requirements.schemas
@util.provide_metadata
def test_create_table_schema(self):
table = self._simple_fixture(schema=config.test_schema)
table.create(config.db, checkfirst=False)
self._simple_roundtrip(table)
@requirements.drop_table
@util.provide_metadata
def test_drop_table(self):
table = self._simple_fixture()
table.create(config.db, checkfirst=False)
table.drop(config.db, checkfirst=False)
@requirements.create_table
@util.provide_metadata
def test_underscore_names(self):
table = self._underscore_fixture()
table.create(config.db, checkfirst=False)
self._simple_roundtrip(table)
@requirements.comment_reflection
@util.provide_metadata
def test_add_table_comment(self, connection):
table = self._simple_fixture()
table.create(connection, checkfirst=False)
table.comment = "a comment"
connection.execute(schema.SetTableComment(table))
eq_(
inspect(connection).get_table_comment("test_table"),
{"text": "a comment"},
)
@requirements.comment_reflection
@util.provide_metadata
def test_drop_table_comment(self, connection):
table = self._simple_fixture()
table.create(connection, checkfirst=False)
table.comment = "a comment"
connection.execute(schema.SetTableComment(table))
connection.execute(schema.DropTableComment(table))
eq_(
inspect(connection).get_table_comment("test_table"), {"text": None}
)
@requirements.table_ddl_if_exists
@util.provide_metadata
def test_create_table_if_not_exists(self, connection):
table = self._simple_fixture()
connection.execute(schema.CreateTable(table, if_not_exists=True))
is_true(inspect(connection).has_table("test_table"))
connection.execute(schema.CreateTable(table, if_not_exists=True))
@requirements.index_ddl_if_exists
@util.provide_metadata
def test_create_index_if_not_exists(self, connection):
table, idx = self._table_index_fixture()
connection.execute(schema.CreateTable(table, if_not_exists=True))
is_true(inspect(connection).has_table("test_table"))
is_false(
"test_index"
in [
ix["name"]
for ix in inspect(connection).get_indexes("test_table")
]
)
connection.execute(schema.CreateIndex(idx, if_not_exists=True))
is_true(
"test_index"
in [
ix["name"]
for ix in inspect(connection).get_indexes("test_table")
]
)
connection.execute(schema.CreateIndex(idx, if_not_exists=True))
@requirements.table_ddl_if_exists
@util.provide_metadata
def test_drop_table_if_exists(self, connection):
table = self._simple_fixture()
table.create(connection)
is_true(inspect(connection).has_table("test_table"))
connection.execute(schema.DropTable(table, if_exists=True))
is_false(inspect(connection).has_table("test_table"))
connection.execute(schema.DropTable(table, if_exists=True))
@requirements.index_ddl_if_exists
@util.provide_metadata
def test_drop_index_if_exists(self, connection):
table, idx = self._table_index_fixture()
table.create(connection)
is_true(
"test_index"
in [
ix["name"]
for ix in inspect(connection).get_indexes("test_table")
]
)
connection.execute(schema.DropIndex(idx, if_exists=True))
is_false(
"test_index"
in [
ix["name"]
for ix in inspect(connection).get_indexes("test_table")
]
)
connection.execute(schema.DropIndex(idx, if_exists=True))
class FutureTableDDLTest(fixtures.FutureEngineMixin, TableDDLTest):
pass
class LongNameBlowoutTest(fixtures.TestBase):
"""test the creation of a variety of DDL structures and ensure
label length limits pass on backends
"""
__backend__ = True
def fk(self, metadata, connection):
convention = {
"fk": "foreign_key_%(table_name)s_"
"%(column_0_N_name)s_"
"%(referred_table_name)s_"
+ (
"_".join(
"".join(random.choice("abcdef") for j in range(20))
for i in range(10)
)
),
}
metadata.naming_convention = convention
Table(
"a_things_with_stuff",
metadata,
Column("id_long_column_name", Integer, primary_key=True),
test_needs_fk=True,
)
cons = ForeignKeyConstraint(
["aid"], ["a_things_with_stuff.id_long_column_name"]
)
Table(
"b_related_things_of_value",
metadata,
Column(
"aid",
),
cons,
test_needs_fk=True,
)
actual_name = cons.name
metadata.create_all(connection)
if testing.requires.foreign_key_constraint_name_reflection.enabled:
insp = inspect(connection)
fks = insp.get_foreign_keys("b_related_things_of_value")
reflected_name = fks[0]["name"]
return actual_name, reflected_name
else:
return actual_name, None
def pk(self, metadata, connection):
convention = {
"pk": "primary_key_%(table_name)s_"
"%(column_0_N_name)s"
+ (
"_".join(
"".join(random.choice("abcdef") for j in range(30))
for i in range(10)
)
),
}
metadata.naming_convention = convention
a = Table(
"a_things_with_stuff",
metadata,
Column("id_long_column_name", Integer, primary_key=True),
Column("id_another_long_name", Integer, primary_key=True),
)
cons = a.primary_key
actual_name = cons.name
metadata.create_all(connection)
insp = inspect(connection)
pk = insp.get_pk_constraint("a_things_with_stuff")
reflected_name = pk["name"]
return actual_name, reflected_name
def ix(self, metadata, connection):
convention = {
"ix": "index_%(table_name)s_"
"%(column_0_N_name)s"
+ (
"_".join(
"".join(random.choice("abcdef") for j in range(30))
for i in range(10)
)
),
}
metadata.naming_convention = convention
a = Table(
"a_things_with_stuff",
metadata,
Column("id_long_column_name", Integer, primary_key=True),
Column("id_another_long_name", Integer),
)
cons = Index(None, a.c.id_long_column_name, a.c.id_another_long_name)
actual_name = cons.name
metadata.create_all(connection)
insp = inspect(connection)
ix = insp.get_indexes("a_things_with_stuff")
reflected_name = ix[0]["name"]
return actual_name, reflected_name
def uq(self, metadata, connection):
convention = {
"uq": "unique_constraint_%(table_name)s_"
"%(column_0_N_name)s"
+ (
"_".join(
"".join(random.choice("abcdef") for j in range(30))
for i in range(10)
)
),
}
metadata.naming_convention = convention
cons = UniqueConstraint("id_long_column_name", "id_another_long_name")
Table(
"a_things_with_stuff",
metadata,
Column("id_long_column_name", Integer, primary_key=True),
Column("id_another_long_name", Integer),
cons,
)
actual_name = cons.name
metadata.create_all(connection)
insp = inspect(connection)
uq = insp.get_unique_constraints("a_things_with_stuff")
reflected_name = uq[0]["name"]
return actual_name, reflected_name
def ck(self, metadata, connection):
convention = {
"ck": "check_constraint_%(table_name)s"
+ (
"_".join(
"".join(random.choice("abcdef") for j in range(30))
for i in range(10)
)
),
}
metadata.naming_convention = convention
cons = CheckConstraint("some_long_column_name > 5")
Table(
"a_things_with_stuff",
metadata,
Column("id_long_column_name", Integer, primary_key=True),
Column("some_long_column_name", Integer),
cons,
)
actual_name = cons.name
metadata.create_all(connection)
insp = inspect(connection)
ck = insp.get_check_constraints("a_things_with_stuff")
reflected_name = ck[0]["name"]
return actual_name, reflected_name
@testing.combinations(
("fk",),
("pk",),
("ix",),
("ck", testing.requires.check_constraint_reflection.as_skips()),
("uq", testing.requires.unique_constraint_reflection.as_skips()),
argnames="type_",
)
def test_long_convention_name(self, type_, metadata, connection):
actual_name, reflected_name = getattr(self, type_)(
metadata, connection
)
assert len(actual_name) > 255
if reflected_name is not None:
overlap = actual_name[0 : len(reflected_name)]
if len(overlap) < len(actual_name):
eq_(overlap[0:-5], reflected_name[0 : len(overlap) - 5])
else:
eq_(overlap, reflected_name)
__all__ = ("TableDDLTest", "FutureTableDDLTest", "LongNameBlowoutTest")

View File

@ -0,0 +1,153 @@
# testing/suite/test_deprecations.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from .. import fixtures
from ..assertions import eq_
from ..schema import Column
from ..schema import Table
from ... import Integer
from ... import select
from ... import testing
from ... import union
class DeprecatedCompoundSelectTest(fixtures.TablesTest):
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
"some_table",
metadata,
Column("id", Integer, primary_key=True),
Column("x", Integer),
Column("y", Integer),
)
@classmethod
def insert_data(cls, connection):
connection.execute(
cls.tables.some_table.insert(),
[
{"id": 1, "x": 1, "y": 2},
{"id": 2, "x": 2, "y": 3},
{"id": 3, "x": 3, "y": 4},
{"id": 4, "x": 4, "y": 5},
],
)
def _assert_result(self, conn, select, result, params=()):
eq_(conn.execute(select, params).fetchall(), result)
def test_plain_union(self, connection):
table = self.tables.some_table
s1 = select(table).where(table.c.id == 2)
s2 = select(table).where(table.c.id == 3)
u1 = union(s1, s2)
with testing.expect_deprecated(
"The SelectBase.c and SelectBase.columns "
"attributes are deprecated"
):
self._assert_result(
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
)
# note we've had to remove one use case entirely, which is this
# one. the Select gets its FROMS from the WHERE clause and the
# columns clause, but not the ORDER BY, which means the old ".c" system
# allowed you to "order_by(s.c.foo)" to get an unnamed column in the
# ORDER BY without adding the SELECT into the FROM and breaking the
# query. Users will have to adjust for this use case if they were doing
# it before.
def _dont_test_select_from_plain_union(self, connection):
table = self.tables.some_table
s1 = select(table).where(table.c.id == 2)
s2 = select(table).where(table.c.id == 3)
u1 = union(s1, s2).alias().select()
with testing.expect_deprecated(
"The SelectBase.c and SelectBase.columns "
"attributes are deprecated"
):
self._assert_result(
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
)
@testing.requires.order_by_col_from_union
@testing.requires.parens_in_union_contained_select_w_limit_offset
def test_limit_offset_selectable_in_unions(self, connection):
table = self.tables.some_table
s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id)
s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id)
u1 = union(s1, s2).limit(2)
with testing.expect_deprecated(
"The SelectBase.c and SelectBase.columns "
"attributes are deprecated"
):
self._assert_result(
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
)
@testing.requires.parens_in_union_contained_select_wo_limit_offset
def test_order_by_selectable_in_unions(self, connection):
table = self.tables.some_table
s1 = select(table).where(table.c.id == 2).order_by(table.c.id)
s2 = select(table).where(table.c.id == 3).order_by(table.c.id)
u1 = union(s1, s2).limit(2)
with testing.expect_deprecated(
"The SelectBase.c and SelectBase.columns "
"attributes are deprecated"
):
self._assert_result(
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
)
def test_distinct_selectable_in_unions(self, connection):
table = self.tables.some_table
s1 = select(table).where(table.c.id == 2).distinct()
s2 = select(table).where(table.c.id == 3).distinct()
u1 = union(s1, s2).limit(2)
with testing.expect_deprecated(
"The SelectBase.c and SelectBase.columns "
"attributes are deprecated"
):
self._assert_result(
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
)
def test_limit_offset_aliased_selectable_in_unions(self, connection):
table = self.tables.some_table
s1 = (
select(table)
.where(table.c.id == 2)
.limit(1)
.order_by(table.c.id)
.alias()
.select()
)
s2 = (
select(table)
.where(table.c.id == 3)
.limit(1)
.order_by(table.c.id)
.alias()
.select()
)
u1 = union(s1, s2).limit(2)
with testing.expect_deprecated(
"The SelectBase.c and SelectBase.columns "
"attributes are deprecated"
):
self._assert_result(
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
)

View File

@ -0,0 +1,740 @@
# testing/suite/test_dialect.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
import importlib
from . import testing
from .. import assert_raises
from .. import config
from .. import engines
from .. import eq_
from .. import fixtures
from .. import is_not_none
from .. import is_true
from .. import ne_
from .. import provide_metadata
from ..assertions import expect_raises
from ..assertions import expect_raises_message
from ..config import requirements
from ..provision import set_default_schema_on_connection
from ..schema import Column
from ..schema import Table
from ... import bindparam
from ... import dialects
from ... import event
from ... import exc
from ... import Integer
from ... import literal_column
from ... import select
from ... import String
from ...sql.compiler import Compiled
from ...util import inspect_getfullargspec
class PingTest(fixtures.TestBase):
__backend__ = True
def test_do_ping(self):
with testing.db.connect() as conn:
is_true(
testing.db.dialect.do_ping(conn.connection.dbapi_connection)
)
class ArgSignatureTest(fixtures.TestBase):
"""test that all visit_XYZ() in :class:`_sql.Compiler` subclasses have
``**kw``, for #8988.
This test uses runtime code inspection. Does not need to be a
``__backend__`` test as it only needs to run once provided all target
dialects have been imported.
For third party dialects, the suite would be run with that third
party as a "--dburi", which means its compiler classes will have been
imported by the time this test runs.
"""
def _all_subclasses(): # type: ignore # noqa
for d in dialects.__all__:
if not d.startswith("_"):
importlib.import_module("sqlalchemy.dialects.%s" % d)
stack = [Compiled]
while stack:
cls = stack.pop(0)
stack.extend(cls.__subclasses__())
yield cls
@testing.fixture(params=list(_all_subclasses()))
def all_subclasses(self, request):
yield request.param
def test_all_visit_methods_accept_kw(self, all_subclasses):
cls = all_subclasses
for k in cls.__dict__:
if k.startswith("visit_"):
meth = getattr(cls, k)
insp = inspect_getfullargspec(meth)
is_not_none(
insp.varkw,
f"Compiler visit method {cls.__name__}.{k}() does "
"not accommodate for **kw in its argument signature",
)
class ExceptionTest(fixtures.TablesTest):
"""Test basic exception wrapping.
DBAPIs vary a lot in exception behavior so to actually anticipate
specific exceptions from real round trips, we need to be conservative.
"""
run_deletes = "each"
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
"manual_pk",
metadata,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("data", String(50)),
)
@requirements.duplicate_key_raises_integrity_error
def test_integrity_error(self):
with config.db.connect() as conn:
trans = conn.begin()
conn.execute(
self.tables.manual_pk.insert(), {"id": 1, "data": "d1"}
)
assert_raises(
exc.IntegrityError,
conn.execute,
self.tables.manual_pk.insert(),
{"id": 1, "data": "d1"},
)
trans.rollback()
def test_exception_with_non_ascii(self):
with config.db.connect() as conn:
try:
# try to create an error message that likely has non-ascii
# characters in the DBAPI's message string. unfortunately
# there's no way to make this happen with some drivers like
# mysqlclient, pymysql. this at least does produce a non-
# ascii error message for cx_oracle, psycopg2
conn.execute(select(literal_column("méil")))
assert False
except exc.DBAPIError as err:
err_str = str(err)
assert str(err.orig) in str(err)
assert isinstance(err_str, str)
class IsolationLevelTest(fixtures.TestBase):
__backend__ = True
__requires__ = ("isolation_level",)
def _get_non_default_isolation_level(self):
levels = requirements.get_isolation_levels(config)
default = levels["default"]
supported = levels["supported"]
s = set(supported).difference(["AUTOCOMMIT", default])
if s:
return s.pop()
else:
config.skip_test("no non-default isolation level available")
def test_default_isolation_level(self):
eq_(
config.db.dialect.default_isolation_level,
requirements.get_isolation_levels(config)["default"],
)
def test_non_default_isolation_level(self):
non_default = self._get_non_default_isolation_level()
with config.db.connect() as conn:
existing = conn.get_isolation_level()
ne_(existing, non_default)
conn.execution_options(isolation_level=non_default)
eq_(conn.get_isolation_level(), non_default)
conn.dialect.reset_isolation_level(
conn.connection.dbapi_connection
)
eq_(conn.get_isolation_level(), existing)
def test_all_levels(self):
levels = requirements.get_isolation_levels(config)
all_levels = levels["supported"]
for level in set(all_levels).difference(["AUTOCOMMIT"]):
with config.db.connect() as conn:
conn.execution_options(isolation_level=level)
eq_(conn.get_isolation_level(), level)
trans = conn.begin()
trans.rollback()
eq_(conn.get_isolation_level(), level)
with config.db.connect() as conn:
eq_(
conn.get_isolation_level(),
levels["default"],
)
@testing.requires.get_isolation_level_values
def test_invalid_level_execution_option(self, connection_no_trans):
"""test for the new get_isolation_level_values() method"""
connection = connection_no_trans
with expect_raises_message(
exc.ArgumentError,
"Invalid value '%s' for isolation_level. "
"Valid isolation levels for '%s' are %s"
% (
"FOO",
connection.dialect.name,
", ".join(
requirements.get_isolation_levels(config)["supported"]
),
),
):
connection.execution_options(isolation_level="FOO")
@testing.requires.get_isolation_level_values
@testing.requires.dialect_level_isolation_level_param
def test_invalid_level_engine_param(self, testing_engine):
"""test for the new get_isolation_level_values() method
and support for the dialect-level 'isolation_level' parameter.
"""
eng = testing_engine(options=dict(isolation_level="FOO"))
with expect_raises_message(
exc.ArgumentError,
"Invalid value '%s' for isolation_level. "
"Valid isolation levels for '%s' are %s"
% (
"FOO",
eng.dialect.name,
", ".join(
requirements.get_isolation_levels(config)["supported"]
),
),
):
eng.connect()
@testing.requires.independent_readonly_connections
def test_dialect_user_setting_is_restored(self, testing_engine):
levels = requirements.get_isolation_levels(config)
default = levels["default"]
supported = (
sorted(
set(levels["supported"]).difference([default, "AUTOCOMMIT"])
)
)[0]
e = testing_engine(options={"isolation_level": supported})
with e.connect() as conn:
eq_(conn.get_isolation_level(), supported)
with e.connect() as conn:
conn.execution_options(isolation_level=default)
eq_(conn.get_isolation_level(), default)
with e.connect() as conn:
eq_(conn.get_isolation_level(), supported)
class AutocommitIsolationTest(fixtures.TablesTest):
run_deletes = "each"
__requires__ = ("autocommit",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
"some_table",
metadata,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("data", String(50)),
test_needs_acid=True,
)
def _test_conn_autocommits(self, conn, autocommit):
trans = conn.begin()
conn.execute(
self.tables.some_table.insert(), {"id": 1, "data": "some data"}
)
trans.rollback()
eq_(
conn.scalar(select(self.tables.some_table.c.id)),
1 if autocommit else None,
)
conn.rollback()
with conn.begin():
conn.execute(self.tables.some_table.delete())
def test_autocommit_on(self, connection_no_trans):
conn = connection_no_trans
c2 = conn.execution_options(isolation_level="AUTOCOMMIT")
self._test_conn_autocommits(c2, True)
c2.dialect.reset_isolation_level(c2.connection.dbapi_connection)
self._test_conn_autocommits(conn, False)
def test_autocommit_off(self, connection_no_trans):
conn = connection_no_trans
self._test_conn_autocommits(conn, False)
def test_turn_autocommit_off_via_default_iso_level(
self, connection_no_trans
):
conn = connection_no_trans
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
self._test_conn_autocommits(conn, True)
conn.execution_options(
isolation_level=requirements.get_isolation_levels(config)[
"default"
]
)
self._test_conn_autocommits(conn, False)
@testing.requires.independent_readonly_connections
@testing.variation("use_dialect_setting", [True, False])
def test_dialect_autocommit_is_restored(
self, testing_engine, use_dialect_setting
):
"""test #10147"""
if use_dialect_setting:
e = testing_engine(options={"isolation_level": "AUTOCOMMIT"})
else:
e = testing_engine().execution_options(
isolation_level="AUTOCOMMIT"
)
levels = requirements.get_isolation_levels(config)
default = levels["default"]
with e.connect() as conn:
self._test_conn_autocommits(conn, True)
with e.connect() as conn:
conn.execution_options(isolation_level=default)
self._test_conn_autocommits(conn, False)
with e.connect() as conn:
self._test_conn_autocommits(conn, True)
class EscapingTest(fixtures.TestBase):
@provide_metadata
def test_percent_sign_round_trip(self):
"""test that the DBAPI accommodates for escaped / nonescaped
percent signs in a way that matches the compiler
"""
m = self.metadata
t = Table("t", m, Column("data", String(50)))
t.create(config.db)
with config.db.begin() as conn:
conn.execute(t.insert(), dict(data="some % value"))
conn.execute(t.insert(), dict(data="some %% other value"))
eq_(
conn.scalar(
select(t.c.data).where(
t.c.data == literal_column("'some % value'")
)
),
"some % value",
)
eq_(
conn.scalar(
select(t.c.data).where(
t.c.data == literal_column("'some %% other value'")
)
),
"some %% other value",
)
class WeCanSetDefaultSchemaWEventsTest(fixtures.TestBase):
__backend__ = True
__requires__ = ("default_schema_name_switch",)
def test_control_case(self):
default_schema_name = config.db.dialect.default_schema_name
eng = engines.testing_engine()
with eng.connect():
pass
eq_(eng.dialect.default_schema_name, default_schema_name)
def test_wont_work_wo_insert(self):
default_schema_name = config.db.dialect.default_schema_name
eng = engines.testing_engine()
@event.listens_for(eng, "connect")
def on_connect(dbapi_connection, connection_record):
set_default_schema_on_connection(
config, dbapi_connection, config.test_schema
)
with eng.connect() as conn:
what_it_should_be = eng.dialect._get_default_schema_name(conn)
eq_(what_it_should_be, config.test_schema)
eq_(eng.dialect.default_schema_name, default_schema_name)
def test_schema_change_on_connect(self):
eng = engines.testing_engine()
@event.listens_for(eng, "connect", insert=True)
def on_connect(dbapi_connection, connection_record):
set_default_schema_on_connection(
config, dbapi_connection, config.test_schema
)
with eng.connect() as conn:
what_it_should_be = eng.dialect._get_default_schema_name(conn)
eq_(what_it_should_be, config.test_schema)
eq_(eng.dialect.default_schema_name, config.test_schema)
def test_schema_change_works_w_transactions(self):
eng = engines.testing_engine()
@event.listens_for(eng, "connect", insert=True)
def on_connect(dbapi_connection, *arg):
set_default_schema_on_connection(
config, dbapi_connection, config.test_schema
)
with eng.connect() as conn:
trans = conn.begin()
what_it_should_be = eng.dialect._get_default_schema_name(conn)
eq_(what_it_should_be, config.test_schema)
trans.rollback()
what_it_should_be = eng.dialect._get_default_schema_name(conn)
eq_(what_it_should_be, config.test_schema)
eq_(eng.dialect.default_schema_name, config.test_schema)
class FutureWeCanSetDefaultSchemaWEventsTest(
fixtures.FutureEngineMixin, WeCanSetDefaultSchemaWEventsTest
):
pass
class DifficultParametersTest(fixtures.TestBase):
__backend__ = True
tough_parameters = testing.combinations(
("boring",),
("per cent",),
("per % cent",),
("%percent",),
("par(ens)",),
("percent%(ens)yah",),
("col:ons",),
("_starts_with_underscore",),
("dot.s",),
("more :: %colons%",),
("_name",),
("___name",),
("[BracketsAndCase]",),
("42numbers",),
("percent%signs",),
("has spaces",),
("/slashes/",),
("more/slashes",),
("q?marks",),
("1param",),
("1col:on",),
argnames="paramname",
)
@tough_parameters
@config.requirements.unusual_column_name_characters
def test_round_trip_same_named_column(
self, paramname, connection, metadata
):
name = paramname
t = Table(
"t",
metadata,
Column("id", Integer, primary_key=True),
Column(name, String(50), nullable=False),
)
# table is created
t.create(connection)
# automatic param generated by insert
connection.execute(t.insert().values({"id": 1, name: "some name"}))
# automatic param generated by criteria, plus selecting the column
stmt = select(t.c[name]).where(t.c[name] == "some name")
eq_(connection.scalar(stmt), "some name")
# use the name in a param explicitly
stmt = select(t.c[name]).where(t.c[name] == bindparam(name))
row = connection.execute(stmt, {name: "some name"}).first()
# name works as the key from cursor.description
eq_(row._mapping[name], "some name")
# use expanding IN
stmt = select(t.c[name]).where(
t.c[name].in_(["some name", "some other_name"])
)
row = connection.execute(stmt).first()
@testing.fixture
def multirow_fixture(self, metadata, connection):
mytable = Table(
"mytable",
metadata,
Column("myid", Integer),
Column("name", String(50)),
Column("desc", String(50)),
)
mytable.create(connection)
connection.execute(
mytable.insert(),
[
{"myid": 1, "name": "a", "desc": "a_desc"},
{"myid": 2, "name": "b", "desc": "b_desc"},
{"myid": 3, "name": "c", "desc": "c_desc"},
{"myid": 4, "name": "d", "desc": "d_desc"},
],
)
yield mytable
@tough_parameters
def test_standalone_bindparam_escape(
self, paramname, connection, multirow_fixture
):
tbl1 = multirow_fixture
stmt = select(tbl1.c.myid).where(
tbl1.c.name == bindparam(paramname, value="x")
)
res = connection.scalar(stmt, {paramname: "c"})
eq_(res, 3)
@tough_parameters
def test_standalone_bindparam_escape_expanding(
self, paramname, connection, multirow_fixture
):
tbl1 = multirow_fixture
stmt = (
select(tbl1.c.myid)
.where(tbl1.c.name.in_(bindparam(paramname, value=["a", "b"])))
.order_by(tbl1.c.myid)
)
res = connection.scalars(stmt, {paramname: ["d", "a"]}).all()
eq_(res, [1, 4])
class ReturningGuardsTest(fixtures.TablesTest):
"""test that the various 'returning' flags are set appropriately"""
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
"t",
metadata,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("data", String(50)),
)
@testing.fixture
def run_stmt(self, connection):
t = self.tables.t
def go(stmt, executemany, id_param_name, expect_success):
stmt = stmt.returning(t.c.id)
if executemany:
if not expect_success:
# for RETURNING executemany(), we raise our own
# error as this is independent of general RETURNING
# support
with expect_raises_message(
exc.StatementError,
rf"Dialect {connection.dialect.name}\+"
f"{connection.dialect.driver} with "
f"current server capabilities does not support "
f".*RETURNING when executemany is used",
):
result = connection.execute(
stmt,
[
{id_param_name: 1, "data": "d1"},
{id_param_name: 2, "data": "d2"},
{id_param_name: 3, "data": "d3"},
],
)
else:
result = connection.execute(
stmt,
[
{id_param_name: 1, "data": "d1"},
{id_param_name: 2, "data": "d2"},
{id_param_name: 3, "data": "d3"},
],
)
eq_(result.all(), [(1,), (2,), (3,)])
else:
if not expect_success:
# for RETURNING execute(), we pass all the way to the DB
# and let it fail
with expect_raises(exc.DBAPIError):
connection.execute(
stmt, {id_param_name: 1, "data": "d1"}
)
else:
result = connection.execute(
stmt, {id_param_name: 1, "data": "d1"}
)
eq_(result.all(), [(1,)])
return go
def test_insert_single(self, connection, run_stmt):
t = self.tables.t
stmt = t.insert()
run_stmt(stmt, False, "id", connection.dialect.insert_returning)
def test_insert_many(self, connection, run_stmt):
t = self.tables.t
stmt = t.insert()
run_stmt(
stmt, True, "id", connection.dialect.insert_executemany_returning
)
def test_update_single(self, connection, run_stmt):
t = self.tables.t
connection.execute(
t.insert(),
[
{"id": 1, "data": "d1"},
{"id": 2, "data": "d2"},
{"id": 3, "data": "d3"},
],
)
stmt = t.update().where(t.c.id == bindparam("b_id"))
run_stmt(stmt, False, "b_id", connection.dialect.update_returning)
def test_update_many(self, connection, run_stmt):
t = self.tables.t
connection.execute(
t.insert(),
[
{"id": 1, "data": "d1"},
{"id": 2, "data": "d2"},
{"id": 3, "data": "d3"},
],
)
stmt = t.update().where(t.c.id == bindparam("b_id"))
run_stmt(
stmt, True, "b_id", connection.dialect.update_executemany_returning
)
def test_delete_single(self, connection, run_stmt):
t = self.tables.t
connection.execute(
t.insert(),
[
{"id": 1, "data": "d1"},
{"id": 2, "data": "d2"},
{"id": 3, "data": "d3"},
],
)
stmt = t.delete().where(t.c.id == bindparam("b_id"))
run_stmt(stmt, False, "b_id", connection.dialect.delete_returning)
def test_delete_many(self, connection, run_stmt):
t = self.tables.t
connection.execute(
t.insert(),
[
{"id": 1, "data": "d1"},
{"id": 2, "data": "d2"},
{"id": 3, "data": "d3"},
],
)
stmt = t.delete().where(t.c.id == bindparam("b_id"))
run_stmt(
stmt, True, "b_id", connection.dialect.delete_executemany_returning
)

View File

@ -0,0 +1,630 @@
# testing/suite/test_insert.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from decimal import Decimal
import uuid
from . import testing
from .. import fixtures
from ..assertions import eq_
from ..config import requirements
from ..schema import Column
from ..schema import Table
from ... import Double
from ... import Float
from ... import Identity
from ... import Integer
from ... import literal
from ... import literal_column
from ... import Numeric
from ... import select
from ... import String
from ...types import LargeBinary
from ...types import UUID
from ...types import Uuid
class LastrowidTest(fixtures.TablesTest):
run_deletes = "each"
__backend__ = True
__requires__ = "implements_get_lastrowid", "autoincrement_insert"
@classmethod
def define_tables(cls, metadata):
Table(
"autoinc_pk",
metadata,
Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
),
Column("data", String(50)),
implicit_returning=False,
)
Table(
"manual_pk",
metadata,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("data", String(50)),
implicit_returning=False,
)
def _assert_round_trip(self, table, conn):
row = conn.execute(table.select()).first()
eq_(
row,
(
conn.dialect.default_sequence_base,
"some data",
),
)
def test_autoincrement_on_insert(self, connection):
connection.execute(
self.tables.autoinc_pk.insert(), dict(data="some data")
)
self._assert_round_trip(self.tables.autoinc_pk, connection)
def test_last_inserted_id(self, connection):
r = connection.execute(
self.tables.autoinc_pk.insert(), dict(data="some data")
)
pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
eq_(r.inserted_primary_key, (pk,))
@requirements.dbapi_lastrowid
def test_native_lastrowid_autoinc(self, connection):
r = connection.execute(
self.tables.autoinc_pk.insert(), dict(data="some data")
)
lastrowid = r.lastrowid
pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
eq_(lastrowid, pk)
class InsertBehaviorTest(fixtures.TablesTest):
run_deletes = "each"
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
"autoinc_pk",
metadata,
Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
),
Column("data", String(50)),
)
Table(
"manual_pk",
metadata,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("data", String(50)),
)
Table(
"no_implicit_returning",
metadata,
Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
),
Column("data", String(50)),
implicit_returning=False,
)
Table(
"includes_defaults",
metadata,
Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
),
Column("data", String(50)),
Column("x", Integer, default=5),
Column(
"y",
Integer,
default=literal_column("2", type_=Integer) + literal(2),
),
)
@testing.variation("style", ["plain", "return_defaults"])
@testing.variation("executemany", [True, False])
def test_no_results_for_non_returning_insert(
self, connection, style, executemany
):
"""test another INSERT issue found during #10453"""
table = self.tables.no_implicit_returning
stmt = table.insert()
if style.return_defaults:
stmt = stmt.return_defaults()
if executemany:
data = [
{"data": "d1"},
{"data": "d2"},
{"data": "d3"},
{"data": "d4"},
{"data": "d5"},
]
else:
data = {"data": "d1"}
r = connection.execute(stmt, data)
assert not r.returns_rows
@requirements.autoincrement_insert
def test_autoclose_on_insert(self, connection):
r = connection.execute(
self.tables.autoinc_pk.insert(), dict(data="some data")
)
assert r._soft_closed
assert not r.closed
assert r.is_insert
# new as of I8091919d45421e3f53029b8660427f844fee0228; for the moment
# an insert where the PK was taken from a row that the dialect
# selected, as is the case for mssql/pyodbc, will still report
# returns_rows as true because there's a cursor description. in that
# case, the row had to have been consumed at least.
assert not r.returns_rows or r.fetchone() is None
@requirements.insert_returning
def test_autoclose_on_insert_implicit_returning(self, connection):
r = connection.execute(
# return_defaults() ensures RETURNING will be used,
# new in 2.0 as sqlite/mariadb offer both RETURNING and
# cursor.lastrowid
self.tables.autoinc_pk.insert().return_defaults(),
dict(data="some data"),
)
assert r._soft_closed
assert not r.closed
assert r.is_insert
# note we are experimenting with having this be True
# as of I8091919d45421e3f53029b8660427f844fee0228 .
# implicit returning has fetched the row, but it still is a
# "returns rows"
assert r.returns_rows
# and we should be able to fetchone() on it, we just get no row
eq_(r.fetchone(), None)
# and the keys, etc.
eq_(r.keys(), ["id"])
# but the dialect took in the row already. not really sure
# what the best behavior is.
@requirements.empty_inserts
def test_empty_insert(self, connection):
r = connection.execute(self.tables.autoinc_pk.insert())
assert r._soft_closed
assert not r.closed
r = connection.execute(
self.tables.autoinc_pk.select().where(
self.tables.autoinc_pk.c.id != None
)
)
eq_(len(r.all()), 1)
@requirements.empty_inserts_executemany
def test_empty_insert_multiple(self, connection):
r = connection.execute(self.tables.autoinc_pk.insert(), [{}, {}, {}])
assert r._soft_closed
assert not r.closed
r = connection.execute(
self.tables.autoinc_pk.select().where(
self.tables.autoinc_pk.c.id != None
)
)
eq_(len(r.all()), 3)
@requirements.insert_from_select
def test_insert_from_select_autoinc(self, connection):
src_table = self.tables.manual_pk
dest_table = self.tables.autoinc_pk
connection.execute(
src_table.insert(),
[
dict(id=1, data="data1"),
dict(id=2, data="data2"),
dict(id=3, data="data3"),
],
)
result = connection.execute(
dest_table.insert().from_select(
("data",),
select(src_table.c.data).where(
src_table.c.data.in_(["data2", "data3"])
),
)
)
eq_(result.inserted_primary_key, (None,))
result = connection.execute(
select(dest_table.c.data).order_by(dest_table.c.data)
)
eq_(result.fetchall(), [("data2",), ("data3",)])
@requirements.insert_from_select
def test_insert_from_select_autoinc_no_rows(self, connection):
src_table = self.tables.manual_pk
dest_table = self.tables.autoinc_pk
result = connection.execute(
dest_table.insert().from_select(
("data",),
select(src_table.c.data).where(
src_table.c.data.in_(["data2", "data3"])
),
)
)
eq_(result.inserted_primary_key, (None,))
result = connection.execute(
select(dest_table.c.data).order_by(dest_table.c.data)
)
eq_(result.fetchall(), [])
@requirements.insert_from_select
def test_insert_from_select(self, connection):
table = self.tables.manual_pk
connection.execute(
table.insert(),
[
dict(id=1, data="data1"),
dict(id=2, data="data2"),
dict(id=3, data="data3"),
],
)
connection.execute(
table.insert()
.inline()
.from_select(
("id", "data"),
select(table.c.id + 5, table.c.data).where(
table.c.data.in_(["data2", "data3"])
),
)
)
eq_(
connection.execute(
select(table.c.data).order_by(table.c.data)
).fetchall(),
[("data1",), ("data2",), ("data2",), ("data3",), ("data3",)],
)
@requirements.insert_from_select
def test_insert_from_select_with_defaults(self, connection):
table = self.tables.includes_defaults
connection.execute(
table.insert(),
[
dict(id=1, data="data1"),
dict(id=2, data="data2"),
dict(id=3, data="data3"),
],
)
connection.execute(
table.insert()
.inline()
.from_select(
("id", "data"),
select(table.c.id + 5, table.c.data).where(
table.c.data.in_(["data2", "data3"])
),
)
)
eq_(
connection.execute(
select(table).order_by(table.c.data, table.c.id)
).fetchall(),
[
(1, "data1", 5, 4),
(2, "data2", 5, 4),
(7, "data2", 5, 4),
(3, "data3", 5, 4),
(8, "data3", 5, 4),
],
)
class ReturningTest(fixtures.TablesTest):
run_create_tables = "each"
__requires__ = "insert_returning", "autoincrement_insert"
__backend__ = True
def _assert_round_trip(self, table, conn):
row = conn.execute(table.select()).first()
eq_(
row,
(
conn.dialect.default_sequence_base,
"some data",
),
)
@classmethod
def define_tables(cls, metadata):
Table(
"autoinc_pk",
metadata,
Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
),
Column("data", String(50)),
)
@requirements.fetch_rows_post_commit
def test_explicit_returning_pk_autocommit(self, connection):
table = self.tables.autoinc_pk
r = connection.execute(
table.insert().returning(table.c.id), dict(data="some data")
)
pk = r.first()[0]
fetched_pk = connection.scalar(select(table.c.id))
eq_(fetched_pk, pk)
def test_explicit_returning_pk_no_autocommit(self, connection):
table = self.tables.autoinc_pk
r = connection.execute(
table.insert().returning(table.c.id), dict(data="some data")
)
pk = r.first()[0]
fetched_pk = connection.scalar(select(table.c.id))
eq_(fetched_pk, pk)
def test_autoincrement_on_insert_implicit_returning(self, connection):
connection.execute(
self.tables.autoinc_pk.insert(), dict(data="some data")
)
self._assert_round_trip(self.tables.autoinc_pk, connection)
def test_last_inserted_id_implicit_returning(self, connection):
r = connection.execute(
self.tables.autoinc_pk.insert(), dict(data="some data")
)
pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
eq_(r.inserted_primary_key, (pk,))
@requirements.insert_executemany_returning
def test_insertmanyvalues_returning(self, connection):
r = connection.execute(
self.tables.autoinc_pk.insert().returning(
self.tables.autoinc_pk.c.id
),
[
{"data": "d1"},
{"data": "d2"},
{"data": "d3"},
{"data": "d4"},
{"data": "d5"},
],
)
rall = r.all()
pks = connection.execute(select(self.tables.autoinc_pk.c.id))
eq_(rall, pks.all())
@testing.combinations(
(Double(), 8.5514716, True),
(
Double(53),
8.5514716,
True,
testing.requires.float_or_double_precision_behaves_generically,
),
(Float(), 8.5514, True),
(
Float(8),
8.5514,
True,
testing.requires.float_or_double_precision_behaves_generically,
),
(
Numeric(precision=15, scale=12, asdecimal=False),
8.5514716,
True,
testing.requires.literal_float_coercion,
),
(
Numeric(precision=15, scale=12, asdecimal=True),
Decimal("8.5514716"),
False,
),
argnames="type_,value,do_rounding",
)
@testing.variation("sort_by_parameter_order", [True, False])
@testing.variation("multiple_rows", [True, False])
def test_insert_w_floats(
self,
connection,
metadata,
sort_by_parameter_order,
type_,
value,
do_rounding,
multiple_rows,
):
"""test #9701.
this tests insertmanyvalues as well as decimal / floating point
RETURNING types
"""
t = Table(
# Oracle backends seems to be getting confused if
# this table is named the same as the one
# in test_imv_returning_datatypes. use a different name
"f_t",
metadata,
Column("id", Integer, Identity(), primary_key=True),
Column("value", type_),
)
t.create(connection)
result = connection.execute(
t.insert().returning(
t.c.id,
t.c.value,
sort_by_parameter_order=bool(sort_by_parameter_order),
),
(
[{"value": value} for i in range(10)]
if multiple_rows
else {"value": value}
),
)
if multiple_rows:
i_range = range(1, 11)
else:
i_range = range(1, 2)
# we want to test only that we are getting floating points back
# with some degree of the original value maintained, that it is not
# being truncated to an integer. there's too much variation in how
# drivers return floats, which should not be relied upon to be
# exact, for us to just compare as is (works for PG drivers but not
# others) so we use rounding here. There's precedent for this
# in suite/test_types.py::NumericTest as well
if do_rounding:
eq_(
{(id_, round(val_, 5)) for id_, val_ in result},
{(id_, round(value, 5)) for id_ in i_range},
)
eq_(
{
round(val_, 5)
for val_ in connection.scalars(select(t.c.value))
},
{round(value, 5)},
)
else:
eq_(
set(result),
{(id_, value) for id_ in i_range},
)
eq_(
set(connection.scalars(select(t.c.value))),
{value},
)
@testing.combinations(
(
"non_native_uuid",
Uuid(native_uuid=False),
uuid.uuid4(),
),
(
"non_native_uuid_str",
Uuid(as_uuid=False, native_uuid=False),
str(uuid.uuid4()),
),
(
"generic_native_uuid",
Uuid(native_uuid=True),
uuid.uuid4(),
testing.requires.uuid_data_type,
),
(
"generic_native_uuid_str",
Uuid(as_uuid=False, native_uuid=True),
str(uuid.uuid4()),
testing.requires.uuid_data_type,
),
("UUID", UUID(), uuid.uuid4(), testing.requires.uuid_data_type),
(
"LargeBinary1",
LargeBinary(),
b"this is binary",
),
("LargeBinary2", LargeBinary(), b"7\xe7\x9f"),
argnames="type_,value",
id_="iaa",
)
@testing.variation("sort_by_parameter_order", [True, False])
@testing.variation("multiple_rows", [True, False])
@testing.requires.insert_returning
def test_imv_returning_datatypes(
self,
connection,
metadata,
sort_by_parameter_order,
type_,
value,
multiple_rows,
):
"""test #9739, #9808 (similar to #9701).
this tests insertmanyvalues in conjunction with various datatypes.
These tests are particularly for the asyncpg driver which needs
most types to be explicitly cast for the new IMV format
"""
t = Table(
"d_t",
metadata,
Column("id", Integer, Identity(), primary_key=True),
Column("value", type_),
)
t.create(connection)
result = connection.execute(
t.insert().returning(
t.c.id,
t.c.value,
sort_by_parameter_order=bool(sort_by_parameter_order),
),
(
[{"value": value} for i in range(10)]
if multiple_rows
else {"value": value}
),
)
if multiple_rows:
i_range = range(1, 11)
else:
i_range = range(1, 2)
eq_(
set(result),
{(id_, value) for id_ in i_range},
)
eq_(
set(connection.scalars(select(t.c.value))),
{value},
)
__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,504 @@
# testing/suite/test_results.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
import datetime
import re
from .. import engines
from .. import fixtures
from ..assertions import eq_
from ..config import requirements
from ..schema import Column
from ..schema import Table
from ... import DateTime
from ... import func
from ... import Integer
from ... import select
from ... import sql
from ... import String
from ... import testing
from ... import text
class RowFetchTest(fixtures.TablesTest):
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
"plain_pk",
metadata,
Column("id", Integer, primary_key=True),
Column("data", String(50)),
)
Table(
"has_dates",
metadata,
Column("id", Integer, primary_key=True),
Column("today", DateTime),
)
@classmethod
def insert_data(cls, connection):
connection.execute(
cls.tables.plain_pk.insert(),
[
{"id": 1, "data": "d1"},
{"id": 2, "data": "d2"},
{"id": 3, "data": "d3"},
],
)
connection.execute(
cls.tables.has_dates.insert(),
[{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}],
)
def test_via_attr(self, connection):
row = connection.execute(
self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
).first()
eq_(row.id, 1)
eq_(row.data, "d1")
def test_via_string(self, connection):
row = connection.execute(
self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
).first()
eq_(row._mapping["id"], 1)
eq_(row._mapping["data"], "d1")
def test_via_int(self, connection):
row = connection.execute(
self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
).first()
eq_(row[0], 1)
eq_(row[1], "d1")
def test_via_col_object(self, connection):
row = connection.execute(
self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
).first()
eq_(row._mapping[self.tables.plain_pk.c.id], 1)
eq_(row._mapping[self.tables.plain_pk.c.data], "d1")
@requirements.duplicate_names_in_cursor_description
def test_row_with_dupe_names(self, connection):
result = connection.execute(
select(
self.tables.plain_pk.c.data,
self.tables.plain_pk.c.data.label("data"),
).order_by(self.tables.plain_pk.c.id)
)
row = result.first()
eq_(result.keys(), ["data", "data"])
eq_(row, ("d1", "d1"))
def test_row_w_scalar_select(self, connection):
"""test that a scalar select as a column is returned as such
and that type conversion works OK.
(this is half a SQLAlchemy Core test and half to catch database
backends that may have unusual behavior with scalar selects.)
"""
datetable = self.tables.has_dates
s = select(datetable.alias("x").c.today).scalar_subquery()
s2 = select(datetable.c.id, s.label("somelabel"))
row = connection.execute(s2).first()
eq_(row.somelabel, datetime.datetime(2006, 5, 12, 12, 0, 0))
class PercentSchemaNamesTest(fixtures.TablesTest):
"""tests using percent signs, spaces in table and column names.
This didn't work for PostgreSQL / MySQL drivers for a long time
but is now supported.
"""
__requires__ = ("percent_schema_names",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
cls.tables.percent_table = Table(
"percent%table",
metadata,
Column("percent%", Integer),
Column("spaces % more spaces", Integer),
)
cls.tables.lightweight_percent_table = sql.table(
"percent%table",
sql.column("percent%"),
sql.column("spaces % more spaces"),
)
def test_single_roundtrip(self, connection):
percent_table = self.tables.percent_table
for params in [
{"percent%": 5, "spaces % more spaces": 12},
{"percent%": 7, "spaces % more spaces": 11},
{"percent%": 9, "spaces % more spaces": 10},
{"percent%": 11, "spaces % more spaces": 9},
]:
connection.execute(percent_table.insert(), params)
self._assert_table(connection)
def test_executemany_roundtrip(self, connection):
percent_table = self.tables.percent_table
connection.execute(
percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
)
connection.execute(
percent_table.insert(),
[
{"percent%": 7, "spaces % more spaces": 11},
{"percent%": 9, "spaces % more spaces": 10},
{"percent%": 11, "spaces % more spaces": 9},
],
)
self._assert_table(connection)
@requirements.insert_executemany_returning
def test_executemany_returning_roundtrip(self, connection):
percent_table = self.tables.percent_table
connection.execute(
percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
)
result = connection.execute(
percent_table.insert().returning(
percent_table.c["percent%"],
percent_table.c["spaces % more spaces"],
),
[
{"percent%": 7, "spaces % more spaces": 11},
{"percent%": 9, "spaces % more spaces": 10},
{"percent%": 11, "spaces % more spaces": 9},
],
)
eq_(result.all(), [(7, 11), (9, 10), (11, 9)])
self._assert_table(connection)
def _assert_table(self, conn):
percent_table = self.tables.percent_table
lightweight_percent_table = self.tables.lightweight_percent_table
for table in (
percent_table,
percent_table.alias(),
lightweight_percent_table,
lightweight_percent_table.alias(),
):
eq_(
list(
conn.execute(table.select().order_by(table.c["percent%"]))
),
[(5, 12), (7, 11), (9, 10), (11, 9)],
)
eq_(
list(
conn.execute(
table.select()
.where(table.c["spaces % more spaces"].in_([9, 10]))
.order_by(table.c["percent%"])
)
),
[(9, 10), (11, 9)],
)
row = conn.execute(
table.select().order_by(table.c["percent%"])
).first()
eq_(row._mapping["percent%"], 5)
eq_(row._mapping["spaces % more spaces"], 12)
eq_(row._mapping[table.c["percent%"]], 5)
eq_(row._mapping[table.c["spaces % more spaces"]], 12)
conn.execute(
percent_table.update().values(
{percent_table.c["spaces % more spaces"]: 15}
)
)
eq_(
list(
conn.execute(
percent_table.select().order_by(
percent_table.c["percent%"]
)
)
),
[(5, 15), (7, 15), (9, 15), (11, 15)],
)
class ServerSideCursorsTest(
fixtures.TestBase, testing.AssertsExecutionResults
):
__requires__ = ("server_side_cursors",)
__backend__ = True
def _is_server_side(self, cursor):
# TODO: this is a huge issue as it prevents these tests from being
# usable by third party dialects.
if self.engine.dialect.driver == "psycopg2":
return bool(cursor.name)
elif self.engine.dialect.driver == "pymysql":
sscursor = __import__("pymysql.cursors").cursors.SSCursor
return isinstance(cursor, sscursor)
elif self.engine.dialect.driver in ("aiomysql", "asyncmy", "aioodbc"):
return cursor.server_side
elif self.engine.dialect.driver == "mysqldb":
sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
return isinstance(cursor, sscursor)
elif self.engine.dialect.driver == "mariadbconnector":
return not cursor.buffered
elif self.engine.dialect.driver == "mysqlconnector":
return "buffered" not in type(cursor).__name__.lower()
elif self.engine.dialect.driver in ("asyncpg", "aiosqlite"):
return cursor.server_side
elif self.engine.dialect.driver == "pg8000":
return getattr(cursor, "server_side", False)
elif self.engine.dialect.driver == "psycopg":
return bool(getattr(cursor, "name", False))
elif self.engine.dialect.driver == "oracledb":
return getattr(cursor, "server_side", False)
else:
return False
def _fixture(self, server_side_cursors):
if server_side_cursors:
with testing.expect_deprecated(
"The create_engine.server_side_cursors parameter is "
"deprecated and will be removed in a future release. "
"Please use the Connection.execution_options.stream_results "
"parameter."
):
self.engine = engines.testing_engine(
options={"server_side_cursors": server_side_cursors}
)
else:
self.engine = engines.testing_engine(
options={"server_side_cursors": server_side_cursors}
)
return self.engine
def stringify(self, str_):
return re.compile(r"SELECT (\d+)", re.I).sub(
lambda m: str(select(int(m.group(1))).compile(testing.db)), str_
)
@testing.combinations(
("global_string", True, lambda stringify: stringify("select 1"), True),
(
"global_text",
True,
lambda stringify: text(stringify("select 1")),
True,
),
("global_expr", True, select(1), True),
(
"global_off_explicit",
False,
lambda stringify: text(stringify("select 1")),
False,
),
(
"stmt_option",
False,
select(1).execution_options(stream_results=True),
True,
),
(
"stmt_option_disabled",
True,
select(1).execution_options(stream_results=False),
False,
),
("for_update_expr", True, select(1).with_for_update(), True),
# TODO: need a real requirement for this, or dont use this test
(
"for_update_string",
True,
lambda stringify: stringify("SELECT 1 FOR UPDATE"),
True,
testing.skip_if(["sqlite", "mssql"]),
),
(
"text_no_ss",
False,
lambda stringify: text(stringify("select 42")),
False,
),
(
"text_ss_option",
False,
lambda stringify: text(stringify("select 42")).execution_options(
stream_results=True
),
True,
),
id_="iaaa",
argnames="engine_ss_arg, statement, cursor_ss_status",
)
def test_ss_cursor_status(
self, engine_ss_arg, statement, cursor_ss_status
):
engine = self._fixture(engine_ss_arg)
with engine.begin() as conn:
if callable(statement):
statement = testing.resolve_lambda(
statement, stringify=self.stringify
)
if isinstance(statement, str):
result = conn.exec_driver_sql(statement)
else:
result = conn.execute(statement)
eq_(self._is_server_side(result.cursor), cursor_ss_status)
result.close()
def test_conn_option(self):
engine = self._fixture(False)
with engine.connect() as conn:
# should be enabled for this one
result = conn.execution_options(
stream_results=True
).exec_driver_sql(self.stringify("select 1"))
assert self._is_server_side(result.cursor)
# the connection has autobegun, which means at the end of the
# block, we will roll back, which on MySQL at least will fail
# with "Commands out of sync" if the result set
# is not closed, so we close it first.
#
# fun fact! why did we not have this result.close() in this test
# before 2.0? don't we roll back in the connection pool
# unconditionally? yes! and in fact if you run this test in 1.4
# with stdout shown, there is in fact "Exception during reset or
# similar" with "Commands out sync" emitted a warning! 2.0's
# architecture finds and fixes what was previously an expensive
# silent error condition.
result.close()
def test_stmt_enabled_conn_option_disabled(self):
engine = self._fixture(False)
s = select(1).execution_options(stream_results=True)
with engine.connect() as conn:
# not this one
result = conn.execution_options(stream_results=False).execute(s)
assert not self._is_server_side(result.cursor)
def test_aliases_and_ss(self):
engine = self._fixture(False)
s1 = (
select(sql.literal_column("1").label("x"))
.execution_options(stream_results=True)
.subquery()
)
# options don't propagate out when subquery is used as a FROM clause
with engine.begin() as conn:
result = conn.execute(s1.select())
assert not self._is_server_side(result.cursor)
result.close()
s2 = select(1).select_from(s1)
with engine.begin() as conn:
result = conn.execute(s2)
assert not self._is_server_side(result.cursor)
result.close()
def test_roundtrip_fetchall(self, metadata):
md = self.metadata
engine = self._fixture(True)
test_table = Table(
"test_table",
md,
Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
),
Column("data", String(50)),
)
with engine.begin() as connection:
test_table.create(connection, checkfirst=True)
connection.execute(test_table.insert(), dict(data="data1"))
connection.execute(test_table.insert(), dict(data="data2"))
eq_(
connection.execute(
test_table.select().order_by(test_table.c.id)
).fetchall(),
[(1, "data1"), (2, "data2")],
)
connection.execute(
test_table.update()
.where(test_table.c.id == 2)
.values(data=test_table.c.data + " updated")
)
eq_(
connection.execute(
test_table.select().order_by(test_table.c.id)
).fetchall(),
[(1, "data1"), (2, "data2 updated")],
)
connection.execute(test_table.delete())
eq_(
connection.scalar(
select(func.count("*")).select_from(test_table)
),
0,
)
def test_roundtrip_fetchmany(self, metadata):
md = self.metadata
engine = self._fixture(True)
test_table = Table(
"test_table",
md,
Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
),
Column("data", String(50)),
)
with engine.begin() as connection:
test_table.create(connection, checkfirst=True)
connection.execute(
test_table.insert(),
[dict(data="data%d" % i) for i in range(1, 20)],
)
result = connection.execute(
test_table.select().order_by(test_table.c.id)
)
eq_(
result.fetchmany(5),
[(i, "data%d" % i) for i in range(1, 6)],
)
eq_(
result.fetchmany(10),
[(i, "data%d" % i) for i in range(6, 16)],
)
eq_(result.fetchall(), [(i, "data%d" % i) for i in range(16, 20)])

View File

@ -0,0 +1,258 @@
# testing/suite/test_rowcount.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from sqlalchemy import bindparam
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import testing
from sqlalchemy import text
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
class RowCountTest(fixtures.TablesTest):
"""test rowcount functionality"""
__requires__ = ("sane_rowcount",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
"employees",
metadata,
Column(
"employee_id",
Integer,
autoincrement=False,
primary_key=True,
),
Column("name", String(50)),
Column("department", String(1)),
)
@classmethod
def insert_data(cls, connection):
cls.data = data = [
("Angela", "A"),
("Andrew", "A"),
("Anand", "A"),
("Bob", "B"),
("Bobette", "B"),
("Buffy", "B"),
("Charlie", "C"),
("Cynthia", "C"),
("Chris", "C"),
]
employees_table = cls.tables.employees
connection.execute(
employees_table.insert(),
[
{"employee_id": i, "name": n, "department": d}
for i, (n, d) in enumerate(data)
],
)
def test_basic(self, connection):
employees_table = self.tables.employees
s = select(
employees_table.c.name, employees_table.c.department
).order_by(employees_table.c.employee_id)
rows = connection.execute(s).fetchall()
eq_(rows, self.data)
@testing.variation("statement", ["update", "delete", "insert", "select"])
@testing.variation("close_first", [True, False])
def test_non_rowcount_scenarios_no_raise(
self, connection, statement, close_first
):
employees_table = self.tables.employees
# WHERE matches 3, 3 rows changed
department = employees_table.c.department
if statement.update:
r = connection.execute(
employees_table.update().where(department == "C"),
{"department": "Z"},
)
elif statement.delete:
r = connection.execute(
employees_table.delete().where(department == "C"),
{"department": "Z"},
)
elif statement.insert:
r = connection.execute(
employees_table.insert(),
[
{"employee_id": 25, "name": "none 1", "department": "X"},
{"employee_id": 26, "name": "none 2", "department": "Z"},
{"employee_id": 27, "name": "none 3", "department": "Z"},
],
)
elif statement.select:
s = select(
employees_table.c.name, employees_table.c.department
).where(employees_table.c.department == "C")
r = connection.execute(s)
r.all()
else:
statement.fail()
if close_first:
r.close()
assert r.rowcount in (-1, 3)
def test_update_rowcount1(self, connection):
employees_table = self.tables.employees
# WHERE matches 3, 3 rows changed
department = employees_table.c.department
r = connection.execute(
employees_table.update().where(department == "C"),
{"department": "Z"},
)
assert r.rowcount == 3
def test_update_rowcount2(self, connection):
employees_table = self.tables.employees
# WHERE matches 3, 0 rows changed
department = employees_table.c.department
r = connection.execute(
employees_table.update().where(department == "C"),
{"department": "C"},
)
eq_(r.rowcount, 3)
@testing.variation("implicit_returning", [True, False])
@testing.variation(
"dml",
[
("update", testing.requires.update_returning),
("delete", testing.requires.delete_returning),
],
)
def test_update_delete_rowcount_return_defaults(
self, connection, implicit_returning, dml
):
"""note this test should succeed for all RETURNING backends
as of 2.0. In
Idf28379f8705e403a3c6a937f6a798a042ef2540 we changed rowcount to use
len(rows) when we have implicit returning
"""
if implicit_returning:
employees_table = self.tables.employees
else:
employees_table = Table(
"employees",
MetaData(),
Column(
"employee_id",
Integer,
autoincrement=False,
primary_key=True,
),
Column("name", String(50)),
Column("department", String(1)),
implicit_returning=False,
)
department = employees_table.c.department
if dml.update:
stmt = (
employees_table.update()
.where(department == "C")
.values(name=employees_table.c.department + "Z")
.return_defaults()
)
elif dml.delete:
stmt = (
employees_table.delete()
.where(department == "C")
.return_defaults()
)
else:
dml.fail()
r = connection.execute(stmt)
eq_(r.rowcount, 3)
def test_raw_sql_rowcount(self, connection):
# test issue #3622, make sure eager rowcount is called for text
result = connection.exec_driver_sql(
"update employees set department='Z' where department='C'"
)
eq_(result.rowcount, 3)
def test_text_rowcount(self, connection):
# test issue #3622, make sure eager rowcount is called for text
result = connection.execute(
text("update employees set department='Z' where department='C'")
)
eq_(result.rowcount, 3)
def test_delete_rowcount(self, connection):
employees_table = self.tables.employees
# WHERE matches 3, 3 rows deleted
department = employees_table.c.department
r = connection.execute(
employees_table.delete().where(department == "C")
)
eq_(r.rowcount, 3)
@testing.requires.sane_multi_rowcount
def test_multi_update_rowcount(self, connection):
employees_table = self.tables.employees
stmt = (
employees_table.update()
.where(employees_table.c.name == bindparam("emp_name"))
.values(department="C")
)
r = connection.execute(
stmt,
[
{"emp_name": "Bob"},
{"emp_name": "Cynthia"},
{"emp_name": "nonexistent"},
],
)
eq_(r.rowcount, 2)
@testing.requires.sane_multi_rowcount
def test_multi_delete_rowcount(self, connection):
employees_table = self.tables.employees
stmt = employees_table.delete().where(
employees_table.c.name == bindparam("emp_name")
)
r = connection.execute(
stmt,
[
{"emp_name": "Bob"},
{"emp_name": "Cynthia"},
{"emp_name": "nonexistent"},
],
)
eq_(r.rowcount, 2)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,317 @@
# testing/suite/test_sequence.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from .. import config
from .. import fixtures
from ..assertions import eq_
from ..assertions import is_true
from ..config import requirements
from ..provision import normalize_sequence
from ..schema import Column
from ..schema import Table
from ... import inspect
from ... import Integer
from ... import MetaData
from ... import Sequence
from ... import String
from ... import testing
class SequenceTest(fixtures.TablesTest):
__requires__ = ("sequences",)
__backend__ = True
run_create_tables = "each"
@classmethod
def define_tables(cls, metadata):
Table(
"seq_pk",
metadata,
Column(
"id",
Integer,
normalize_sequence(config, Sequence("tab_id_seq")),
primary_key=True,
),
Column("data", String(50)),
)
Table(
"seq_opt_pk",
metadata,
Column(
"id",
Integer,
normalize_sequence(
config,
Sequence("tab_id_seq", data_type=Integer, optional=True),
),
primary_key=True,
),
Column("data", String(50)),
)
Table(
"seq_no_returning",
metadata,
Column(
"id",
Integer,
normalize_sequence(config, Sequence("noret_id_seq")),
primary_key=True,
),
Column("data", String(50)),
implicit_returning=False,
)
if testing.requires.schemas.enabled:
Table(
"seq_no_returning_sch",
metadata,
Column(
"id",
Integer,
normalize_sequence(
config,
Sequence(
"noret_sch_id_seq", schema=config.test_schema
),
),
primary_key=True,
),
Column("data", String(50)),
implicit_returning=False,
schema=config.test_schema,
)
def test_insert_roundtrip(self, connection):
connection.execute(self.tables.seq_pk.insert(), dict(data="some data"))
self._assert_round_trip(self.tables.seq_pk, connection)
def test_insert_lastrowid(self, connection):
r = connection.execute(
self.tables.seq_pk.insert(), dict(data="some data")
)
eq_(
r.inserted_primary_key, (testing.db.dialect.default_sequence_base,)
)
def test_nextval_direct(self, connection):
r = connection.scalar(self.tables.seq_pk.c.id.default)
eq_(r, testing.db.dialect.default_sequence_base)
@requirements.sequences_optional
def test_optional_seq(self, connection):
r = connection.execute(
self.tables.seq_opt_pk.insert(), dict(data="some data")
)
eq_(r.inserted_primary_key, (1,))
def _assert_round_trip(self, table, conn):
row = conn.execute(table.select()).first()
eq_(row, (testing.db.dialect.default_sequence_base, "some data"))
def test_insert_roundtrip_no_implicit_returning(self, connection):
connection.execute(
self.tables.seq_no_returning.insert(), dict(data="some data")
)
self._assert_round_trip(self.tables.seq_no_returning, connection)
@testing.combinations((True,), (False,), argnames="implicit_returning")
@testing.requires.schemas
def test_insert_roundtrip_translate(self, connection, implicit_returning):
seq_no_returning = Table(
"seq_no_returning_sch",
MetaData(),
Column(
"id",
Integer,
normalize_sequence(
config, Sequence("noret_sch_id_seq", schema="alt_schema")
),
primary_key=True,
),
Column("data", String(50)),
implicit_returning=implicit_returning,
schema="alt_schema",
)
connection = connection.execution_options(
schema_translate_map={"alt_schema": config.test_schema}
)
connection.execute(seq_no_returning.insert(), dict(data="some data"))
self._assert_round_trip(seq_no_returning, connection)
@testing.requires.schemas
def test_nextval_direct_schema_translate(self, connection):
seq = normalize_sequence(
config, Sequence("noret_sch_id_seq", schema="alt_schema")
)
connection = connection.execution_options(
schema_translate_map={"alt_schema": config.test_schema}
)
r = connection.scalar(seq)
eq_(r, testing.db.dialect.default_sequence_base)
class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase):
__requires__ = ("sequences",)
__backend__ = True
def test_literal_binds_inline_compile(self, connection):
table = Table(
"x",
MetaData(),
Column(
"y", Integer, normalize_sequence(config, Sequence("y_seq"))
),
Column("q", Integer),
)
stmt = table.insert().values(q=5)
seq_nextval = connection.dialect.statement_compiler(
statement=None, dialect=connection.dialect
).visit_sequence(normalize_sequence(config, Sequence("y_seq")))
self.assert_compile(
stmt,
"INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval,),
literal_binds=True,
dialect=connection.dialect,
)
class HasSequenceTest(fixtures.TablesTest):
run_deletes = None
__requires__ = ("sequences",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
normalize_sequence(config, Sequence("user_id_seq", metadata=metadata))
normalize_sequence(
config,
Sequence(
"other_seq",
metadata=metadata,
nomaxvalue=True,
nominvalue=True,
),
)
if testing.requires.schemas.enabled:
normalize_sequence(
config,
Sequence(
"user_id_seq", schema=config.test_schema, metadata=metadata
),
)
normalize_sequence(
config,
Sequence(
"schema_seq", schema=config.test_schema, metadata=metadata
),
)
Table(
"user_id_table",
metadata,
Column("id", Integer, primary_key=True),
)
def test_has_sequence(self, connection):
eq_(inspect(connection).has_sequence("user_id_seq"), True)
def test_has_sequence_cache(self, connection, metadata):
insp = inspect(connection)
eq_(insp.has_sequence("user_id_seq"), True)
ss = normalize_sequence(config, Sequence("new_seq", metadata=metadata))
eq_(insp.has_sequence("new_seq"), False)
ss.create(connection)
try:
eq_(insp.has_sequence("new_seq"), False)
insp.clear_cache()
eq_(insp.has_sequence("new_seq"), True)
finally:
ss.drop(connection)
def test_has_sequence_other_object(self, connection):
eq_(inspect(connection).has_sequence("user_id_table"), False)
@testing.requires.schemas
def test_has_sequence_schema(self, connection):
eq_(
inspect(connection).has_sequence(
"user_id_seq", schema=config.test_schema
),
True,
)
def test_has_sequence_neg(self, connection):
eq_(inspect(connection).has_sequence("some_sequence"), False)
@testing.requires.schemas
def test_has_sequence_schemas_neg(self, connection):
eq_(
inspect(connection).has_sequence(
"some_sequence", schema=config.test_schema
),
False,
)
@testing.requires.schemas
def test_has_sequence_default_not_in_remote(self, connection):
eq_(
inspect(connection).has_sequence(
"other_sequence", schema=config.test_schema
),
False,
)
@testing.requires.schemas
def test_has_sequence_remote_not_in_default(self, connection):
eq_(inspect(connection).has_sequence("schema_seq"), False)
def test_get_sequence_names(self, connection):
exp = {"other_seq", "user_id_seq"}
res = set(inspect(connection).get_sequence_names())
is_true(res.intersection(exp) == exp)
is_true("schema_seq" not in res)
@testing.requires.schemas
def test_get_sequence_names_no_sequence_schema(self, connection):
eq_(
inspect(connection).get_sequence_names(
schema=config.test_schema_2
),
[],
)
@testing.requires.schemas
def test_get_sequence_names_sequences_schema(self, connection):
eq_(
sorted(
inspect(connection).get_sequence_names(
schema=config.test_schema
)
),
["schema_seq", "user_id_seq"],
)
class HasSequenceTestEmpty(fixtures.TestBase):
__requires__ = ("sequences",)
__backend__ = True
def test_get_sequence_names_no_sequence(self, connection):
eq_(
inspect(connection).get_sequence_names(),
[],
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,189 @@
# testing/suite/test_unicode_ddl.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from sqlalchemy import desc
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import testing
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
class UnicodeSchemaTest(fixtures.TablesTest):
__requires__ = ("unicode_ddl",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
global t1, t2, t3
t1 = Table(
"unitable1",
metadata,
Column("méil", Integer, primary_key=True),
Column("\u6e2c\u8a66", Integer),
test_needs_fk=True,
)
t2 = Table(
"Unitéble2",
metadata,
Column("méil", Integer, primary_key=True, key="a"),
Column(
"\u6e2c\u8a66",
Integer,
ForeignKey("unitable1.méil"),
key="b",
),
test_needs_fk=True,
)
# Few DBs support Unicode foreign keys
if testing.against("sqlite"):
t3 = Table(
"\u6e2c\u8a66",
metadata,
Column(
"\u6e2c\u8a66_id",
Integer,
primary_key=True,
autoincrement=False,
),
Column(
"unitable1_\u6e2c\u8a66",
Integer,
ForeignKey("unitable1.\u6e2c\u8a66"),
),
Column("Unitéble2_b", Integer, ForeignKey("Unitéble2.b")),
Column(
"\u6e2c\u8a66_self",
Integer,
ForeignKey("\u6e2c\u8a66.\u6e2c\u8a66_id"),
),
test_needs_fk=True,
)
else:
t3 = Table(
"\u6e2c\u8a66",
metadata,
Column(
"\u6e2c\u8a66_id",
Integer,
primary_key=True,
autoincrement=False,
),
Column("unitable1_\u6e2c\u8a66", Integer),
Column("Unitéble2_b", Integer),
Column("\u6e2c\u8a66_self", Integer),
test_needs_fk=True,
)
def test_insert(self, connection):
connection.execute(t1.insert(), {"méil": 1, "\u6e2c\u8a66": 5})
connection.execute(t2.insert(), {"a": 1, "b": 1})
connection.execute(
t3.insert(),
{
"\u6e2c\u8a66_id": 1,
"unitable1_\u6e2c\u8a66": 5,
"Unitéble2_b": 1,
"\u6e2c\u8a66_self": 1,
},
)
eq_(connection.execute(t1.select()).fetchall(), [(1, 5)])
eq_(connection.execute(t2.select()).fetchall(), [(1, 1)])
eq_(connection.execute(t3.select()).fetchall(), [(1, 5, 1, 1)])
def test_col_targeting(self, connection):
connection.execute(t1.insert(), {"méil": 1, "\u6e2c\u8a66": 5})
connection.execute(t2.insert(), {"a": 1, "b": 1})
connection.execute(
t3.insert(),
{
"\u6e2c\u8a66_id": 1,
"unitable1_\u6e2c\u8a66": 5,
"Unitéble2_b": 1,
"\u6e2c\u8a66_self": 1,
},
)
row = connection.execute(t1.select()).first()
eq_(row._mapping[t1.c["méil"]], 1)
eq_(row._mapping[t1.c["\u6e2c\u8a66"]], 5)
row = connection.execute(t2.select()).first()
eq_(row._mapping[t2.c["a"]], 1)
eq_(row._mapping[t2.c["b"]], 1)
row = connection.execute(t3.select()).first()
eq_(row._mapping[t3.c["\u6e2c\u8a66_id"]], 1)
eq_(row._mapping[t3.c["unitable1_\u6e2c\u8a66"]], 5)
eq_(row._mapping[t3.c["Unitéble2_b"]], 1)
eq_(row._mapping[t3.c["\u6e2c\u8a66_self"]], 1)
def test_reflect(self, connection):
connection.execute(t1.insert(), {"méil": 2, "\u6e2c\u8a66": 7})
connection.execute(t2.insert(), {"a": 2, "b": 2})
connection.execute(
t3.insert(),
{
"\u6e2c\u8a66_id": 2,
"unitable1_\u6e2c\u8a66": 7,
"Unitéble2_b": 2,
"\u6e2c\u8a66_self": 2,
},
)
meta = MetaData()
tt1 = Table(t1.name, meta, autoload_with=connection)
tt2 = Table(t2.name, meta, autoload_with=connection)
tt3 = Table(t3.name, meta, autoload_with=connection)
connection.execute(tt1.insert(), {"méil": 1, "\u6e2c\u8a66": 5})
connection.execute(tt2.insert(), {"méil": 1, "\u6e2c\u8a66": 1})
connection.execute(
tt3.insert(),
{
"\u6e2c\u8a66_id": 1,
"unitable1_\u6e2c\u8a66": 5,
"Unitéble2_b": 1,
"\u6e2c\u8a66_self": 1,
},
)
eq_(
connection.execute(tt1.select().order_by(desc("méil"))).fetchall(),
[(2, 7), (1, 5)],
)
eq_(
connection.execute(tt2.select().order_by(desc("méil"))).fetchall(),
[(2, 2), (1, 1)],
)
eq_(
connection.execute(
tt3.select().order_by(desc("\u6e2c\u8a66_id"))
).fetchall(),
[(2, 7, 2, 2), (1, 5, 1, 1)],
)
def test_repr(self):
meta = MetaData()
t = Table("\u6e2c\u8a66", meta, Column("\u6e2c\u8a66_id", Integer))
eq_(
repr(t),
(
"Table('測試', MetaData(), "
"Column('測試_id', Integer(), "
"table=<測試>), "
"schema=None)"
),
)

View File

@ -0,0 +1,139 @@
# testing/suite/test_update_delete.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from .. import fixtures
from ..assertions import eq_
from ..schema import Column
from ..schema import Table
from ... import Integer
from ... import String
from ... import testing
class SimpleUpdateDeleteTest(fixtures.TablesTest):
run_deletes = "each"
__requires__ = ("sane_rowcount",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
"plain_pk",
metadata,
Column("id", Integer, primary_key=True),
Column("data", String(50)),
)
@classmethod
def insert_data(cls, connection):
connection.execute(
cls.tables.plain_pk.insert(),
[
{"id": 1, "data": "d1"},
{"id": 2, "data": "d2"},
{"id": 3, "data": "d3"},
],
)
def test_update(self, connection):
t = self.tables.plain_pk
r = connection.execute(
t.update().where(t.c.id == 2), dict(data="d2_new")
)
assert not r.is_insert
assert not r.returns_rows
assert r.rowcount == 1
eq_(
connection.execute(t.select().order_by(t.c.id)).fetchall(),
[(1, "d1"), (2, "d2_new"), (3, "d3")],
)
def test_delete(self, connection):
t = self.tables.plain_pk
r = connection.execute(t.delete().where(t.c.id == 2))
assert not r.is_insert
assert not r.returns_rows
assert r.rowcount == 1
eq_(
connection.execute(t.select().order_by(t.c.id)).fetchall(),
[(1, "d1"), (3, "d3")],
)
@testing.variation("criteria", ["rows", "norows", "emptyin"])
@testing.requires.update_returning
def test_update_returning(self, connection, criteria):
t = self.tables.plain_pk
stmt = t.update().returning(t.c.id, t.c.data)
if criteria.norows:
stmt = stmt.where(t.c.id == 10)
elif criteria.rows:
stmt = stmt.where(t.c.id == 2)
elif criteria.emptyin:
stmt = stmt.where(t.c.id.in_([]))
else:
criteria.fail()
r = connection.execute(stmt, dict(data="d2_new"))
assert not r.is_insert
assert r.returns_rows
eq_(r.keys(), ["id", "data"])
if criteria.rows:
eq_(r.all(), [(2, "d2_new")])
else:
eq_(r.all(), [])
eq_(
connection.execute(t.select().order_by(t.c.id)).fetchall(),
(
[(1, "d1"), (2, "d2_new"), (3, "d3")]
if criteria.rows
else [(1, "d1"), (2, "d2"), (3, "d3")]
),
)
@testing.variation("criteria", ["rows", "norows", "emptyin"])
@testing.requires.delete_returning
def test_delete_returning(self, connection, criteria):
t = self.tables.plain_pk
stmt = t.delete().returning(t.c.id, t.c.data)
if criteria.norows:
stmt = stmt.where(t.c.id == 10)
elif criteria.rows:
stmt = stmt.where(t.c.id == 2)
elif criteria.emptyin:
stmt = stmt.where(t.c.id.in_([]))
else:
criteria.fail()
r = connection.execute(stmt)
assert not r.is_insert
assert r.returns_rows
eq_(r.keys(), ["id", "data"])
if criteria.rows:
eq_(r.all(), [(2, "d2")])
else:
eq_(r.all(), [])
eq_(
connection.execute(t.select().order_by(t.c.id)).fetchall(),
(
[(1, "d1"), (3, "d3")]
if criteria.rows
else [(1, "d1"), (2, "d2"), (3, "d3")]
),
)
__all__ = ("SimpleUpdateDeleteTest",)