Update 2025-04-24_11:44:19

This commit is contained in:
oib
2025-04-24 11:44:23 +02:00
commit e748c737f4
3408 changed files with 717481 additions and 0 deletions

View File

@ -0,0 +1,6 @@
# ext/mypy/__init__.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php

View File

@ -0,0 +1,324 @@
# ext/mypy/apply.py
# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import List
from typing import Optional
from typing import Union
from mypy.nodes import ARG_NAMED_OPT
from mypy.nodes import Argument
from mypy.nodes import AssignmentStmt
from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import MDEF
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import RefExpr
from mypy.nodes import StrExpr
from mypy.nodes import SymbolTableNode
from mypy.nodes import TempNode
from mypy.nodes import TypeInfo
from mypy.nodes import Var
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import add_method_to_class
from mypy.types import AnyType
from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import NoneTyp
from mypy.types import ProperType
from mypy.types import TypeOfAny
from mypy.types import UnboundType
from mypy.types import UnionType
from . import infer
from . import util
from .names import expr_to_mapped_constructor
from .names import NAMED_TYPE_SQLA_MAPPED
def apply_mypy_mapped_attr(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
item: Union[NameExpr, StrExpr],
attributes: List[util.SQLAlchemyAttribute],
) -> None:
if isinstance(item, NameExpr):
name = item.name
elif isinstance(item, StrExpr):
name = item.value
else:
return None
for stmt in cls.defs.body:
if (
isinstance(stmt, AssignmentStmt)
and isinstance(stmt.lvalues[0], NameExpr)
and stmt.lvalues[0].name == name
):
break
else:
util.fail(api, f"Can't find mapped attribute {name}", cls)
return None
if stmt.type is None:
util.fail(
api,
"Statement linked from _mypy_mapped_attrs has no "
"typing information",
stmt,
)
return None
left_hand_explicit_type = get_proper_type(stmt.type)
assert isinstance(
left_hand_explicit_type, (Instance, UnionType, UnboundType)
)
attributes.append(
util.SQLAlchemyAttribute(
name=name,
line=item.line,
column=item.column,
typ=left_hand_explicit_type,
info=cls.info,
)
)
apply_type_to_mapped_statement(
api, stmt, stmt.lvalues[0], left_hand_explicit_type, None
)
def re_apply_declarative_assignments(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""For multiple class passes, re-apply our left-hand side types as mypy
seems to reset them in place.
"""
mapped_attr_lookup = {attr.name: attr for attr in attributes}
update_cls_metadata = False
for stmt in cls.defs.body:
# for a re-apply, all of our statements are AssignmentStmt;
# @declared_attr calls will have been converted and this
# currently seems to be preserved by mypy (but who knows if this
# will change).
if (
isinstance(stmt, AssignmentStmt)
and isinstance(stmt.lvalues[0], NameExpr)
and stmt.lvalues[0].name in mapped_attr_lookup
and isinstance(stmt.lvalues[0].node, Var)
):
left_node = stmt.lvalues[0].node
python_type_for_type = mapped_attr_lookup[
stmt.lvalues[0].name
].type
left_node_proper_type = get_proper_type(left_node.type)
# if we have scanned an UnboundType and now there's a more
# specific type than UnboundType, call the re-scan so we
# can get that set up correctly
if (
isinstance(python_type_for_type, UnboundType)
and not isinstance(left_node_proper_type, UnboundType)
and (
isinstance(stmt.rvalue, CallExpr)
and isinstance(stmt.rvalue.callee, MemberExpr)
and isinstance(stmt.rvalue.callee.expr, NameExpr)
and stmt.rvalue.callee.expr.node is not None
and stmt.rvalue.callee.expr.node.fullname
== NAMED_TYPE_SQLA_MAPPED
and stmt.rvalue.callee.name == "_empty_constructor"
and isinstance(stmt.rvalue.args[0], CallExpr)
and isinstance(stmt.rvalue.args[0].callee, RefExpr)
)
):
new_python_type_for_type = (
infer.infer_type_from_right_hand_nameexpr(
api,
stmt,
left_node,
left_node_proper_type,
stmt.rvalue.args[0].callee,
)
)
if new_python_type_for_type is not None and not isinstance(
new_python_type_for_type, UnboundType
):
python_type_for_type = new_python_type_for_type
# update the SQLAlchemyAttribute with the better
# information
mapped_attr_lookup[stmt.lvalues[0].name].type = (
python_type_for_type
)
update_cls_metadata = True
if (
not isinstance(left_node.type, Instance)
or left_node.type.type.fullname != NAMED_TYPE_SQLA_MAPPED
):
assert python_type_for_type is not None
left_node.type = api.named_type(
NAMED_TYPE_SQLA_MAPPED, [python_type_for_type]
)
if update_cls_metadata:
util.set_mapped_attributes(cls.info, attributes)
def apply_type_to_mapped_statement(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
lvalue: NameExpr,
left_hand_explicit_type: Optional[ProperType],
python_type_for_type: Optional[ProperType],
) -> None:
"""Apply the Mapped[<type>] annotation and right hand object to a
declarative assignment statement.
This converts a Python declarative class statement such as::
class User(Base):
# ...
attrname = Column(Integer)
To one that describes the final Python behavior to Mypy::
... format: off
class User(Base):
# ...
attrname : Mapped[Optional[int]] = <meaningless temp node>
... format: on
"""
left_node = lvalue.node
assert isinstance(left_node, Var)
# to be completely honest I have no idea what the difference between
# left_node.type and stmt.type is, what it means if these are different
# vs. the same, why in order to get tests to pass I have to assign
# to stmt.type for the second case and not the first. this is complete
# trying every combination until it works stuff.
if left_hand_explicit_type is not None:
lvalue.is_inferred_def = False
left_node.type = api.named_type(
NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
)
else:
lvalue.is_inferred_def = False
left_node.type = api.named_type(
NAMED_TYPE_SQLA_MAPPED,
(
[AnyType(TypeOfAny.special_form)]
if python_type_for_type is None
else [python_type_for_type]
),
)
# so to have it skip the right side totally, we can do this:
# stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
# however, if we instead manufacture a new node that uses the old
# one, then we can still get type checking for the call itself,
# e.g. the Column, relationship() call, etc.
# rewrite the node as:
# <attr> : Mapped[<typ>] =
# _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
# the original right-hand side is maintained so it gets type checked
# internally
stmt.rvalue = expr_to_mapped_constructor(stmt.rvalue)
if stmt.type is not None and python_type_for_type is not None:
stmt.type = python_type_for_type
def add_additional_orm_attributes(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""Apply __init__, __table__ and other attributes to the mapped class."""
info = util.info_for_cls(cls, api)
if info is None:
return
is_base = util.get_is_base(info)
if "__init__" not in info.names and not is_base:
mapped_attr_names = {attr.name: attr.type for attr in attributes}
for base in info.mro[1:-1]:
if "sqlalchemy" not in info.metadata:
continue
base_cls_attributes = util.get_mapped_attributes(base, api)
if base_cls_attributes is None:
continue
for attr in base_cls_attributes:
mapped_attr_names.setdefault(attr.name, attr.type)
arguments = []
for name, typ in mapped_attr_names.items():
if typ is None:
typ = AnyType(TypeOfAny.special_form)
arguments.append(
Argument(
variable=Var(name, typ),
type_annotation=typ,
initializer=TempNode(typ),
kind=ARG_NAMED_OPT,
)
)
add_method_to_class(api, cls, "__init__", arguments, NoneTyp())
if "__table__" not in info.names and util.get_has_table(info):
_apply_placeholder_attr_to_class(
api, cls, "sqlalchemy.sql.schema.Table", "__table__"
)
if not is_base:
_apply_placeholder_attr_to_class(
api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__"
)
def _apply_placeholder_attr_to_class(
api: SemanticAnalyzerPluginInterface,
cls: ClassDef,
qualified_name: str,
attrname: str,
) -> None:
sym = api.lookup_fully_qualified_or_none(qualified_name)
if sym:
assert isinstance(sym.node, TypeInfo)
type_: ProperType = Instance(sym.node, [])
else:
type_ = AnyType(TypeOfAny.special_form)
var = Var(attrname)
var._fullname = cls.fullname + "." + attrname
var.info = cls.info
var.type = type_
cls.info.names[attrname] = SymbolTableNode(MDEF, var)

View File

@ -0,0 +1,515 @@
# ext/mypy/decl_class.py
# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import List
from typing import Optional
from typing import Union
from mypy.nodes import AssignmentStmt
from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import Decorator
from mypy.nodes import LambdaExpr
from mypy.nodes import ListExpr
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import PlaceholderNode
from mypy.nodes import RefExpr
from mypy.nodes import StrExpr
from mypy.nodes import SymbolNode
from mypy.nodes import SymbolTableNode
from mypy.nodes import TempNode
from mypy.nodes import TypeInfo
from mypy.nodes import Var
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.types import AnyType
from mypy.types import CallableType
from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import NoneType
from mypy.types import ProperType
from mypy.types import Type
from mypy.types import TypeOfAny
from mypy.types import UnboundType
from mypy.types import UnionType
from . import apply
from . import infer
from . import names
from . import util
def scan_declarative_assignments_and_apply_types(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
is_mixin_scan: bool = False,
) -> Optional[List[util.SQLAlchemyAttribute]]:
info = util.info_for_cls(cls, api)
if info is None:
# this can occur during cached passes
return None
elif cls.fullname.startswith("builtins"):
return None
mapped_attributes: Optional[List[util.SQLAlchemyAttribute]] = (
util.get_mapped_attributes(info, api)
)
# used by assign.add_additional_orm_attributes among others
util.establish_as_sqlalchemy(info)
if mapped_attributes is not None:
# ensure that a class that's mapped is always picked up by
# its mapped() decorator or declarative metaclass before
# it would be detected as an unmapped mixin class
if not is_mixin_scan:
# mypy can call us more than once. it then *may* have reset the
# left hand side of everything, but not the right that we removed,
# removing our ability to re-scan. but we have the types
# here, so lets re-apply them, or if we have an UnboundType,
# we can re-scan
apply.re_apply_declarative_assignments(cls, api, mapped_attributes)
return mapped_attributes
mapped_attributes = []
if not cls.defs.body:
# when we get a mixin class from another file, the body is
# empty (!) but the names are in the symbol table. so use that.
for sym_name, sym in info.names.items():
_scan_symbol_table_entry(
cls, api, sym_name, sym, mapped_attributes
)
else:
for stmt in util.flatten_typechecking(cls.defs.body):
if isinstance(stmt, AssignmentStmt):
_scan_declarative_assignment_stmt(
cls, api, stmt, mapped_attributes
)
elif isinstance(stmt, Decorator):
_scan_declarative_decorator_stmt(
cls, api, stmt, mapped_attributes
)
_scan_for_mapped_bases(cls, api)
if not is_mixin_scan:
apply.add_additional_orm_attributes(cls, api, mapped_attributes)
util.set_mapped_attributes(info, mapped_attributes)
return mapped_attributes
def _scan_symbol_table_entry(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
name: str,
value: SymbolTableNode,
attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""Extract mapping information from a SymbolTableNode that's in the
type.names dictionary.
"""
value_type = get_proper_type(value.type)
if not isinstance(value_type, Instance):
return
left_hand_explicit_type = None
type_id = names.type_id_for_named_node(value_type.type)
# type_id = names._type_id_for_unbound_type(value.type.type, cls, api)
err = False
# TODO: this is nearly the same logic as that of
# _scan_declarative_decorator_stmt, likely can be merged
if type_id in {
names.MAPPED,
names.RELATIONSHIP,
names.COMPOSITE_PROPERTY,
names.MAPPER_PROPERTY,
names.SYNONYM_PROPERTY,
names.COLUMN_PROPERTY,
}:
if value_type.args:
left_hand_explicit_type = get_proper_type(value_type.args[0])
else:
err = True
elif type_id is names.COLUMN:
if not value_type.args:
err = True
else:
typeengine_arg: Union[ProperType, TypeInfo] = get_proper_type(
value_type.args[0]
)
if isinstance(typeengine_arg, Instance):
typeengine_arg = typeengine_arg.type
if isinstance(typeengine_arg, (UnboundType, TypeInfo)):
sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
if sym is not None and isinstance(sym.node, TypeInfo):
if names.has_base_type_id(sym.node, names.TYPEENGINE):
left_hand_explicit_type = UnionType(
[
infer.extract_python_type_from_typeengine(
api, sym.node, []
),
NoneType(),
]
)
else:
util.fail(
api,
"Column type should be a TypeEngine "
"subclass not '{}'".format(sym.node.fullname),
value_type,
)
if err:
msg = (
"Can't infer type from attribute {} on class {}. "
"please specify a return type from this function that is "
"one of: Mapped[<python type>], relationship[<target class>], "
"Column[<TypeEngine>], MapperProperty[<python type>]"
)
util.fail(api, msg.format(name, cls.name), cls)
left_hand_explicit_type = AnyType(TypeOfAny.special_form)
if left_hand_explicit_type is not None:
assert value.node is not None
attributes.append(
util.SQLAlchemyAttribute(
name=name,
line=value.node.line,
column=value.node.column,
typ=left_hand_explicit_type,
info=cls.info,
)
)
def _scan_declarative_decorator_stmt(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
stmt: Decorator,
attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""Extract mapping information from a @declared_attr in a declarative
class.
E.g.::
@reg.mapped
class MyClass:
# ...
@declared_attr
def updated_at(cls) -> Column[DateTime]:
return Column(DateTime)
Will resolve in mypy as::
@reg.mapped
class MyClass:
# ...
updated_at: Mapped[Optional[datetime.datetime]]
"""
for dec in stmt.decorators:
if (
isinstance(dec, (NameExpr, MemberExpr, SymbolNode))
and names.type_id_for_named_node(dec) is names.DECLARED_ATTR
):
break
else:
return
dec_index = cls.defs.body.index(stmt)
left_hand_explicit_type: Optional[ProperType] = None
if util.name_is_dunder(stmt.name):
# for dunder names like __table_args__, __tablename__,
# __mapper_args__ etc., rewrite these as simple assignment
# statements; otherwise mypy doesn't like if the decorated
# function has an annotation like ``cls: Type[Foo]`` because
# it isn't @classmethod
any_ = AnyType(TypeOfAny.special_form)
left_node = NameExpr(stmt.var.name)
left_node.node = stmt.var
new_stmt = AssignmentStmt([left_node], TempNode(any_))
new_stmt.type = left_node.node.type
cls.defs.body[dec_index] = new_stmt
return
elif isinstance(stmt.func.type, CallableType):
func_type = stmt.func.type.ret_type
if isinstance(func_type, UnboundType):
type_id = names.type_id_for_unbound_type(func_type, cls, api)
else:
# this does not seem to occur unless the type argument is
# incorrect
return
if (
type_id
in {
names.MAPPED,
names.RELATIONSHIP,
names.COMPOSITE_PROPERTY,
names.MAPPER_PROPERTY,
names.SYNONYM_PROPERTY,
names.COLUMN_PROPERTY,
}
and func_type.args
):
left_hand_explicit_type = get_proper_type(func_type.args[0])
elif type_id is names.COLUMN and func_type.args:
typeengine_arg = func_type.args[0]
if isinstance(typeengine_arg, UnboundType):
sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
if sym is not None and isinstance(sym.node, TypeInfo):
if names.has_base_type_id(sym.node, names.TYPEENGINE):
left_hand_explicit_type = UnionType(
[
infer.extract_python_type_from_typeengine(
api, sym.node, []
),
NoneType(),
]
)
else:
util.fail(
api,
"Column type should be a TypeEngine "
"subclass not '{}'".format(sym.node.fullname),
func_type,
)
if left_hand_explicit_type is None:
# no type on the decorated function. our option here is to
# dig into the function body and get the return type, but they
# should just have an annotation.
msg = (
"Can't infer type from @declared_attr on function '{}'; "
"please specify a return type from this function that is "
"one of: Mapped[<python type>], relationship[<target class>], "
"Column[<TypeEngine>], MapperProperty[<python type>]"
)
util.fail(api, msg.format(stmt.var.name), stmt)
left_hand_explicit_type = AnyType(TypeOfAny.special_form)
left_node = NameExpr(stmt.var.name)
left_node.node = stmt.var
# totally feeling around in the dark here as I don't totally understand
# the significance of UnboundType. It seems to be something that is
# not going to do what's expected when it is applied as the type of
# an AssignmentStatement. So do a feeling-around-in-the-dark version
# of converting it to the regular Instance/TypeInfo/UnionType structures
# we see everywhere else.
if isinstance(left_hand_explicit_type, UnboundType):
left_hand_explicit_type = get_proper_type(
util.unbound_to_instance(api, left_hand_explicit_type)
)
left_node.node.type = api.named_type(
names.NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
)
# this will ignore the rvalue entirely
# rvalue = TempNode(AnyType(TypeOfAny.special_form))
# rewrite the node as:
# <attr> : Mapped[<typ>] =
# _sa_Mapped._empty_constructor(lambda: <function body>)
# the function body is maintained so it gets type checked internally
rvalue = names.expr_to_mapped_constructor(
LambdaExpr(stmt.func.arguments, stmt.func.body)
)
new_stmt = AssignmentStmt([left_node], rvalue)
new_stmt.type = left_node.node.type
attributes.append(
util.SQLAlchemyAttribute(
name=left_node.name,
line=stmt.line,
column=stmt.column,
typ=left_hand_explicit_type,
info=cls.info,
)
)
cls.defs.body[dec_index] = new_stmt
def _scan_declarative_assignment_stmt(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""Extract mapping information from an assignment statement in a
declarative class.
"""
lvalue = stmt.lvalues[0]
if not isinstance(lvalue, NameExpr):
return
sym = cls.info.names.get(lvalue.name)
# this establishes that semantic analysis has taken place, which
# means the nodes are populated and we are called from an appropriate
# hook.
assert sym is not None
node = sym.node
if isinstance(node, PlaceholderNode):
return
assert node is lvalue.node
assert isinstance(node, Var)
if node.name == "__abstract__":
if api.parse_bool(stmt.rvalue) is True:
util.set_is_base(cls.info)
return
elif node.name == "__tablename__":
util.set_has_table(cls.info)
elif node.name.startswith("__"):
return
elif node.name == "_mypy_mapped_attrs":
if not isinstance(stmt.rvalue, ListExpr):
util.fail(api, "_mypy_mapped_attrs is expected to be a list", stmt)
else:
for item in stmt.rvalue.items:
if isinstance(item, (NameExpr, StrExpr)):
apply.apply_mypy_mapped_attr(cls, api, item, attributes)
left_hand_mapped_type: Optional[Type] = None
left_hand_explicit_type: Optional[ProperType] = None
if node.is_inferred or node.type is None:
if isinstance(stmt.type, UnboundType):
# look for an explicit Mapped[] type annotation on the left
# side with nothing on the right
# print(stmt.type)
# Mapped?[Optional?[A?]]
left_hand_explicit_type = stmt.type
if stmt.type.name == "Mapped":
mapped_sym = api.lookup_qualified("Mapped", cls)
if (
mapped_sym is not None
and mapped_sym.node is not None
and names.type_id_for_named_node(mapped_sym.node)
is names.MAPPED
):
left_hand_explicit_type = get_proper_type(
stmt.type.args[0]
)
left_hand_mapped_type = stmt.type
# TODO: do we need to convert from unbound for this case?
# left_hand_explicit_type = util._unbound_to_instance(
# api, left_hand_explicit_type
# )
else:
node_type = get_proper_type(node.type)
if (
isinstance(node_type, Instance)
and names.type_id_for_named_node(node_type.type) is names.MAPPED
):
# print(node.type)
# sqlalchemy.orm.attributes.Mapped[<python type>]
left_hand_explicit_type = get_proper_type(node_type.args[0])
left_hand_mapped_type = node_type
else:
# print(node.type)
# <python type>
left_hand_explicit_type = node_type
left_hand_mapped_type = None
if isinstance(stmt.rvalue, TempNode) and left_hand_mapped_type is not None:
# annotation without assignment and Mapped is present
# as type annotation
# equivalent to using _infer_type_from_left_hand_type_only.
python_type_for_type = left_hand_explicit_type
elif isinstance(stmt.rvalue, CallExpr) and isinstance(
stmt.rvalue.callee, RefExpr
):
python_type_for_type = infer.infer_type_from_right_hand_nameexpr(
api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee
)
if python_type_for_type is None:
return
else:
return
assert python_type_for_type is not None
attributes.append(
util.SQLAlchemyAttribute(
name=node.name,
line=stmt.line,
column=stmt.column,
typ=python_type_for_type,
info=cls.info,
)
)
apply.apply_type_to_mapped_statement(
api,
stmt,
lvalue,
left_hand_explicit_type,
python_type_for_type,
)
def _scan_for_mapped_bases(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
) -> None:
"""Given a class, iterate through its superclass hierarchy to find
all other classes that are considered as ORM-significant.
Locates non-mapped mixins and scans them for mapped attributes to be
applied to subclasses.
"""
info = util.info_for_cls(cls, api)
if info is None:
return
for base_info in info.mro[1:-1]:
if base_info.fullname.startswith("builtins"):
continue
# scan each base for mapped attributes. if they are not already
# scanned (but have all their type info), that means they are unmapped
# mixins
scan_declarative_assignments_and_apply_types(
base_info.defn, api, is_mixin_scan=True
)

View File

@ -0,0 +1,590 @@
# ext/mypy/infer.py
# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import Optional
from typing import Sequence
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import AssignmentStmt
from mypy.nodes import CallExpr
from mypy.nodes import Expression
from mypy.nodes import FuncDef
from mypy.nodes import LambdaExpr
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import RefExpr
from mypy.nodes import StrExpr
from mypy.nodes import TypeInfo
from mypy.nodes import Var
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.subtypes import is_subtype
from mypy.types import AnyType
from mypy.types import CallableType
from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import NoneType
from mypy.types import ProperType
from mypy.types import TypeOfAny
from mypy.types import UnionType
from . import names
from . import util
def infer_type_from_right_hand_nameexpr(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
left_hand_explicit_type: Optional[ProperType],
infer_from_right_side: RefExpr,
) -> Optional[ProperType]:
type_id = names.type_id_for_callee(infer_from_right_side)
if type_id is None:
return None
elif type_id is names.MAPPED:
python_type_for_type = _infer_type_from_mapped(
api, stmt, node, left_hand_explicit_type, infer_from_right_side
)
elif type_id is names.COLUMN:
python_type_for_type = _infer_type_from_decl_column(
api, stmt, node, left_hand_explicit_type
)
elif type_id is names.RELATIONSHIP:
python_type_for_type = _infer_type_from_relationship(
api, stmt, node, left_hand_explicit_type
)
elif type_id is names.COLUMN_PROPERTY:
python_type_for_type = _infer_type_from_decl_column_property(
api, stmt, node, left_hand_explicit_type
)
elif type_id is names.SYNONYM_PROPERTY:
python_type_for_type = infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
elif type_id is names.COMPOSITE_PROPERTY:
python_type_for_type = _infer_type_from_decl_composite_property(
api, stmt, node, left_hand_explicit_type
)
else:
return None
return python_type_for_type
def _infer_type_from_relationship(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
left_hand_explicit_type: Optional[ProperType],
) -> Optional[ProperType]:
"""Infer the type of mapping from a relationship.
E.g.::
@reg.mapped
class MyClass:
# ...
addresses = relationship(Address, uselist=True)
order: Mapped["Order"] = relationship("Order")
Will resolve in mypy as::
@reg.mapped
class MyClass:
# ...
addresses: Mapped[List[Address]]
order: Mapped["Order"]
"""
assert isinstance(stmt.rvalue, CallExpr)
target_cls_arg = stmt.rvalue.args[0]
python_type_for_type: Optional[ProperType] = None
if isinstance(target_cls_arg, NameExpr) and isinstance(
target_cls_arg.node, TypeInfo
):
# type
related_object_type = target_cls_arg.node
python_type_for_type = Instance(related_object_type, [])
# other cases not covered - an error message directs the user
# to set an explicit type annotation
#
# node.type == str, it's a string
# if isinstance(target_cls_arg, NameExpr) and isinstance(
# target_cls_arg.node, Var
# )
# points to a type
# isinstance(target_cls_arg, NameExpr) and isinstance(
# target_cls_arg.node, TypeAlias
# )
# string expression
# isinstance(target_cls_arg, StrExpr)
uselist_arg = util.get_callexpr_kwarg(stmt.rvalue, "uselist")
collection_cls_arg: Optional[Expression] = util.get_callexpr_kwarg(
stmt.rvalue, "collection_class"
)
type_is_a_collection = False
# this can be used to determine Optional for a many-to-one
# in the same way nullable=False could be used, if we start supporting
# that.
# innerjoin_arg = util.get_callexpr_kwarg(stmt.rvalue, "innerjoin")
if (
uselist_arg is not None
and api.parse_bool(uselist_arg) is True
and collection_cls_arg is None
):
type_is_a_collection = True
if python_type_for_type is not None:
python_type_for_type = api.named_type(
names.NAMED_TYPE_BUILTINS_LIST, [python_type_for_type]
)
elif (
uselist_arg is None or api.parse_bool(uselist_arg) is True
) and collection_cls_arg is not None:
type_is_a_collection = True
if isinstance(collection_cls_arg, CallExpr):
collection_cls_arg = collection_cls_arg.callee
if isinstance(collection_cls_arg, NameExpr) and isinstance(
collection_cls_arg.node, TypeInfo
):
if python_type_for_type is not None:
# this can still be overridden by the left hand side
# within _infer_Type_from_left_and_inferred_right
python_type_for_type = Instance(
collection_cls_arg.node, [python_type_for_type]
)
elif (
isinstance(collection_cls_arg, NameExpr)
and isinstance(collection_cls_arg.node, FuncDef)
and collection_cls_arg.node.type is not None
):
if python_type_for_type is not None:
# this can still be overridden by the left hand side
# within _infer_Type_from_left_and_inferred_right
# TODO: handle mypy.types.Overloaded
if isinstance(collection_cls_arg.node.type, CallableType):
rt = get_proper_type(collection_cls_arg.node.type.ret_type)
if isinstance(rt, CallableType):
callable_ret_type = get_proper_type(rt.ret_type)
if isinstance(callable_ret_type, Instance):
python_type_for_type = Instance(
callable_ret_type.type,
[python_type_for_type],
)
else:
util.fail(
api,
"Expected Python collection type for "
"collection_class parameter",
stmt.rvalue,
)
python_type_for_type = None
elif uselist_arg is not None and api.parse_bool(uselist_arg) is False:
if collection_cls_arg is not None:
util.fail(
api,
"Sending uselist=False and collection_class at the same time "
"does not make sense",
stmt.rvalue,
)
if python_type_for_type is not None:
python_type_for_type = UnionType(
[python_type_for_type, NoneType()]
)
else:
if left_hand_explicit_type is None:
msg = (
"Can't infer scalar or collection for ORM mapped expression "
"assigned to attribute '{}' if both 'uselist' and "
"'collection_class' arguments are absent from the "
"relationship(); please specify a "
"type annotation on the left hand side."
)
util.fail(api, msg.format(node.name), node)
if python_type_for_type is None:
return infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
elif left_hand_explicit_type is not None:
if type_is_a_collection:
assert isinstance(left_hand_explicit_type, Instance)
assert isinstance(python_type_for_type, Instance)
return _infer_collection_type_from_left_and_inferred_right(
api, node, left_hand_explicit_type, python_type_for_type
)
else:
return _infer_type_from_left_and_inferred_right(
api,
node,
left_hand_explicit_type,
python_type_for_type,
)
else:
return python_type_for_type
def _infer_type_from_decl_composite_property(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
left_hand_explicit_type: Optional[ProperType],
) -> Optional[ProperType]:
"""Infer the type of mapping from a Composite."""
assert isinstance(stmt.rvalue, CallExpr)
target_cls_arg = stmt.rvalue.args[0]
python_type_for_type = None
if isinstance(target_cls_arg, NameExpr) and isinstance(
target_cls_arg.node, TypeInfo
):
related_object_type = target_cls_arg.node
python_type_for_type = Instance(related_object_type, [])
else:
python_type_for_type = None
if python_type_for_type is None:
return infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
elif left_hand_explicit_type is not None:
return _infer_type_from_left_and_inferred_right(
api, node, left_hand_explicit_type, python_type_for_type
)
else:
return python_type_for_type
def _infer_type_from_mapped(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
left_hand_explicit_type: Optional[ProperType],
infer_from_right_side: RefExpr,
) -> Optional[ProperType]:
"""Infer the type of mapping from a right side expression
that returns Mapped.
"""
assert isinstance(stmt.rvalue, CallExpr)
# (Pdb) print(stmt.rvalue.callee)
# NameExpr(query_expression [sqlalchemy.orm._orm_constructors.query_expression]) # noqa: E501
# (Pdb) stmt.rvalue.callee.node
# <mypy.nodes.FuncDef object at 0x7f8d92fb5940>
# (Pdb) stmt.rvalue.callee.node.type
# def [_T] (default_expr: sqlalchemy.sql.elements.ColumnElement[_T`-1] =) -> sqlalchemy.orm.base.Mapped[_T`-1] # noqa: E501
# sqlalchemy.orm.base.Mapped[_T`-1]
# the_mapped_type = stmt.rvalue.callee.node.type.ret_type
# TODO: look at generic ref and either use that,
# or reconcile w/ what's present, etc.
the_mapped_type = util.type_for_callee(infer_from_right_side) # noqa
return infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
def _infer_type_from_decl_column_property(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
left_hand_explicit_type: Optional[ProperType],
) -> Optional[ProperType]:
"""Infer the type of mapping from a ColumnProperty.
This includes mappings against ``column_property()`` as well as the
``deferred()`` function.
"""
assert isinstance(stmt.rvalue, CallExpr)
if stmt.rvalue.args:
first_prop_arg = stmt.rvalue.args[0]
if isinstance(first_prop_arg, CallExpr):
type_id = names.type_id_for_callee(first_prop_arg.callee)
# look for column_property() / deferred() etc with Column as first
# argument
if type_id is names.COLUMN:
return _infer_type_from_decl_column(
api,
stmt,
node,
left_hand_explicit_type,
right_hand_expression=first_prop_arg,
)
if isinstance(stmt.rvalue, CallExpr):
type_id = names.type_id_for_callee(stmt.rvalue.callee)
# this is probably not strictly necessary as we have to use the left
# hand type for query expression in any case. any other no-arg
# column prop objects would go here also
if type_id is names.QUERY_EXPRESSION:
return _infer_type_from_decl_column(
api,
stmt,
node,
left_hand_explicit_type,
)
return infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
def _infer_type_from_decl_column(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
left_hand_explicit_type: Optional[ProperType],
right_hand_expression: Optional[CallExpr] = None,
) -> Optional[ProperType]:
"""Infer the type of mapping from a Column.
E.g.::
@reg.mapped
class MyClass:
# ...
a = Column(Integer)
b = Column("b", String)
c: Mapped[int] = Column(Integer)
d: bool = Column(Boolean)
Will resolve in MyPy as::
@reg.mapped
class MyClass:
# ...
a: Mapped[int]
b: Mapped[str]
c: Mapped[int]
d: Mapped[bool]
"""
assert isinstance(node, Var)
callee = None
if right_hand_expression is None:
if not isinstance(stmt.rvalue, CallExpr):
return None
right_hand_expression = stmt.rvalue
for column_arg in right_hand_expression.args[0:2]:
if isinstance(column_arg, CallExpr):
if isinstance(column_arg.callee, RefExpr):
# x = Column(String(50))
callee = column_arg.callee
type_args: Sequence[Expression] = column_arg.args
break
elif isinstance(column_arg, (NameExpr, MemberExpr)):
if isinstance(column_arg.node, TypeInfo):
# x = Column(String)
callee = column_arg
type_args = ()
break
else:
# x = Column(some_name, String), go to next argument
continue
elif isinstance(column_arg, (StrExpr,)):
# x = Column("name", String), go to next argument
continue
elif isinstance(column_arg, (LambdaExpr,)):
# x = Column("name", String, default=lambda: uuid.uuid4())
# go to next argument
continue
else:
assert False
if callee is None:
return None
if isinstance(callee.node, TypeInfo) and names.mro_has_id(
callee.node.mro, names.TYPEENGINE
):
python_type_for_type = extract_python_type_from_typeengine(
api, callee.node, type_args
)
if left_hand_explicit_type is not None:
return _infer_type_from_left_and_inferred_right(
api, node, left_hand_explicit_type, python_type_for_type
)
else:
return UnionType([python_type_for_type, NoneType()])
else:
# it's not TypeEngine, it's typically implicitly typed
# like ForeignKey. we can't infer from the right side.
return infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
def _infer_type_from_left_and_inferred_right(
api: SemanticAnalyzerPluginInterface,
node: Var,
left_hand_explicit_type: ProperType,
python_type_for_type: ProperType,
orig_left_hand_type: Optional[ProperType] = None,
orig_python_type_for_type: Optional[ProperType] = None,
) -> Optional[ProperType]:
"""Validate type when a left hand annotation is present and we also
could infer the right hand side::
attrname: SomeType = Column(SomeDBType)
"""
if orig_left_hand_type is None:
orig_left_hand_type = left_hand_explicit_type
if orig_python_type_for_type is None:
orig_python_type_for_type = python_type_for_type
if not is_subtype(left_hand_explicit_type, python_type_for_type):
effective_type = api.named_type(
names.NAMED_TYPE_SQLA_MAPPED, [orig_python_type_for_type]
)
msg = (
"Left hand assignment '{}: {}' not compatible "
"with ORM mapped expression of type {}"
)
util.fail(
api,
msg.format(
node.name,
util.format_type(orig_left_hand_type, api.options),
util.format_type(effective_type, api.options),
),
node,
)
return orig_left_hand_type
def _infer_collection_type_from_left_and_inferred_right(
api: SemanticAnalyzerPluginInterface,
node: Var,
left_hand_explicit_type: Instance,
python_type_for_type: Instance,
) -> Optional[ProperType]:
orig_left_hand_type = left_hand_explicit_type
orig_python_type_for_type = python_type_for_type
if left_hand_explicit_type.args:
left_hand_arg = get_proper_type(left_hand_explicit_type.args[0])
python_type_arg = get_proper_type(python_type_for_type.args[0])
else:
left_hand_arg = left_hand_explicit_type
python_type_arg = python_type_for_type
assert isinstance(left_hand_arg, (Instance, UnionType))
assert isinstance(python_type_arg, (Instance, UnionType))
return _infer_type_from_left_and_inferred_right(
api,
node,
left_hand_arg,
python_type_arg,
orig_left_hand_type=orig_left_hand_type,
orig_python_type_for_type=orig_python_type_for_type,
)
def infer_type_from_left_hand_type_only(
api: SemanticAnalyzerPluginInterface,
node: Var,
left_hand_explicit_type: Optional[ProperType],
) -> Optional[ProperType]:
"""Determine the type based on explicit annotation only.
if no annotation were present, note that we need one there to know
the type.
"""
if left_hand_explicit_type is None:
msg = (
"Can't infer type from ORM mapped expression "
"assigned to attribute '{}'; please specify a "
"Python type or "
"Mapped[<python type>] on the left hand side."
)
util.fail(api, msg.format(node.name), node)
return api.named_type(
names.NAMED_TYPE_SQLA_MAPPED, [AnyType(TypeOfAny.special_form)]
)
else:
# use type from the left hand side
return left_hand_explicit_type
def extract_python_type_from_typeengine(
api: SemanticAnalyzerPluginInterface,
node: TypeInfo,
type_args: Sequence[Expression],
) -> ProperType:
if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args:
first_arg = type_args[0]
if isinstance(first_arg, RefExpr) and isinstance(
first_arg.node, TypeInfo
):
for base_ in first_arg.node.mro:
if base_.fullname == "enum.Enum":
return Instance(first_arg.node, [])
# TODO: support other pep-435 types here
else:
return api.named_type(names.NAMED_TYPE_BUILTINS_STR, [])
assert node.has_base("sqlalchemy.sql.type_api.TypeEngine"), (
"could not extract Python type from node: %s" % node
)
type_engine_sym = api.lookup_fully_qualified_or_none(
"sqlalchemy.sql.type_api.TypeEngine"
)
assert type_engine_sym is not None and isinstance(
type_engine_sym.node, TypeInfo
)
type_engine = map_instance_to_supertype(
Instance(node, []),
type_engine_sym.node,
)
return get_proper_type(type_engine.args[-1])

View File

@ -0,0 +1,335 @@
# ext/mypy/names.py
# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import Dict
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Union
from mypy.nodes import ARG_POS
from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import Decorator
from mypy.nodes import Expression
from mypy.nodes import FuncDef
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import OverloadedFuncDef
from mypy.nodes import SymbolNode
from mypy.nodes import TypeAlias
from mypy.nodes import TypeInfo
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.types import CallableType
from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import UnboundType
from ... import util
COLUMN: int = util.symbol("COLUMN")
RELATIONSHIP: int = util.symbol("RELATIONSHIP")
REGISTRY: int = util.symbol("REGISTRY")
COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY")
TYPEENGINE: int = util.symbol("TYPEENGNE")
MAPPED: int = util.symbol("MAPPED")
DECLARATIVE_BASE: int = util.symbol("DECLARATIVE_BASE")
DECLARATIVE_META: int = util.symbol("DECLARATIVE_META")
MAPPED_DECORATOR: int = util.symbol("MAPPED_DECORATOR")
SYNONYM_PROPERTY: int = util.symbol("SYNONYM_PROPERTY")
COMPOSITE_PROPERTY: int = util.symbol("COMPOSITE_PROPERTY")
DECLARED_ATTR: int = util.symbol("DECLARED_ATTR")
MAPPER_PROPERTY: int = util.symbol("MAPPER_PROPERTY")
AS_DECLARATIVE: int = util.symbol("AS_DECLARATIVE")
AS_DECLARATIVE_BASE: int = util.symbol("AS_DECLARATIVE_BASE")
DECLARATIVE_MIXIN: int = util.symbol("DECLARATIVE_MIXIN")
QUERY_EXPRESSION: int = util.symbol("QUERY_EXPRESSION")
# names that must succeed with mypy.api.named_type
NAMED_TYPE_BUILTINS_OBJECT = "builtins.object"
NAMED_TYPE_BUILTINS_STR = "builtins.str"
NAMED_TYPE_BUILTINS_LIST = "builtins.list"
NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.base.Mapped"
_RelFullNames = {
"sqlalchemy.orm.relationships.Relationship",
"sqlalchemy.orm.relationships.RelationshipProperty",
"sqlalchemy.orm.relationships._RelationshipDeclared",
"sqlalchemy.orm.Relationship",
"sqlalchemy.orm.RelationshipProperty",
}
_lookup: Dict[str, Tuple[int, Set[str]]] = {
"Column": (
COLUMN,
{
"sqlalchemy.sql.schema.Column",
"sqlalchemy.sql.Column",
},
),
"Relationship": (RELATIONSHIP, _RelFullNames),
"RelationshipProperty": (RELATIONSHIP, _RelFullNames),
"_RelationshipDeclared": (RELATIONSHIP, _RelFullNames),
"registry": (
REGISTRY,
{
"sqlalchemy.orm.decl_api.registry",
"sqlalchemy.orm.registry",
},
),
"ColumnProperty": (
COLUMN_PROPERTY,
{
"sqlalchemy.orm.properties.MappedSQLExpression",
"sqlalchemy.orm.MappedSQLExpression",
"sqlalchemy.orm.properties.ColumnProperty",
"sqlalchemy.orm.ColumnProperty",
},
),
"MappedSQLExpression": (
COLUMN_PROPERTY,
{
"sqlalchemy.orm.properties.MappedSQLExpression",
"sqlalchemy.orm.MappedSQLExpression",
"sqlalchemy.orm.properties.ColumnProperty",
"sqlalchemy.orm.ColumnProperty",
},
),
"Synonym": (
SYNONYM_PROPERTY,
{
"sqlalchemy.orm.descriptor_props.Synonym",
"sqlalchemy.orm.Synonym",
"sqlalchemy.orm.descriptor_props.SynonymProperty",
"sqlalchemy.orm.SynonymProperty",
},
),
"SynonymProperty": (
SYNONYM_PROPERTY,
{
"sqlalchemy.orm.descriptor_props.Synonym",
"sqlalchemy.orm.Synonym",
"sqlalchemy.orm.descriptor_props.SynonymProperty",
"sqlalchemy.orm.SynonymProperty",
},
),
"Composite": (
COMPOSITE_PROPERTY,
{
"sqlalchemy.orm.descriptor_props.Composite",
"sqlalchemy.orm.Composite",
"sqlalchemy.orm.descriptor_props.CompositeProperty",
"sqlalchemy.orm.CompositeProperty",
},
),
"CompositeProperty": (
COMPOSITE_PROPERTY,
{
"sqlalchemy.orm.descriptor_props.Composite",
"sqlalchemy.orm.Composite",
"sqlalchemy.orm.descriptor_props.CompositeProperty",
"sqlalchemy.orm.CompositeProperty",
},
),
"MapperProperty": (
MAPPER_PROPERTY,
{
"sqlalchemy.orm.interfaces.MapperProperty",
"sqlalchemy.orm.MapperProperty",
},
),
"TypeEngine": (TYPEENGINE, {"sqlalchemy.sql.type_api.TypeEngine"}),
"Mapped": (MAPPED, {NAMED_TYPE_SQLA_MAPPED}),
"declarative_base": (
DECLARATIVE_BASE,
{
"sqlalchemy.ext.declarative.declarative_base",
"sqlalchemy.orm.declarative_base",
"sqlalchemy.orm.decl_api.declarative_base",
},
),
"DeclarativeMeta": (
DECLARATIVE_META,
{
"sqlalchemy.ext.declarative.DeclarativeMeta",
"sqlalchemy.orm.DeclarativeMeta",
"sqlalchemy.orm.decl_api.DeclarativeMeta",
},
),
"mapped": (
MAPPED_DECORATOR,
{
"sqlalchemy.orm.decl_api.registry.mapped",
"sqlalchemy.orm.registry.mapped",
},
),
"as_declarative": (
AS_DECLARATIVE,
{
"sqlalchemy.ext.declarative.as_declarative",
"sqlalchemy.orm.decl_api.as_declarative",
"sqlalchemy.orm.as_declarative",
},
),
"as_declarative_base": (
AS_DECLARATIVE_BASE,
{
"sqlalchemy.orm.decl_api.registry.as_declarative_base",
"sqlalchemy.orm.registry.as_declarative_base",
},
),
"declared_attr": (
DECLARED_ATTR,
{
"sqlalchemy.orm.decl_api.declared_attr",
"sqlalchemy.orm.declared_attr",
},
),
"declarative_mixin": (
DECLARATIVE_MIXIN,
{
"sqlalchemy.orm.decl_api.declarative_mixin",
"sqlalchemy.orm.declarative_mixin",
},
),
"query_expression": (
QUERY_EXPRESSION,
{
"sqlalchemy.orm.query_expression",
"sqlalchemy.orm._orm_constructors.query_expression",
},
),
}
def has_base_type_id(info: TypeInfo, type_id: int) -> bool:
for mr in info.mro:
check_type_id, fullnames = _lookup.get(mr.name, (None, None))
if check_type_id == type_id:
break
else:
return False
if fullnames is None:
return False
return mr.fullname in fullnames
def mro_has_id(mro: List[TypeInfo], type_id: int) -> bool:
for mr in mro:
check_type_id, fullnames = _lookup.get(mr.name, (None, None))
if check_type_id == type_id:
break
else:
return False
if fullnames is None:
return False
return mr.fullname in fullnames
def type_id_for_unbound_type(
type_: UnboundType, cls: ClassDef, api: SemanticAnalyzerPluginInterface
) -> Optional[int]:
sym = api.lookup_qualified(type_.name, type_)
if sym is not None:
if isinstance(sym.node, TypeAlias):
target_type = get_proper_type(sym.node.target)
if isinstance(target_type, Instance):
return type_id_for_named_node(target_type.type)
elif isinstance(sym.node, TypeInfo):
return type_id_for_named_node(sym.node)
return None
def type_id_for_callee(callee: Expression) -> Optional[int]:
if isinstance(callee, (MemberExpr, NameExpr)):
if isinstance(callee.node, Decorator) and isinstance(
callee.node.func, FuncDef
):
if callee.node.func.type and isinstance(
callee.node.func.type, CallableType
):
ret_type = get_proper_type(callee.node.func.type.ret_type)
if isinstance(ret_type, Instance):
return type_id_for_fullname(ret_type.type.fullname)
return None
elif isinstance(callee.node, OverloadedFuncDef):
if (
callee.node.impl
and callee.node.impl.type
and isinstance(callee.node.impl.type, CallableType)
):
ret_type = get_proper_type(callee.node.impl.type.ret_type)
if isinstance(ret_type, Instance):
return type_id_for_fullname(ret_type.type.fullname)
return None
elif isinstance(callee.node, FuncDef):
if callee.node.type and isinstance(callee.node.type, CallableType):
ret_type = get_proper_type(callee.node.type.ret_type)
if isinstance(ret_type, Instance):
return type_id_for_fullname(ret_type.type.fullname)
return None
elif isinstance(callee.node, TypeAlias):
target_type = get_proper_type(callee.node.target)
if isinstance(target_type, Instance):
return type_id_for_fullname(target_type.type.fullname)
elif isinstance(callee.node, TypeInfo):
return type_id_for_named_node(callee)
return None
def type_id_for_named_node(
node: Union[NameExpr, MemberExpr, SymbolNode]
) -> Optional[int]:
type_id, fullnames = _lookup.get(node.name, (None, None))
if type_id is None or fullnames is None:
return None
elif node.fullname in fullnames:
return type_id
else:
return None
def type_id_for_fullname(fullname: str) -> Optional[int]:
tokens = fullname.split(".")
immediate = tokens[-1]
type_id, fullnames = _lookup.get(immediate, (None, None))
if type_id is None or fullnames is None:
return None
elif fullname in fullnames:
return type_id
else:
return None
def expr_to_mapped_constructor(expr: Expression) -> CallExpr:
column_descriptor = NameExpr("__sa_Mapped")
column_descriptor.fullname = NAMED_TYPE_SQLA_MAPPED
member_expr = MemberExpr(column_descriptor, "_empty_constructor")
return CallExpr(
member_expr,
[expr],
[ARG_POS],
["arg1"],
)

View File

@ -0,0 +1,303 @@
# ext/mypy/plugin.py
# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""
Mypy plugin for SQLAlchemy ORM.
"""
from __future__ import annotations
from typing import Callable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Type as TypingType
from typing import Union
from mypy import nodes
from mypy.mro import calculate_mro
from mypy.mro import MroError
from mypy.nodes import Block
from mypy.nodes import ClassDef
from mypy.nodes import GDEF
from mypy.nodes import MypyFile
from mypy.nodes import NameExpr
from mypy.nodes import SymbolTable
from mypy.nodes import SymbolTableNode
from mypy.nodes import TypeInfo
from mypy.plugin import AttributeContext
from mypy.plugin import ClassDefContext
from mypy.plugin import DynamicClassDefContext
from mypy.plugin import Plugin
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import Type
from . import decl_class
from . import names
from . import util
try:
__import__("sqlalchemy-stubs")
except ImportError:
pass
else:
raise ImportError(
"The SQLAlchemy mypy plugin in SQLAlchemy "
"2.0 does not work with sqlalchemy-stubs or "
"sqlalchemy2-stubs installed, as well as with any other third party "
"SQLAlchemy stubs. Please uninstall all SQLAlchemy stubs "
"packages."
)
class SQLAlchemyPlugin(Plugin):
def get_dynamic_class_hook(
self, fullname: str
) -> Optional[Callable[[DynamicClassDefContext], None]]:
if names.type_id_for_fullname(fullname) is names.DECLARATIVE_BASE:
return _dynamic_class_hook
return None
def get_customize_class_mro_hook(
self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
return _fill_in_decorators
def get_class_decorator_hook(
self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
sym = self.lookup_fully_qualified(fullname)
if sym is not None and sym.node is not None:
type_id = names.type_id_for_named_node(sym.node)
if type_id is names.MAPPED_DECORATOR:
return _cls_decorator_hook
elif type_id in (
names.AS_DECLARATIVE,
names.AS_DECLARATIVE_BASE,
):
return _base_cls_decorator_hook
elif type_id is names.DECLARATIVE_MIXIN:
return _declarative_mixin_hook
return None
def get_metaclass_hook(
self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if names.type_id_for_fullname(fullname) is names.DECLARATIVE_META:
# Set any classes that explicitly have metaclass=DeclarativeMeta
# as declarative so the check in `get_base_class_hook()` works
return _metaclass_cls_hook
return None
def get_base_class_hook(
self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
sym = self.lookup_fully_qualified(fullname)
if (
sym
and isinstance(sym.node, TypeInfo)
and util.has_declarative_base(sym.node)
):
return _base_cls_hook
return None
def get_attribute_hook(
self, fullname: str
) -> Optional[Callable[[AttributeContext], Type]]:
if fullname.startswith(
"sqlalchemy.orm.attributes.QueryableAttribute."
):
return _queryable_getattr_hook
return None
def get_additional_deps(
self, file: MypyFile
) -> List[Tuple[int, str, int]]:
return [
#
(10, "sqlalchemy.orm", -1),
(10, "sqlalchemy.orm.attributes", -1),
(10, "sqlalchemy.orm.decl_api", -1),
]
def plugin(version: str) -> TypingType[SQLAlchemyPlugin]:
return SQLAlchemyPlugin
def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None:
"""Generate a declarative Base class when the declarative_base() function
is encountered."""
_add_globals(ctx)
cls = ClassDef(ctx.name, Block([]))
cls.fullname = ctx.api.qualified_name(ctx.name)
info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id)
cls.info = info
_set_declarative_metaclass(ctx.api, cls)
cls_arg = util.get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,))
if cls_arg is not None and isinstance(cls_arg.node, TypeInfo):
util.set_is_base(cls_arg.node)
decl_class.scan_declarative_assignments_and_apply_types(
cls_arg.node.defn, ctx.api, is_mixin_scan=True
)
info.bases = [Instance(cls_arg.node, [])]
else:
obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT)
info.bases = [obj]
try:
calculate_mro(info)
except MroError:
util.fail(
ctx.api, "Not able to calculate MRO for declarative base", ctx.call
)
obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT)
info.bases = [obj]
info.fallback_to_any = True
ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
util.set_is_base(info)
def _fill_in_decorators(ctx: ClassDefContext) -> None:
for decorator in ctx.cls.decorators:
# set the ".fullname" attribute of a class decorator
# that is a MemberExpr. This causes the logic in
# semanal.py->apply_class_plugin_hooks to invoke the
# get_class_decorator_hook for our "registry.map_class()"
# and "registry.as_declarative_base()" methods.
# this seems like a bug in mypy that these decorators are otherwise
# skipped.
if (
isinstance(decorator, nodes.CallExpr)
and isinstance(decorator.callee, nodes.MemberExpr)
and decorator.callee.name == "as_declarative_base"
):
target = decorator.callee
elif (
isinstance(decorator, nodes.MemberExpr)
and decorator.name == "mapped"
):
target = decorator
else:
continue
if isinstance(target.expr, NameExpr):
sym = ctx.api.lookup_qualified(
target.expr.name, target, suppress_errors=True
)
else:
continue
if sym and sym.node:
sym_type = get_proper_type(sym.type)
if isinstance(sym_type, Instance):
target.fullname = f"{sym_type.type.fullname}.{target.name}"
else:
# if the registry is in the same file as where the
# decorator is used, it might not have semantic
# symbols applied and we can't get a fully qualified
# name or an inferred type, so we are actually going to
# flag an error in this case that they need to annotate
# it. The "registry" is declared just
# once (or few times), so they have to just not use
# type inference for its assignment in this one case.
util.fail(
ctx.api,
"Class decorator called %s(), but we can't "
"tell if it's from an ORM registry. Please "
"annotate the registry assignment, e.g. "
"my_registry: registry = registry()" % target.name,
sym.node,
)
def _cls_decorator_hook(ctx: ClassDefContext) -> None:
_add_globals(ctx)
assert isinstance(ctx.reason, nodes.MemberExpr)
expr = ctx.reason.expr
assert isinstance(expr, nodes.RefExpr) and isinstance(expr.node, nodes.Var)
node_type = get_proper_type(expr.node.type)
assert (
isinstance(node_type, Instance)
and names.type_id_for_named_node(node_type.type) is names.REGISTRY
)
decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
def _base_cls_decorator_hook(ctx: ClassDefContext) -> None:
_add_globals(ctx)
cls = ctx.cls
_set_declarative_metaclass(ctx.api, cls)
util.set_is_base(ctx.cls.info)
decl_class.scan_declarative_assignments_and_apply_types(
cls, ctx.api, is_mixin_scan=True
)
def _declarative_mixin_hook(ctx: ClassDefContext) -> None:
_add_globals(ctx)
util.set_is_base(ctx.cls.info)
decl_class.scan_declarative_assignments_and_apply_types(
ctx.cls, ctx.api, is_mixin_scan=True
)
def _metaclass_cls_hook(ctx: ClassDefContext) -> None:
util.set_is_base(ctx.cls.info)
def _base_cls_hook(ctx: ClassDefContext) -> None:
_add_globals(ctx)
decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
def _queryable_getattr_hook(ctx: AttributeContext) -> Type:
# how do I....tell it it has no attribute of a certain name?
# can't find any Type that seems to match that
return ctx.default_attr_type
def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None:
"""Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space
for all class defs
"""
util.add_global(ctx, "sqlalchemy.orm", "Mapped", "__sa_Mapped")
def _set_declarative_metaclass(
api: SemanticAnalyzerPluginInterface, target_cls: ClassDef
) -> None:
info = target_cls.info
sym = api.lookup_fully_qualified_or_none(
"sqlalchemy.orm.decl_api.DeclarativeMeta"
)
assert sym is not None and isinstance(sym.node, TypeInfo)
info.declared_metaclass = info.metaclass_type = Instance(sym.node, [])

View File

@ -0,0 +1,357 @@
# ext/mypy/util.py
# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
import re
from typing import Any
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import overload
from typing import Tuple
from typing import Type as TypingType
from typing import TypeVar
from typing import Union
from mypy import version
from mypy.messages import format_type as _mypy_format_type
from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import CLASSDEF_NO_INFO
from mypy.nodes import Context
from mypy.nodes import Expression
from mypy.nodes import FuncDef
from mypy.nodes import IfStmt
from mypy.nodes import JsonDict
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import Statement
from mypy.nodes import SymbolTableNode
from mypy.nodes import TypeAlias
from mypy.nodes import TypeInfo
from mypy.options import Options
from mypy.plugin import ClassDefContext
from mypy.plugin import DynamicClassDefContext
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import deserialize_and_fixup_type
from mypy.typeops import map_type_from_supertype
from mypy.types import CallableType
from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import NoneType
from mypy.types import Type
from mypy.types import TypeVarType
from mypy.types import UnboundType
from mypy.types import UnionType
_vers = tuple(
[int(x) for x in version.__version__.split(".") if re.match(r"^\d+$", x)]
)
mypy_14 = _vers >= (1, 4)
_TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr])
class SQLAlchemyAttribute:
def __init__(
self,
name: str,
line: int,
column: int,
typ: Optional[Type],
info: TypeInfo,
) -> None:
self.name = name
self.line = line
self.column = column
self.type = typ
self.info = info
def serialize(self) -> JsonDict:
assert self.type
return {
"name": self.name,
"line": self.line,
"column": self.column,
"type": serialize_type(self.type),
}
def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
"""Expands type vars in the context of a subtype when an attribute is
inherited from a generic super type.
"""
if not isinstance(self.type, TypeVarType):
return
self.type = map_type_from_supertype(self.type, sub_type, self.info)
@classmethod
def deserialize(
cls,
info: TypeInfo,
data: JsonDict,
api: SemanticAnalyzerPluginInterface,
) -> SQLAlchemyAttribute:
data = data.copy()
typ = deserialize_and_fixup_type(data.pop("type"), api)
return cls(typ=typ, info=info, **data)
def name_is_dunder(name: str) -> bool:
return bool(re.match(r"^__.+?__$", name))
def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None:
info.metadata.setdefault("sqlalchemy", {})[key] = data
def _get_info_metadata(info: TypeInfo, key: str) -> Optional[Any]:
return info.metadata.get("sqlalchemy", {}).get(key, None)
def _get_info_mro_metadata(info: TypeInfo, key: str) -> Optional[Any]:
if info.mro:
for base in info.mro:
metadata = _get_info_metadata(base, key)
if metadata is not None:
return metadata
return None
def establish_as_sqlalchemy(info: TypeInfo) -> None:
info.metadata.setdefault("sqlalchemy", {})
def set_is_base(info: TypeInfo) -> None:
_set_info_metadata(info, "is_base", True)
def get_is_base(info: TypeInfo) -> bool:
is_base = _get_info_metadata(info, "is_base")
return is_base is True
def has_declarative_base(info: TypeInfo) -> bool:
is_base = _get_info_mro_metadata(info, "is_base")
return is_base is True
def set_has_table(info: TypeInfo) -> None:
_set_info_metadata(info, "has_table", True)
def get_has_table(info: TypeInfo) -> bool:
is_base = _get_info_metadata(info, "has_table")
return is_base is True
def get_mapped_attributes(
info: TypeInfo, api: SemanticAnalyzerPluginInterface
) -> Optional[List[SQLAlchemyAttribute]]:
mapped_attributes: Optional[List[JsonDict]] = _get_info_metadata(
info, "mapped_attributes"
)
if mapped_attributes is None:
return None
attributes: List[SQLAlchemyAttribute] = []
for data in mapped_attributes:
attr = SQLAlchemyAttribute.deserialize(info, data, api)
attr.expand_typevar_from_subtype(info)
attributes.append(attr)
return attributes
def format_type(typ_: Type, options: Options) -> str:
if mypy_14:
return _mypy_format_type(typ_, options)
else:
return _mypy_format_type(typ_) # type: ignore
def set_mapped_attributes(
info: TypeInfo, attributes: List[SQLAlchemyAttribute]
) -> None:
_set_info_metadata(
info,
"mapped_attributes",
[attribute.serialize() for attribute in attributes],
)
def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None:
msg = "[SQLAlchemy Mypy plugin] %s" % msg
return api.fail(msg, ctx)
def add_global(
ctx: Union[ClassDefContext, DynamicClassDefContext],
module: str,
symbol_name: str,
asname: str,
) -> None:
module_globals = ctx.api.modules[ctx.api.cur_mod_id].names
if asname not in module_globals:
lookup_sym: SymbolTableNode = ctx.api.modules[module].names[
symbol_name
]
module_globals[asname] = lookup_sym
@overload
def get_callexpr_kwarg(
callexpr: CallExpr, name: str, *, expr_types: None = ...
) -> Optional[Union[CallExpr, NameExpr]]: ...
@overload
def get_callexpr_kwarg(
callexpr: CallExpr,
name: str,
*,
expr_types: Tuple[TypingType[_TArgType], ...],
) -> Optional[_TArgType]: ...
def get_callexpr_kwarg(
callexpr: CallExpr,
name: str,
*,
expr_types: Optional[Tuple[TypingType[Any], ...]] = None,
) -> Optional[Any]:
try:
arg_idx = callexpr.arg_names.index(name)
except ValueError:
return None
kwarg = callexpr.args[arg_idx]
if isinstance(
kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr)
):
return kwarg
return None
def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
for stmt in stmts:
if (
isinstance(stmt, IfStmt)
and isinstance(stmt.expr[0], NameExpr)
and stmt.expr[0].fullname == "typing.TYPE_CHECKING"
):
yield from stmt.body[0].body
else:
yield stmt
def type_for_callee(callee: Expression) -> Optional[Union[Instance, TypeInfo]]:
if isinstance(callee, (MemberExpr, NameExpr)):
if isinstance(callee.node, FuncDef):
if callee.node.type and isinstance(callee.node.type, CallableType):
ret_type = get_proper_type(callee.node.type.ret_type)
if isinstance(ret_type, Instance):
return ret_type
return None
elif isinstance(callee.node, TypeAlias):
target_type = get_proper_type(callee.node.target)
if isinstance(target_type, Instance):
return target_type
elif isinstance(callee.node, TypeInfo):
return callee.node
return None
def unbound_to_instance(
api: SemanticAnalyzerPluginInterface, typ: Type
) -> Type:
"""Take the UnboundType that we seem to get as the ret_type from a FuncDef
and convert it into an Instance/TypeInfo kind of structure that seems
to work as the left-hand type of an AssignmentStatement.
"""
if not isinstance(typ, UnboundType):
return typ
# TODO: figure out a more robust way to check this. The node is some
# kind of _SpecialForm, there's a typing.Optional that's _SpecialForm,
# but I can't figure out how to get them to match up
if typ.name == "Optional":
# convert from "Optional?" to the more familiar
# UnionType[..., NoneType()]
return unbound_to_instance(
api,
UnionType(
[unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
+ [NoneType()]
),
)
node = api.lookup_qualified(typ.name, typ)
if (
node is not None
and isinstance(node, SymbolTableNode)
and isinstance(node.node, TypeInfo)
):
bound_type = node.node
return Instance(
bound_type,
[
(
unbound_to_instance(api, arg)
if isinstance(arg, UnboundType)
else arg
)
for arg in typ.args
],
)
else:
return typ
def info_for_cls(
cls: ClassDef, api: SemanticAnalyzerPluginInterface
) -> Optional[TypeInfo]:
if cls.info is CLASSDEF_NO_INFO:
sym = api.lookup_qualified(cls.name, cls)
if sym is None:
return None
assert sym and isinstance(sym.node, TypeInfo)
return sym.node
return cls.info
def serialize_type(typ: Type) -> Union[str, JsonDict]:
try:
return typ.serialize()
except Exception:
pass
if hasattr(typ, "args"):
typ.args = tuple(
(
a.resolve_string_annotation()
if hasattr(a, "resolve_string_annotation")
else a
)
for a in typ.args
)
elif hasattr(typ, "resolve_string_annotation"):
typ = typ.resolve_string_annotation()
return typ.serialize()