Update 2025-04-13_16:43:49

This commit is contained in:
root
2025-04-13 16:43:50 +02:00
commit 5b46114a61
2244 changed files with 407391 additions and 0 deletions

View File

@ -0,0 +1,373 @@
from __future__ import annotations as _annotations
import warnings
from contextlib import contextmanager
from re import Pattern
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
cast,
)
from pydantic_core import core_schema
from typing_extensions import Self
from ..aliases import AliasGenerator
from ..config import ConfigDict, ExtraValues, JsonDict, JsonEncoder, JsonSchemaExtraCallable
from ..errors import PydanticUserError
from ..warnings import PydanticDeprecatedSince20, PydanticDeprecatedSince210
if not TYPE_CHECKING:
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
# and https://youtrack.jetbrains.com/issue/PY-51428
DeprecationWarning = PydanticDeprecatedSince20
if TYPE_CHECKING:
from .._internal._schema_generation_shared import GenerateSchema
from ..fields import ComputedFieldInfo, FieldInfo
DEPRECATION_MESSAGE = 'Support for class-based `config` is deprecated, use ConfigDict instead.'
class ConfigWrapper:
"""Internal wrapper for Config which exposes ConfigDict items as attributes."""
__slots__ = ('config_dict',)
config_dict: ConfigDict
# all annotations are copied directly from ConfigDict, and should be kept up to date, a test will fail if they
# stop matching
title: str | None
str_to_lower: bool
str_to_upper: bool
str_strip_whitespace: bool
str_min_length: int
str_max_length: int | None
extra: ExtraValues | None
frozen: bool
populate_by_name: bool
use_enum_values: bool
validate_assignment: bool
arbitrary_types_allowed: bool
from_attributes: bool
# whether to use the actual key provided in the data (e.g. alias or first alias for "field required" errors) instead of field_names
# to construct error `loc`s, default `True`
loc_by_alias: bool
alias_generator: Callable[[str], str] | AliasGenerator | None
model_title_generator: Callable[[type], str] | None
field_title_generator: Callable[[str, FieldInfo | ComputedFieldInfo], str] | None
ignored_types: tuple[type, ...]
allow_inf_nan: bool
json_schema_extra: JsonDict | JsonSchemaExtraCallable | None
json_encoders: dict[type[object], JsonEncoder] | None
# new in V2
strict: bool
# whether instances of models and dataclasses (including subclass instances) should re-validate, default 'never'
revalidate_instances: Literal['always', 'never', 'subclass-instances']
ser_json_timedelta: Literal['iso8601', 'float']
ser_json_bytes: Literal['utf8', 'base64', 'hex']
val_json_bytes: Literal['utf8', 'base64', 'hex']
ser_json_inf_nan: Literal['null', 'constants', 'strings']
# whether to validate default values during validation, default False
validate_default: bool
validate_return: bool
protected_namespaces: tuple[str | Pattern[str], ...]
hide_input_in_errors: bool
defer_build: bool
plugin_settings: dict[str, object] | None
schema_generator: type[GenerateSchema] | None
json_schema_serialization_defaults_required: bool
json_schema_mode_override: Literal['validation', 'serialization', None]
coerce_numbers_to_str: bool
regex_engine: Literal['rust-regex', 'python-re']
validation_error_cause: bool
use_attribute_docstrings: bool
cache_strings: bool | Literal['all', 'keys', 'none']
validate_by_alias: bool
validate_by_name: bool
serialize_by_alias: bool
def __init__(self, config: ConfigDict | dict[str, Any] | type[Any] | None, *, check: bool = True):
if check:
self.config_dict = prepare_config(config)
else:
self.config_dict = cast(ConfigDict, config)
@classmethod
def for_model(cls, bases: tuple[type[Any], ...], namespace: dict[str, Any], kwargs: dict[str, Any]) -> Self:
"""Build a new `ConfigWrapper` instance for a `BaseModel`.
The config wrapper built based on (in descending order of priority):
- options from `kwargs`
- options from the `namespace`
- options from the base classes (`bases`)
Args:
bases: A tuple of base classes.
namespace: The namespace of the class being created.
kwargs: The kwargs passed to the class being created.
Returns:
A `ConfigWrapper` instance for `BaseModel`.
"""
config_new = ConfigDict()
for base in bases:
config = getattr(base, 'model_config', None)
if config:
config_new.update(config.copy())
config_class_from_namespace = namespace.get('Config')
config_dict_from_namespace = namespace.get('model_config')
raw_annotations = namespace.get('__annotations__', {})
if raw_annotations.get('model_config') and config_dict_from_namespace is None:
raise PydanticUserError(
'`model_config` cannot be used as a model field name. Use `model_config` for model configuration.',
code='model-config-invalid-field-name',
)
if config_class_from_namespace and config_dict_from_namespace:
raise PydanticUserError('"Config" and "model_config" cannot be used together', code='config-both')
config_from_namespace = config_dict_from_namespace or prepare_config(config_class_from_namespace)
config_new.update(config_from_namespace)
for k in list(kwargs.keys()):
if k in config_keys:
config_new[k] = kwargs.pop(k)
return cls(config_new)
# we don't show `__getattr__` to type checkers so missing attributes cause errors
if not TYPE_CHECKING: # pragma: no branch
def __getattr__(self, name: str) -> Any:
try:
return self.config_dict[name]
except KeyError:
try:
return config_defaults[name]
except KeyError:
raise AttributeError(f'Config has no attribute {name!r}') from None
def core_config(self, title: str | None) -> core_schema.CoreConfig:
"""Create a pydantic-core config.
We don't use getattr here since we don't want to populate with defaults.
Args:
title: The title to use if not set in config.
Returns:
A `CoreConfig` object created from config.
"""
config = self.config_dict
if config.get('schema_generator') is not None:
warnings.warn(
'The `schema_generator` setting has been deprecated since v2.10. This setting no longer has any effect.',
PydanticDeprecatedSince210,
stacklevel=2,
)
if (populate_by_name := config.get('populate_by_name')) is not None:
# We include this patch for backwards compatibility purposes, but this config setting will be deprecated in v3.0, and likely removed in v4.0.
# Thus, the above warning and this patch can be removed then as well.
if config.get('validate_by_name') is None:
config['validate_by_alias'] = True
config['validate_by_name'] = populate_by_name
# We dynamically patch validate_by_name to be True if validate_by_alias is set to False
# and validate_by_name is not explicitly set.
if config.get('validate_by_alias') is False and config.get('validate_by_name') is None:
config['validate_by_name'] = True
if (not config.get('validate_by_alias', True)) and (not config.get('validate_by_name', False)):
raise PydanticUserError(
'At least one of `validate_by_alias` or `validate_by_name` must be set to True.',
code='validate-by-alias-and-name-false',
)
return core_schema.CoreConfig(
**{ # pyright: ignore[reportArgumentType]
k: v
for k, v in (
('title', config.get('title') or title or None),
('extra_fields_behavior', config.get('extra')),
('allow_inf_nan', config.get('allow_inf_nan')),
('str_strip_whitespace', config.get('str_strip_whitespace')),
('str_to_lower', config.get('str_to_lower')),
('str_to_upper', config.get('str_to_upper')),
('strict', config.get('strict')),
('ser_json_timedelta', config.get('ser_json_timedelta')),
('ser_json_bytes', config.get('ser_json_bytes')),
('val_json_bytes', config.get('val_json_bytes')),
('ser_json_inf_nan', config.get('ser_json_inf_nan')),
('from_attributes', config.get('from_attributes')),
('loc_by_alias', config.get('loc_by_alias')),
('revalidate_instances', config.get('revalidate_instances')),
('validate_default', config.get('validate_default')),
('str_max_length', config.get('str_max_length')),
('str_min_length', config.get('str_min_length')),
('hide_input_in_errors', config.get('hide_input_in_errors')),
('coerce_numbers_to_str', config.get('coerce_numbers_to_str')),
('regex_engine', config.get('regex_engine')),
('validation_error_cause', config.get('validation_error_cause')),
('cache_strings', config.get('cache_strings')),
('validate_by_alias', config.get('validate_by_alias')),
('validate_by_name', config.get('validate_by_name')),
('serialize_by_alias', config.get('serialize_by_alias')),
)
if v is not None
}
)
def __repr__(self):
c = ', '.join(f'{k}={v!r}' for k, v in self.config_dict.items())
return f'ConfigWrapper({c})'
class ConfigWrapperStack:
"""A stack of `ConfigWrapper` instances."""
def __init__(self, config_wrapper: ConfigWrapper):
self._config_wrapper_stack: list[ConfigWrapper] = [config_wrapper]
@property
def tail(self) -> ConfigWrapper:
return self._config_wrapper_stack[-1]
@contextmanager
def push(self, config_wrapper: ConfigWrapper | ConfigDict | None):
if config_wrapper is None:
yield
return
if not isinstance(config_wrapper, ConfigWrapper):
config_wrapper = ConfigWrapper(config_wrapper, check=False)
self._config_wrapper_stack.append(config_wrapper)
try:
yield
finally:
self._config_wrapper_stack.pop()
config_defaults = ConfigDict(
title=None,
str_to_lower=False,
str_to_upper=False,
str_strip_whitespace=False,
str_min_length=0,
str_max_length=None,
# let the model / dataclass decide how to handle it
extra=None,
frozen=False,
populate_by_name=False,
use_enum_values=False,
validate_assignment=False,
arbitrary_types_allowed=False,
from_attributes=False,
loc_by_alias=True,
alias_generator=None,
model_title_generator=None,
field_title_generator=None,
ignored_types=(),
allow_inf_nan=True,
json_schema_extra=None,
strict=False,
revalidate_instances='never',
ser_json_timedelta='iso8601',
ser_json_bytes='utf8',
val_json_bytes='utf8',
ser_json_inf_nan='null',
validate_default=False,
validate_return=False,
protected_namespaces=('model_validate', 'model_dump'),
hide_input_in_errors=False,
json_encoders=None,
defer_build=False,
schema_generator=None,
plugin_settings=None,
json_schema_serialization_defaults_required=False,
json_schema_mode_override=None,
coerce_numbers_to_str=False,
regex_engine='rust-regex',
validation_error_cause=False,
use_attribute_docstrings=False,
cache_strings=True,
validate_by_alias=True,
validate_by_name=False,
serialize_by_alias=False,
)
def prepare_config(config: ConfigDict | dict[str, Any] | type[Any] | None) -> ConfigDict:
"""Create a `ConfigDict` instance from an existing dict, a class (e.g. old class-based config) or None.
Args:
config: The input config.
Returns:
A ConfigDict object created from config.
"""
if config is None:
return ConfigDict()
if not isinstance(config, dict):
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
config = {k: getattr(config, k) for k in dir(config) if not k.startswith('__')}
config_dict = cast(ConfigDict, config)
check_deprecated(config_dict)
return config_dict
config_keys = set(ConfigDict.__annotations__.keys())
V2_REMOVED_KEYS = {
'allow_mutation',
'error_msg_templates',
'fields',
'getter_dict',
'smart_union',
'underscore_attrs_are_private',
'json_loads',
'json_dumps',
'copy_on_model_validation',
'post_init_call',
}
V2_RENAMED_KEYS = {
'allow_population_by_field_name': 'validate_by_name',
'anystr_lower': 'str_to_lower',
'anystr_strip_whitespace': 'str_strip_whitespace',
'anystr_upper': 'str_to_upper',
'keep_untouched': 'ignored_types',
'max_anystr_length': 'str_max_length',
'min_anystr_length': 'str_min_length',
'orm_mode': 'from_attributes',
'schema_extra': 'json_schema_extra',
'validate_all': 'validate_default',
}
def check_deprecated(config_dict: ConfigDict) -> None:
"""Check for deprecated config keys and warn the user.
Args:
config_dict: The input config.
"""
deprecated_removed_keys = V2_REMOVED_KEYS & config_dict.keys()
deprecated_renamed_keys = V2_RENAMED_KEYS.keys() & config_dict.keys()
if deprecated_removed_keys or deprecated_renamed_keys:
renamings = {k: V2_RENAMED_KEYS[k] for k in sorted(deprecated_renamed_keys)}
renamed_bullets = [f'* {k!r} has been renamed to {v!r}' for k, v in renamings.items()]
removed_bullets = [f'* {k!r} has been removed' for k in sorted(deprecated_removed_keys)]
message = '\n'.join(['Valid config keys have changed in V2:'] + renamed_bullets + removed_bullets)
warnings.warn(message, UserWarning)

View File

@ -0,0 +1,97 @@
from __future__ import annotations as _annotations
from typing import TYPE_CHECKING, Any, TypedDict, cast
from warnings import warn
if TYPE_CHECKING:
from ..config import JsonDict, JsonSchemaExtraCallable
from ._schema_generation_shared import (
GetJsonSchemaFunction,
)
class CoreMetadata(TypedDict, total=False):
"""A `TypedDict` for holding the metadata dict of the schema.
Attributes:
pydantic_js_functions: List of JSON schema functions that resolve refs during application.
pydantic_js_annotation_functions: List of JSON schema functions that don't resolve refs during application.
pydantic_js_prefer_positional_arguments: Whether JSON schema generator will
prefer positional over keyword arguments for an 'arguments' schema.
custom validation function. Only applies to before, plain, and wrap validators.
pydantic_js_updates: key / value pair updates to apply to the JSON schema for a type.
pydantic_js_extra: WIP, either key/value pair updates to apply to the JSON schema, or a custom callable.
pydantic_internal_union_tag_key: Used internally by the `Tag` metadata to specify the tag used for a discriminated union.
pydantic_internal_union_discriminator: Used internally to specify the discriminator value for a discriminated union
when the discriminator was applied to a `'definition-ref'` schema, and that reference was missing at the time
of the annotation application.
TODO: Perhaps we should move this structure to pydantic-core. At the moment, though,
it's easier to iterate on if we leave it in pydantic until we feel there is a semi-stable API.
TODO: It's unfortunate how functionally oriented JSON schema generation is, especially that which occurs during
the core schema generation process. It's inevitable that we need to store some json schema related information
on core schemas, given that we generate JSON schemas directly from core schemas. That being said, debugging related
issues is quite difficult when JSON schema information is disguised via dynamically defined functions.
"""
pydantic_js_functions: list[GetJsonSchemaFunction]
pydantic_js_annotation_functions: list[GetJsonSchemaFunction]
pydantic_js_prefer_positional_arguments: bool
pydantic_js_updates: JsonDict
pydantic_js_extra: JsonDict | JsonSchemaExtraCallable
pydantic_internal_union_tag_key: str
pydantic_internal_union_discriminator: str
def update_core_metadata(
core_metadata: Any,
/,
*,
pydantic_js_functions: list[GetJsonSchemaFunction] | None = None,
pydantic_js_annotation_functions: list[GetJsonSchemaFunction] | None = None,
pydantic_js_updates: JsonDict | None = None,
pydantic_js_extra: JsonDict | JsonSchemaExtraCallable | None = None,
) -> None:
from ..json_schema import PydanticJsonSchemaWarning
"""Update CoreMetadata instance in place. When we make modifications in this function, they
take effect on the `core_metadata` reference passed in as the first (and only) positional argument.
First, cast to `CoreMetadata`, then finish with a cast to `dict[str, Any]` for core schema compatibility.
We do this here, instead of before / after each call to this function so that this typing hack
can be easily removed if/when we move `CoreMetadata` to `pydantic-core`.
For parameter descriptions, see `CoreMetadata` above.
"""
core_metadata = cast(CoreMetadata, core_metadata)
if pydantic_js_functions:
core_metadata.setdefault('pydantic_js_functions', []).extend(pydantic_js_functions)
if pydantic_js_annotation_functions:
core_metadata.setdefault('pydantic_js_annotation_functions', []).extend(pydantic_js_annotation_functions)
if pydantic_js_updates:
if (existing_updates := core_metadata.get('pydantic_js_updates')) is not None:
core_metadata['pydantic_js_updates'] = {**existing_updates, **pydantic_js_updates}
else:
core_metadata['pydantic_js_updates'] = pydantic_js_updates
if pydantic_js_extra is not None:
existing_pydantic_js_extra = core_metadata.get('pydantic_js_extra')
if existing_pydantic_js_extra is None:
core_metadata['pydantic_js_extra'] = pydantic_js_extra
if isinstance(existing_pydantic_js_extra, dict):
if isinstance(pydantic_js_extra, dict):
core_metadata['pydantic_js_extra'] = {**existing_pydantic_js_extra, **pydantic_js_extra}
if callable(pydantic_js_extra):
warn(
'Composing `dict` and `callable` type `json_schema_extra` is not supported.'
'The `callable` type is being ignored.'
"If you'd like support for this behavior, please open an issue on pydantic.",
PydanticJsonSchemaWarning,
)
if callable(existing_pydantic_js_extra):
# if ever there's a case of a callable, we'll just keep the last json schema extra spec
core_metadata['pydantic_js_extra'] = pydantic_js_extra

View File

@ -0,0 +1,182 @@
from __future__ import annotations
import inspect
import os
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union
from pydantic_core import CoreSchema, core_schema
from pydantic_core import validate_core_schema as _validate_core_schema
from typing_extensions import TypeGuard, get_args, get_origin
from typing_inspection import typing_objects
from . import _repr
from ._typing_extra import is_generic_alias
if TYPE_CHECKING:
from rich.console import Console
AnyFunctionSchema = Union[
core_schema.AfterValidatorFunctionSchema,
core_schema.BeforeValidatorFunctionSchema,
core_schema.WrapValidatorFunctionSchema,
core_schema.PlainValidatorFunctionSchema,
]
FunctionSchemaWithInnerSchema = Union[
core_schema.AfterValidatorFunctionSchema,
core_schema.BeforeValidatorFunctionSchema,
core_schema.WrapValidatorFunctionSchema,
]
CoreSchemaField = Union[
core_schema.ModelField, core_schema.DataclassField, core_schema.TypedDictField, core_schema.ComputedField
]
CoreSchemaOrField = Union[core_schema.CoreSchema, CoreSchemaField]
_CORE_SCHEMA_FIELD_TYPES = {'typed-dict-field', 'dataclass-field', 'model-field', 'computed-field'}
_FUNCTION_WITH_INNER_SCHEMA_TYPES = {'function-before', 'function-after', 'function-wrap'}
_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'set', 'frozenset'}
def is_core_schema(
schema: CoreSchemaOrField,
) -> TypeGuard[CoreSchema]:
return schema['type'] not in _CORE_SCHEMA_FIELD_TYPES
def is_core_schema_field(
schema: CoreSchemaOrField,
) -> TypeGuard[CoreSchemaField]:
return schema['type'] in _CORE_SCHEMA_FIELD_TYPES
def is_function_with_inner_schema(
schema: CoreSchemaOrField,
) -> TypeGuard[FunctionSchemaWithInnerSchema]:
return schema['type'] in _FUNCTION_WITH_INNER_SCHEMA_TYPES
def is_list_like_schema_with_items_schema(
schema: CoreSchema,
) -> TypeGuard[core_schema.ListSchema | core_schema.SetSchema | core_schema.FrozenSetSchema]:
return schema['type'] in _LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES
def get_type_ref(type_: Any, args_override: tuple[type[Any], ...] | None = None) -> str:
"""Produces the ref to be used for this type by pydantic_core's core schemas.
This `args_override` argument was added for the purpose of creating valid recursive references
when creating generic models without needing to create a concrete class.
"""
origin = get_origin(type_) or type_
args = get_args(type_) if is_generic_alias(type_) else (args_override or ())
generic_metadata = getattr(type_, '__pydantic_generic_metadata__', None)
if generic_metadata:
origin = generic_metadata['origin'] or origin
args = generic_metadata['args'] or args
module_name = getattr(origin, '__module__', '<No __module__>')
if typing_objects.is_typealiastype(origin):
type_ref = f'{module_name}.{origin.__name__}:{id(origin)}'
else:
try:
qualname = getattr(origin, '__qualname__', f'<No __qualname__: {origin}>')
except Exception:
qualname = getattr(origin, '__qualname__', '<No __qualname__>')
type_ref = f'{module_name}.{qualname}:{id(origin)}'
arg_refs: list[str] = []
for arg in args:
if isinstance(arg, str):
# Handle string literals as a special case; we may be able to remove this special handling if we
# wrap them in a ForwardRef at some point.
arg_ref = f'{arg}:str-{id(arg)}'
else:
arg_ref = f'{_repr.display_as_type(arg)}:{id(arg)}'
arg_refs.append(arg_ref)
if arg_refs:
type_ref = f'{type_ref}[{",".join(arg_refs)}]'
return type_ref
def get_ref(s: core_schema.CoreSchema) -> None | str:
"""Get the ref from the schema if it has one.
This exists just for type checking to work correctly.
"""
return s.get('ref', None)
def validate_core_schema(schema: CoreSchema) -> CoreSchema:
if os.getenv('PYDANTIC_VALIDATE_CORE_SCHEMAS'):
return _validate_core_schema(schema)
return schema
def _clean_schema_for_pretty_print(obj: Any, strip_metadata: bool = True) -> Any: # pragma: no cover
"""A utility function to remove irrelevant information from a core schema."""
if isinstance(obj, Mapping):
new_dct = {}
for k, v in obj.items():
if k == 'metadata' and strip_metadata:
new_metadata = {}
for meta_k, meta_v in v.items():
if meta_k in ('pydantic_js_functions', 'pydantic_js_annotation_functions'):
new_metadata['js_metadata'] = '<stripped>'
else:
new_metadata[meta_k] = _clean_schema_for_pretty_print(meta_v, strip_metadata=strip_metadata)
if list(new_metadata.keys()) == ['js_metadata']:
new_metadata = {'<stripped>'}
new_dct[k] = new_metadata
# Remove some defaults:
elif k in ('custom_init', 'root_model') and not v:
continue
else:
new_dct[k] = _clean_schema_for_pretty_print(v, strip_metadata=strip_metadata)
return new_dct
elif isinstance(obj, Sequence) and not isinstance(obj, str):
return [_clean_schema_for_pretty_print(v, strip_metadata=strip_metadata) for v in obj]
else:
return obj
def pretty_print_core_schema(
val: Any,
*,
console: Console | None = None,
max_depth: int | None = None,
strip_metadata: bool = True,
) -> None: # pragma: no cover
"""Pretty-print a core schema using the `rich` library.
Args:
val: The core schema to print, or a Pydantic model/dataclass/type adapter
(in which case the cached core schema is fetched and printed).
console: A rich console to use when printing. Defaults to the global rich console instance.
max_depth: The number of nesting levels which may be printed.
strip_metadata: Whether to strip metadata in the output. If `True` any known core metadata
attributes will be stripped (but custom attributes are kept). Defaults to `True`.
"""
# lazy import:
from rich.pretty import pprint
# circ. imports:
from pydantic import BaseModel, TypeAdapter
from pydantic.dataclasses import is_pydantic_dataclass
if (inspect.isclass(val) and issubclass(val, BaseModel)) or is_pydantic_dataclass(val):
val = val.__pydantic_core_schema__
if isinstance(val, TypeAdapter):
val = val.core_schema
cleaned_schema = _clean_schema_for_pretty_print(val, strip_metadata=strip_metadata)
pprint(cleaned_schema, console=console, max_depth=max_depth)
pps = pretty_print_core_schema

View File

@ -0,0 +1,235 @@
"""Private logic for creating pydantic dataclasses."""
from __future__ import annotations as _annotations
import dataclasses
import typing
import warnings
from functools import partial, wraps
from typing import Any, ClassVar
from pydantic_core import (
ArgsKwargs,
SchemaSerializer,
SchemaValidator,
core_schema,
)
from typing_extensions import TypeGuard
from ..errors import PydanticUndefinedAnnotation
from ..plugin._schema_validator import PluggableSchemaValidator, create_schema_validator
from ..warnings import PydanticDeprecatedSince20
from . import _config, _decorators
from ._fields import collect_dataclass_fields
from ._generate_schema import GenerateSchema, InvalidSchemaError
from ._generics import get_standard_typevars_map
from ._mock_val_ser import set_dataclass_mocks
from ._namespace_utils import NsResolver
from ._signature import generate_pydantic_signature
from ._utils import LazyClassAttribute
if typing.TYPE_CHECKING:
from _typeshed import DataclassInstance as StandardDataclass
from ..config import ConfigDict
from ..fields import FieldInfo
class PydanticDataclass(StandardDataclass, typing.Protocol):
"""A protocol containing attributes only available once a class has been decorated as a Pydantic dataclass.
Attributes:
__pydantic_config__: Pydantic-specific configuration settings for the dataclass.
__pydantic_complete__: Whether dataclass building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
__pydantic_decorators__: Metadata containing the decorators defined on the dataclass.
__pydantic_fields__: Metadata about the fields defined on the dataclass.
__pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the dataclass.
__pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the dataclass.
"""
__pydantic_config__: ClassVar[ConfigDict]
__pydantic_complete__: ClassVar[bool]
__pydantic_core_schema__: ClassVar[core_schema.CoreSchema]
__pydantic_decorators__: ClassVar[_decorators.DecoratorInfos]
__pydantic_fields__: ClassVar[dict[str, FieldInfo]]
__pydantic_serializer__: ClassVar[SchemaSerializer]
__pydantic_validator__: ClassVar[SchemaValidator | PluggableSchemaValidator]
else:
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
# and https://youtrack.jetbrains.com/issue/PY-51428
DeprecationWarning = PydanticDeprecatedSince20
def set_dataclass_fields(
cls: type[StandardDataclass],
ns_resolver: NsResolver | None = None,
config_wrapper: _config.ConfigWrapper | None = None,
) -> None:
"""Collect and set `cls.__pydantic_fields__`.
Args:
cls: The class.
ns_resolver: Namespace resolver to use when getting dataclass annotations.
config_wrapper: The config wrapper instance, defaults to `None`.
"""
typevars_map = get_standard_typevars_map(cls)
fields = collect_dataclass_fields(
cls, ns_resolver=ns_resolver, typevars_map=typevars_map, config_wrapper=config_wrapper
)
cls.__pydantic_fields__ = fields # type: ignore
def complete_dataclass(
cls: type[Any],
config_wrapper: _config.ConfigWrapper,
*,
raise_errors: bool = True,
ns_resolver: NsResolver | None = None,
_force_build: bool = False,
) -> bool:
"""Finish building a pydantic dataclass.
This logic is called on a class which has already been wrapped in `dataclasses.dataclass()`.
This is somewhat analogous to `pydantic._internal._model_construction.complete_model_class`.
Args:
cls: The class.
config_wrapper: The config wrapper instance.
raise_errors: Whether to raise errors, defaults to `True`.
ns_resolver: The namespace resolver instance to use when collecting dataclass fields
and during schema building.
_force_build: Whether to force building the dataclass, no matter if
[`defer_build`][pydantic.config.ConfigDict.defer_build] is set.
Returns:
`True` if building a pydantic dataclass is successfully completed, `False` otherwise.
Raises:
PydanticUndefinedAnnotation: If `raise_error` is `True` and there is an undefined annotations.
"""
original_init = cls.__init__
# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied,
# and so that the mock validator is used if building was deferred:
def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -> None:
__tracebackhide__ = True
s = __dataclass_self__
s.__pydantic_validator__.validate_python(ArgsKwargs(args, kwargs), self_instance=s)
__init__.__qualname__ = f'{cls.__qualname__}.__init__'
cls.__init__ = __init__ # type: ignore
cls.__pydantic_config__ = config_wrapper.config_dict # type: ignore
set_dataclass_fields(cls, ns_resolver, config_wrapper=config_wrapper)
if not _force_build and config_wrapper.defer_build:
set_dataclass_mocks(cls)
return False
if hasattr(cls, '__post_init_post_parse__'):
warnings.warn(
'Support for `__post_init_post_parse__` has been dropped, the method will not be called', DeprecationWarning
)
typevars_map = get_standard_typevars_map(cls)
gen_schema = GenerateSchema(
config_wrapper,
ns_resolver=ns_resolver,
typevars_map=typevars_map,
)
# set __signature__ attr only for the class, but not for its instances
# (because instances can define `__call__`, and `inspect.signature` shouldn't
# use the `__signature__` attribute and instead generate from `__call__`).
cls.__signature__ = LazyClassAttribute(
'__signature__',
partial(
generate_pydantic_signature,
# It's important that we reference the `original_init` here,
# as it is the one synthesized by the stdlib `dataclass` module:
init=original_init,
fields=cls.__pydantic_fields__, # type: ignore
validate_by_name=config_wrapper.validate_by_name,
extra=config_wrapper.extra,
is_dataclass=True,
),
)
try:
schema = gen_schema.generate_schema(cls)
except PydanticUndefinedAnnotation as e:
if raise_errors:
raise
set_dataclass_mocks(cls, f'`{e.name}`')
return False
core_config = config_wrapper.core_config(title=cls.__name__)
try:
schema = gen_schema.clean_schema(schema)
except InvalidSchemaError:
set_dataclass_mocks(cls)
return False
# We are about to set all the remaining required properties expected for this cast;
# __pydantic_decorators__ and __pydantic_fields__ should already be set
cls = typing.cast('type[PydanticDataclass]', cls)
# debug(schema)
cls.__pydantic_core_schema__ = schema
cls.__pydantic_validator__ = validator = create_schema_validator(
schema, cls, cls.__module__, cls.__qualname__, 'dataclass', core_config, config_wrapper.plugin_settings
)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
if config_wrapper.validate_assignment:
@wraps(cls.__setattr__)
def validated_setattr(instance: Any, field: str, value: str, /) -> None:
validator.validate_assignment(instance, field, value)
cls.__setattr__ = validated_setattr.__get__(None, cls) # type: ignore
cls.__pydantic_complete__ = True
return True
def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
"""Returns True if a class is a stdlib dataclass and *not* a pydantic dataclass.
We check that
- `_cls` is a dataclass
- `_cls` does not inherit from a processed pydantic dataclass (and thus have a `__pydantic_validator__`)
- `_cls` does not have any annotations that are not dataclass fields
e.g.
```python
import dataclasses
import pydantic.dataclasses
@dataclasses.dataclass
class A:
x: int
@pydantic.dataclasses.dataclass
class B(A):
y: int
```
In this case, when we first check `B`, we make an extra check and look at the annotations ('y'),
which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x')
Args:
cls: The class.
Returns:
`True` if the class is a stdlib dataclass, `False` otherwise.
"""
return (
dataclasses.is_dataclass(_cls)
and not hasattr(_cls, '__pydantic_validator__')
and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {})))
)

View File

@ -0,0 +1,838 @@
"""Logic related to validators applied to models etc. via the `@field_validator` and `@model_validator` decorators."""
from __future__ import annotations as _annotations
import types
from collections import deque
from collections.abc import Iterable
from dataclasses import dataclass, field
from functools import cached_property, partial, partialmethod
from inspect import Parameter, Signature, isdatadescriptor, ismethoddescriptor, signature
from itertools import islice
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Literal, TypeVar, Union
from pydantic_core import PydanticUndefined, PydanticUndefinedType, core_schema
from typing_extensions import TypeAlias, is_typeddict
from ..errors import PydanticUserError
from ._core_utils import get_type_ref
from ._internal_dataclass import slots_true
from ._namespace_utils import GlobalsNamespace, MappingNamespace
from ._typing_extra import get_function_type_hints
from ._utils import can_be_positional
if TYPE_CHECKING:
from ..fields import ComputedFieldInfo
from ..functional_validators import FieldValidatorModes
@dataclass(**slots_true)
class ValidatorDecoratorInfo:
"""A container for data from `@validator` so that we can access it
while building the pydantic-core schema.
Attributes:
decorator_repr: A class variable representing the decorator string, '@validator'.
fields: A tuple of field names the validator should be called on.
mode: The proposed validator mode.
each_item: For complex objects (sets, lists etc.) whether to validate individual
elements rather than the whole object.
always: Whether this method and other validators should be called even if the value is missing.
check_fields: Whether to check that the fields actually exist on the model.
"""
decorator_repr: ClassVar[str] = '@validator'
fields: tuple[str, ...]
mode: Literal['before', 'after']
each_item: bool
always: bool
check_fields: bool | None
@dataclass(**slots_true)
class FieldValidatorDecoratorInfo:
"""A container for data from `@field_validator` so that we can access it
while building the pydantic-core schema.
Attributes:
decorator_repr: A class variable representing the decorator string, '@field_validator'.
fields: A tuple of field names the validator should be called on.
mode: The proposed validator mode.
check_fields: Whether to check that the fields actually exist on the model.
json_schema_input_type: The input type of the function. This is only used to generate
the appropriate JSON Schema (in validation mode) and can only specified
when `mode` is either `'before'`, `'plain'` or `'wrap'`.
"""
decorator_repr: ClassVar[str] = '@field_validator'
fields: tuple[str, ...]
mode: FieldValidatorModes
check_fields: bool | None
json_schema_input_type: Any
@dataclass(**slots_true)
class RootValidatorDecoratorInfo:
"""A container for data from `@root_validator` so that we can access it
while building the pydantic-core schema.
Attributes:
decorator_repr: A class variable representing the decorator string, '@root_validator'.
mode: The proposed validator mode.
"""
decorator_repr: ClassVar[str] = '@root_validator'
mode: Literal['before', 'after']
@dataclass(**slots_true)
class FieldSerializerDecoratorInfo:
"""A container for data from `@field_serializer` so that we can access it
while building the pydantic-core schema.
Attributes:
decorator_repr: A class variable representing the decorator string, '@field_serializer'.
fields: A tuple of field names the serializer should be called on.
mode: The proposed serializer mode.
return_type: The type of the serializer's return value.
when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`,
and `'json-unless-none'`.
check_fields: Whether to check that the fields actually exist on the model.
"""
decorator_repr: ClassVar[str] = '@field_serializer'
fields: tuple[str, ...]
mode: Literal['plain', 'wrap']
return_type: Any
when_used: core_schema.WhenUsed
check_fields: bool | None
@dataclass(**slots_true)
class ModelSerializerDecoratorInfo:
"""A container for data from `@model_serializer` so that we can access it
while building the pydantic-core schema.
Attributes:
decorator_repr: A class variable representing the decorator string, '@model_serializer'.
mode: The proposed serializer mode.
return_type: The type of the serializer's return value.
when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`,
and `'json-unless-none'`.
"""
decorator_repr: ClassVar[str] = '@model_serializer'
mode: Literal['plain', 'wrap']
return_type: Any
when_used: core_schema.WhenUsed
@dataclass(**slots_true)
class ModelValidatorDecoratorInfo:
"""A container for data from `@model_validator` so that we can access it
while building the pydantic-core schema.
Attributes:
decorator_repr: A class variable representing the decorator string, '@model_validator'.
mode: The proposed serializer mode.
"""
decorator_repr: ClassVar[str] = '@model_validator'
mode: Literal['wrap', 'before', 'after']
DecoratorInfo: TypeAlias = """Union[
ValidatorDecoratorInfo,
FieldValidatorDecoratorInfo,
RootValidatorDecoratorInfo,
FieldSerializerDecoratorInfo,
ModelSerializerDecoratorInfo,
ModelValidatorDecoratorInfo,
ComputedFieldInfo,
]"""
ReturnType = TypeVar('ReturnType')
DecoratedType: TypeAlias = (
'Union[classmethod[Any, Any, ReturnType], staticmethod[Any, ReturnType], Callable[..., ReturnType], property]'
)
@dataclass # can't use slots here since we set attributes on `__post_init__`
class PydanticDescriptorProxy(Generic[ReturnType]):
"""Wrap a classmethod, staticmethod, property or unbound function
and act as a descriptor that allows us to detect decorated items
from the class' attributes.
This class' __get__ returns the wrapped item's __get__ result,
which makes it transparent for classmethods and staticmethods.
Attributes:
wrapped: The decorator that has to be wrapped.
decorator_info: The decorator info.
shim: A wrapper function to wrap V1 style function.
"""
wrapped: DecoratedType[ReturnType]
decorator_info: DecoratorInfo
shim: Callable[[Callable[..., Any]], Callable[..., Any]] | None = None
def __post_init__(self):
for attr in 'setter', 'deleter':
if hasattr(self.wrapped, attr):
f = partial(self._call_wrapped_attr, name=attr)
setattr(self, attr, f)
def _call_wrapped_attr(self, func: Callable[[Any], None], *, name: str) -> PydanticDescriptorProxy[ReturnType]:
self.wrapped = getattr(self.wrapped, name)(func)
if isinstance(self.wrapped, property):
# update ComputedFieldInfo.wrapped_property
from ..fields import ComputedFieldInfo
if isinstance(self.decorator_info, ComputedFieldInfo):
self.decorator_info.wrapped_property = self.wrapped
return self
def __get__(self, obj: object | None, obj_type: type[object] | None = None) -> PydanticDescriptorProxy[ReturnType]:
try:
return self.wrapped.__get__(obj, obj_type)
except AttributeError:
# not a descriptor, e.g. a partial object
return self.wrapped # type: ignore[return-value]
def __set_name__(self, instance: Any, name: str) -> None:
if hasattr(self.wrapped, '__set_name__'):
self.wrapped.__set_name__(instance, name) # pyright: ignore[reportFunctionMemberAccess]
def __getattr__(self, name: str, /) -> Any:
"""Forward checks for __isabstractmethod__ and such."""
return getattr(self.wrapped, name)
DecoratorInfoType = TypeVar('DecoratorInfoType', bound=DecoratorInfo)
@dataclass(**slots_true)
class Decorator(Generic[DecoratorInfoType]):
"""A generic container class to join together the decorator metadata
(metadata from decorator itself, which we have when the
decorator is called but not when we are building the core-schema)
and the bound function (which we have after the class itself is created).
Attributes:
cls_ref: The class ref.
cls_var_name: The decorated function name.
func: The decorated function.
shim: A wrapper function to wrap V1 style function.
info: The decorator info.
"""
cls_ref: str
cls_var_name: str
func: Callable[..., Any]
shim: Callable[[Any], Any] | None
info: DecoratorInfoType
@staticmethod
def build(
cls_: Any,
*,
cls_var_name: str,
shim: Callable[[Any], Any] | None,
info: DecoratorInfoType,
) -> Decorator[DecoratorInfoType]:
"""Build a new decorator.
Args:
cls_: The class.
cls_var_name: The decorated function name.
shim: A wrapper function to wrap V1 style function.
info: The decorator info.
Returns:
The new decorator instance.
"""
func = get_attribute_from_bases(cls_, cls_var_name)
if shim is not None:
func = shim(func)
func = unwrap_wrapped_function(func, unwrap_partial=False)
if not callable(func):
# This branch will get hit for classmethod properties
attribute = get_attribute_from_base_dicts(cls_, cls_var_name) # prevents the binding call to `__get__`
if isinstance(attribute, PydanticDescriptorProxy):
func = unwrap_wrapped_function(attribute.wrapped)
return Decorator(
cls_ref=get_type_ref(cls_),
cls_var_name=cls_var_name,
func=func,
shim=shim,
info=info,
)
def bind_to_cls(self, cls: Any) -> Decorator[DecoratorInfoType]:
"""Bind the decorator to a class.
Args:
cls: the class.
Returns:
The new decorator instance.
"""
return self.build(
cls,
cls_var_name=self.cls_var_name,
shim=self.shim,
info=self.info,
)
def get_bases(tp: type[Any]) -> tuple[type[Any], ...]:
"""Get the base classes of a class or typeddict.
Args:
tp: The type or class to get the bases.
Returns:
The base classes.
"""
if is_typeddict(tp):
return tp.__orig_bases__ # type: ignore
try:
return tp.__bases__
except AttributeError:
return ()
def mro(tp: type[Any]) -> tuple[type[Any], ...]:
"""Calculate the Method Resolution Order of bases using the C3 algorithm.
See https://www.python.org/download/releases/2.3/mro/
"""
# try to use the existing mro, for performance mainly
# but also because it helps verify the implementation below
if not is_typeddict(tp):
try:
return tp.__mro__
except AttributeError:
# GenericAlias and some other cases
pass
bases = get_bases(tp)
return (tp,) + mro_for_bases(bases)
def mro_for_bases(bases: tuple[type[Any], ...]) -> tuple[type[Any], ...]:
def merge_seqs(seqs: list[deque[type[Any]]]) -> Iterable[type[Any]]:
while True:
non_empty = [seq for seq in seqs if seq]
if not non_empty:
# Nothing left to process, we're done.
return
candidate: type[Any] | None = None
for seq in non_empty: # Find merge candidates among seq heads.
candidate = seq[0]
not_head = [s for s in non_empty if candidate in islice(s, 1, None)]
if not_head:
# Reject the candidate.
candidate = None
else:
break
if not candidate:
raise TypeError('Inconsistent hierarchy, no C3 MRO is possible')
yield candidate
for seq in non_empty:
# Remove candidate.
if seq[0] == candidate:
seq.popleft()
seqs = [deque(mro(base)) for base in bases] + [deque(bases)]
return tuple(merge_seqs(seqs))
_sentinel = object()
def get_attribute_from_bases(tp: type[Any] | tuple[type[Any], ...], name: str) -> Any:
"""Get the attribute from the next class in the MRO that has it,
aiming to simulate calling the method on the actual class.
The reason for iterating over the mro instead of just getting
the attribute (which would do that for us) is to support TypedDict,
which lacks a real __mro__, but can have a virtual one constructed
from its bases (as done here).
Args:
tp: The type or class to search for the attribute. If a tuple, this is treated as a set of base classes.
name: The name of the attribute to retrieve.
Returns:
Any: The attribute value, if found.
Raises:
AttributeError: If the attribute is not found in any class in the MRO.
"""
if isinstance(tp, tuple):
for base in mro_for_bases(tp):
attribute = base.__dict__.get(name, _sentinel)
if attribute is not _sentinel:
attribute_get = getattr(attribute, '__get__', None)
if attribute_get is not None:
return attribute_get(None, tp)
return attribute
raise AttributeError(f'{name} not found in {tp}')
else:
try:
return getattr(tp, name)
except AttributeError:
return get_attribute_from_bases(mro(tp), name)
def get_attribute_from_base_dicts(tp: type[Any], name: str) -> Any:
"""Get an attribute out of the `__dict__` following the MRO.
This prevents the call to `__get__` on the descriptor, and allows
us to get the original function for classmethod properties.
Args:
tp: The type or class to search for the attribute.
name: The name of the attribute to retrieve.
Returns:
Any: The attribute value, if found.
Raises:
KeyError: If the attribute is not found in any class's `__dict__` in the MRO.
"""
for base in reversed(mro(tp)):
if name in base.__dict__:
return base.__dict__[name]
return tp.__dict__[name] # raise the error
@dataclass(**slots_true)
class DecoratorInfos:
"""Mapping of name in the class namespace to decorator info.
note that the name in the class namespace is the function or attribute name
not the field name!
"""
validators: dict[str, Decorator[ValidatorDecoratorInfo]] = field(default_factory=dict)
field_validators: dict[str, Decorator[FieldValidatorDecoratorInfo]] = field(default_factory=dict)
root_validators: dict[str, Decorator[RootValidatorDecoratorInfo]] = field(default_factory=dict)
field_serializers: dict[str, Decorator[FieldSerializerDecoratorInfo]] = field(default_factory=dict)
model_serializers: dict[str, Decorator[ModelSerializerDecoratorInfo]] = field(default_factory=dict)
model_validators: dict[str, Decorator[ModelValidatorDecoratorInfo]] = field(default_factory=dict)
computed_fields: dict[str, Decorator[ComputedFieldInfo]] = field(default_factory=dict)
@staticmethod
def build(model_dc: type[Any]) -> DecoratorInfos: # noqa: C901 (ignore complexity)
"""We want to collect all DecFunc instances that exist as
attributes in the namespace of the class (a BaseModel or dataclass)
that called us
But we want to collect these in the order of the bases
So instead of getting them all from the leaf class (the class that called us),
we traverse the bases from root (the oldest ancestor class) to leaf
and collect all of the instances as we go, taking care to replace
any duplicate ones with the last one we see to mimic how function overriding
works with inheritance.
If we do replace any functions we put the replacement into the position
the replaced function was in; that is, we maintain the order.
"""
# reminder: dicts are ordered and replacement does not alter the order
res = DecoratorInfos()
for base in reversed(mro(model_dc)[1:]):
existing: DecoratorInfos | None = base.__dict__.get('__pydantic_decorators__')
if existing is None:
existing = DecoratorInfos.build(base)
res.validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.validators.items()})
res.field_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_validators.items()})
res.root_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.root_validators.items()})
res.field_serializers.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_serializers.items()})
res.model_serializers.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_serializers.items()})
res.model_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_validators.items()})
res.computed_fields.update({k: v.bind_to_cls(model_dc) for k, v in existing.computed_fields.items()})
to_replace: list[tuple[str, Any]] = []
for var_name, var_value in vars(model_dc).items():
if isinstance(var_value, PydanticDescriptorProxy):
info = var_value.decorator_info
if isinstance(info, ValidatorDecoratorInfo):
res.validators[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
)
elif isinstance(info, FieldValidatorDecoratorInfo):
res.field_validators[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
)
elif isinstance(info, RootValidatorDecoratorInfo):
res.root_validators[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
)
elif isinstance(info, FieldSerializerDecoratorInfo):
# check whether a serializer function is already registered for fields
for field_serializer_decorator in res.field_serializers.values():
# check that each field has at most one serializer function.
# serializer functions for the same field in subclasses are allowed,
# and are treated as overrides
if field_serializer_decorator.cls_var_name == var_name:
continue
for f in info.fields:
if f in field_serializer_decorator.info.fields:
raise PydanticUserError(
'Multiple field serializer functions were defined '
f'for field {f!r}, this is not allowed.',
code='multiple-field-serializers',
)
res.field_serializers[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
)
elif isinstance(info, ModelValidatorDecoratorInfo):
res.model_validators[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
)
elif isinstance(info, ModelSerializerDecoratorInfo):
res.model_serializers[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
)
else:
from ..fields import ComputedFieldInfo
isinstance(var_value, ComputedFieldInfo)
res.computed_fields[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=None, info=info
)
to_replace.append((var_name, var_value.wrapped))
if to_replace:
# If we can save `__pydantic_decorators__` on the class we'll be able to check for it above
# so then we don't need to re-process the type, which means we can discard our descriptor wrappers
# and replace them with the thing they are wrapping (see the other setattr call below)
# which allows validator class methods to also function as regular class methods
model_dc.__pydantic_decorators__ = res
for name, value in to_replace:
setattr(model_dc, name, value)
return res
def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes) -> bool:
"""Look at a field or model validator function and determine whether it takes an info argument.
An error is raised if the function has an invalid signature.
Args:
validator: The validator function to inspect.
mode: The proposed validator mode.
Returns:
Whether the validator takes an info argument.
"""
try:
sig = signature(validator)
except (ValueError, TypeError):
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
# In this case, we assume no info argument is present:
return False
n_positional = count_positional_required_params(sig)
if mode == 'wrap':
if n_positional == 3:
return True
elif n_positional == 2:
return False
else:
assert mode in {'before', 'after', 'plain'}, f"invalid mode: {mode!r}, expected 'before', 'after' or 'plain"
if n_positional == 2:
return True
elif n_positional == 1:
return False
raise PydanticUserError(
f'Unrecognized field_validator function signature for {validator} with `mode={mode}`:{sig}',
code='validator-signature',
)
def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> tuple[bool, bool]:
"""Look at a field serializer function and determine if it is a field serializer,
and whether it takes an info argument.
An error is raised if the function has an invalid signature.
Args:
serializer: The serializer function to inspect.
mode: The serializer mode, either 'plain' or 'wrap'.
Returns:
Tuple of (is_field_serializer, info_arg).
"""
try:
sig = signature(serializer)
except (ValueError, TypeError):
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
# In this case, we assume no info argument is present and this is not a method:
return (False, False)
first = next(iter(sig.parameters.values()), None)
is_field_serializer = first is not None and first.name == 'self'
n_positional = count_positional_required_params(sig)
if is_field_serializer:
# -1 to correct for self parameter
info_arg = _serializer_info_arg(mode, n_positional - 1)
else:
info_arg = _serializer_info_arg(mode, n_positional)
if info_arg is None:
raise PydanticUserError(
f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
code='field-serializer-signature',
)
return is_field_serializer, info_arg
def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
"""Look at a serializer function used via `Annotated` and determine whether it takes an info argument.
An error is raised if the function has an invalid signature.
Args:
serializer: The serializer function to check.
mode: The serializer mode, either 'plain' or 'wrap'.
Returns:
info_arg
"""
try:
sig = signature(serializer)
except (ValueError, TypeError):
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
# In this case, we assume no info argument is present:
return False
info_arg = _serializer_info_arg(mode, count_positional_required_params(sig))
if info_arg is None:
raise PydanticUserError(
f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
code='field-serializer-signature',
)
else:
return info_arg
def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
"""Look at a model serializer function and determine whether it takes an info argument.
An error is raised if the function has an invalid signature.
Args:
serializer: The serializer function to check.
mode: The serializer mode, either 'plain' or 'wrap'.
Returns:
`info_arg` - whether the function expects an info argument.
"""
if isinstance(serializer, (staticmethod, classmethod)) or not is_instance_method_from_sig(serializer):
raise PydanticUserError(
'`@model_serializer` must be applied to instance methods', code='model-serializer-instance-method'
)
sig = signature(serializer)
info_arg = _serializer_info_arg(mode, count_positional_required_params(sig))
if info_arg is None:
raise PydanticUserError(
f'Unrecognized model_serializer function signature for {serializer} with `mode={mode}`:{sig}',
code='model-serializer-signature',
)
else:
return info_arg
def _serializer_info_arg(mode: Literal['plain', 'wrap'], n_positional: int) -> bool | None:
if mode == 'plain':
if n_positional == 1:
# (input_value: Any, /) -> Any
return False
elif n_positional == 2:
# (model: Any, input_value: Any, /) -> Any
return True
else:
assert mode == 'wrap', f"invalid mode: {mode!r}, expected 'plain' or 'wrap'"
if n_positional == 2:
# (input_value: Any, serializer: SerializerFunctionWrapHandler, /) -> Any
return False
elif n_positional == 3:
# (input_value: Any, serializer: SerializerFunctionWrapHandler, info: SerializationInfo, /) -> Any
return True
return None
AnyDecoratorCallable: TypeAlias = (
'Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any], Callable[..., Any]]'
)
def is_instance_method_from_sig(function: AnyDecoratorCallable) -> bool:
"""Whether the function is an instance method.
It will consider a function as instance method if the first parameter of
function is `self`.
Args:
function: The function to check.
Returns:
`True` if the function is an instance method, `False` otherwise.
"""
sig = signature(unwrap_wrapped_function(function))
first = next(iter(sig.parameters.values()), None)
if first and first.name == 'self':
return True
return False
def ensure_classmethod_based_on_signature(function: AnyDecoratorCallable) -> Any:
"""Apply the `@classmethod` decorator on the function.
Args:
function: The function to apply the decorator on.
Return:
The `@classmethod` decorator applied function.
"""
if not isinstance(
unwrap_wrapped_function(function, unwrap_class_static_method=False), classmethod
) and _is_classmethod_from_sig(function):
return classmethod(function) # type: ignore[arg-type]
return function
def _is_classmethod_from_sig(function: AnyDecoratorCallable) -> bool:
sig = signature(unwrap_wrapped_function(function))
first = next(iter(sig.parameters.values()), None)
if first and first.name == 'cls':
return True
return False
def unwrap_wrapped_function(
func: Any,
*,
unwrap_partial: bool = True,
unwrap_class_static_method: bool = True,
) -> Any:
"""Recursively unwraps a wrapped function until the underlying function is reached.
This handles property, functools.partial, functools.partialmethod, staticmethod, and classmethod.
Args:
func: The function to unwrap.
unwrap_partial: If True (default), unwrap partial and partialmethod decorators.
unwrap_class_static_method: If True (default), also unwrap classmethod and staticmethod
decorators. If False, only unwrap partial and partialmethod decorators.
Returns:
The underlying function of the wrapped function.
"""
# Define the types we want to check against as a single tuple.
unwrap_types = (
(property, cached_property)
+ ((partial, partialmethod) if unwrap_partial else ())
+ ((staticmethod, classmethod) if unwrap_class_static_method else ())
)
while isinstance(func, unwrap_types):
if unwrap_class_static_method and isinstance(func, (classmethod, staticmethod)):
func = func.__func__
elif isinstance(func, (partial, partialmethod)):
func = func.func
elif isinstance(func, property):
func = func.fget # arbitrary choice, convenient for computed fields
else:
# Make coverage happy as it can only get here in the last possible case
assert isinstance(func, cached_property)
func = func.func # type: ignore
return func
_function_like = (
partial,
partialmethod,
types.FunctionType,
types.BuiltinFunctionType,
types.MethodType,
types.WrapperDescriptorType,
types.MethodWrapperType,
types.MemberDescriptorType,
)
def get_callable_return_type(
callable_obj: Any,
globalns: GlobalsNamespace | None = None,
localns: MappingNamespace | None = None,
) -> Any | PydanticUndefinedType:
"""Get the callable return type.
Args:
callable_obj: The callable to analyze.
globalns: The globals namespace to use during type annotation evaluation.
localns: The locals namespace to use during type annotation evaluation.
Returns:
The function return type.
"""
if isinstance(callable_obj, type):
# types are callables, and we assume the return type
# is the type itself (e.g. `int()` results in an instance of `int`).
return callable_obj
if not isinstance(callable_obj, _function_like):
call_func = getattr(type(callable_obj), '__call__', None) # noqa: B004
if call_func is not None:
callable_obj = call_func
hints = get_function_type_hints(
unwrap_wrapped_function(callable_obj),
include_keys={'return'},
globalns=globalns,
localns=localns,
)
return hints.get('return', PydanticUndefined)
def count_positional_required_params(sig: Signature) -> int:
"""Get the number of positional (required) arguments of a signature.
This function should only be used to inspect signatures of validation and serialization functions.
The first argument (the value being serialized or validated) is counted as a required argument
even if a default value exists.
Returns:
The number of positional arguments of a signature.
"""
parameters = list(sig.parameters.values())
return sum(
1
for param in parameters
if can_be_positional(param)
# First argument is the value being validated/serialized, and can have a default value
# (e.g. `float`, which has signature `(x=0, /)`). We assume other parameters (the info arg
# for instance) should be required, and thus without any default value.
and (param.default is Parameter.empty or param is parameters[0])
)
def ensure_property(f: Any) -> Any:
"""Ensure that a function is a `property` or `cached_property`, or is a valid descriptor.
Args:
f: The function to check.
Returns:
The function, or a `property` or `cached_property` instance wrapping the function.
"""
if ismethoddescriptor(f) or isdatadescriptor(f):
return f
else:
return property(f)

View File

@ -0,0 +1,174 @@
"""Logic for V1 validators, e.g. `@validator` and `@root_validator`."""
from __future__ import annotations as _annotations
from inspect import Parameter, signature
from typing import Any, Union, cast
from pydantic_core import core_schema
from typing_extensions import Protocol
from ..errors import PydanticUserError
from ._utils import can_be_positional
class V1OnlyValueValidator(Protocol):
"""A simple validator, supported for V1 validators and V2 validators."""
def __call__(self, __value: Any) -> Any: ...
class V1ValidatorWithValues(Protocol):
"""A validator with `values` argument, supported for V1 validators and V2 validators."""
def __call__(self, __value: Any, values: dict[str, Any]) -> Any: ...
class V1ValidatorWithValuesKwOnly(Protocol):
"""A validator with keyword only `values` argument, supported for V1 validators and V2 validators."""
def __call__(self, __value: Any, *, values: dict[str, Any]) -> Any: ...
class V1ValidatorWithKwargs(Protocol):
"""A validator with `kwargs` argument, supported for V1 validators and V2 validators."""
def __call__(self, __value: Any, **kwargs: Any) -> Any: ...
class V1ValidatorWithValuesAndKwargs(Protocol):
"""A validator with `values` and `kwargs` arguments, supported for V1 validators and V2 validators."""
def __call__(self, __value: Any, values: dict[str, Any], **kwargs: Any) -> Any: ...
V1Validator = Union[
V1ValidatorWithValues, V1ValidatorWithValuesKwOnly, V1ValidatorWithKwargs, V1ValidatorWithValuesAndKwargs
]
def can_be_keyword(param: Parameter) -> bool:
return param.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY)
def make_generic_v1_field_validator(validator: V1Validator) -> core_schema.WithInfoValidatorFunction:
"""Wrap a V1 style field validator for V2 compatibility.
Args:
validator: The V1 style field validator.
Returns:
A wrapped V2 style field validator.
Raises:
PydanticUserError: If the signature is not supported or the parameters are
not available in Pydantic V2.
"""
sig = signature(validator)
needs_values_kw = False
for param_num, (param_name, parameter) in enumerate(sig.parameters.items()):
if can_be_keyword(parameter) and param_name in ('field', 'config'):
raise PydanticUserError(
'The `field` and `config` parameters are not available in Pydantic V2, '
'please use the `info` parameter instead.',
code='validator-field-config-info',
)
if parameter.kind is Parameter.VAR_KEYWORD:
needs_values_kw = True
elif can_be_keyword(parameter) and param_name == 'values':
needs_values_kw = True
elif can_be_positional(parameter) and param_num == 0:
# value
continue
elif parameter.default is Parameter.empty: # ignore params with defaults e.g. bound by functools.partial
raise PydanticUserError(
f'Unsupported signature for V1 style validator {validator}: {sig} is not supported.',
code='validator-v1-signature',
)
if needs_values_kw:
# (v, **kwargs), (v, values, **kwargs), (v, *, values, **kwargs) or (v, *, values)
val1 = cast(V1ValidatorWithValues, validator)
def wrapper1(value: Any, info: core_schema.ValidationInfo) -> Any:
return val1(value, values=info.data)
return wrapper1
else:
val2 = cast(V1OnlyValueValidator, validator)
def wrapper2(value: Any, _: core_schema.ValidationInfo) -> Any:
return val2(value)
return wrapper2
RootValidatorValues = dict[str, Any]
# technically tuple[model_dict, model_extra, fields_set] | tuple[dataclass_dict, init_vars]
RootValidatorFieldsTuple = tuple[Any, ...]
class V1RootValidatorFunction(Protocol):
"""A simple root validator, supported for V1 validators and V2 validators."""
def __call__(self, __values: RootValidatorValues) -> RootValidatorValues: ...
class V2CoreBeforeRootValidator(Protocol):
"""V2 validator with mode='before'."""
def __call__(self, __values: RootValidatorValues, __info: core_schema.ValidationInfo) -> RootValidatorValues: ...
class V2CoreAfterRootValidator(Protocol):
"""V2 validator with mode='after'."""
def __call__(
self, __fields_tuple: RootValidatorFieldsTuple, __info: core_schema.ValidationInfo
) -> RootValidatorFieldsTuple: ...
def make_v1_generic_root_validator(
validator: V1RootValidatorFunction, pre: bool
) -> V2CoreBeforeRootValidator | V2CoreAfterRootValidator:
"""Wrap a V1 style root validator for V2 compatibility.
Args:
validator: The V1 style field validator.
pre: Whether the validator is a pre validator.
Returns:
A wrapped V2 style validator.
"""
if pre is True:
# mode='before' for pydantic-core
def _wrapper1(values: RootValidatorValues, _: core_schema.ValidationInfo) -> RootValidatorValues:
return validator(values)
return _wrapper1
# mode='after' for pydantic-core
def _wrapper2(fields_tuple: RootValidatorFieldsTuple, _: core_schema.ValidationInfo) -> RootValidatorFieldsTuple:
if len(fields_tuple) == 2:
# dataclass, this is easy
values, init_vars = fields_tuple
values = validator(values)
return values, init_vars
else:
# ugly hack: to match v1 behaviour, we merge values and model_extra, then split them up based on fields
# afterwards
model_dict, model_extra, fields_set = fields_tuple
if model_extra:
fields = set(model_dict.keys())
model_dict.update(model_extra)
model_dict_new = validator(model_dict)
for k in list(model_dict_new.keys()):
if k not in fields:
model_extra[k] = model_dict_new.pop(k)
else:
model_dict_new = validator(model_dict)
return model_dict_new, model_extra, fields_set
return _wrapper2

View File

@ -0,0 +1,479 @@
from __future__ import annotations as _annotations
from collections.abc import Hashable, Sequence
from typing import TYPE_CHECKING, Any, cast
from pydantic_core import CoreSchema, core_schema
from ..errors import PydanticUserError
from . import _core_utils
from ._core_utils import (
CoreSchemaField,
)
if TYPE_CHECKING:
from ..types import Discriminator
from ._core_metadata import CoreMetadata
class MissingDefinitionForUnionRef(Exception):
"""Raised when applying a discriminated union discriminator to a schema
requires a definition that is not yet defined
"""
def __init__(self, ref: str) -> None:
self.ref = ref
super().__init__(f'Missing definition for ref {self.ref!r}')
def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None:
metadata = cast('CoreMetadata', schema.setdefault('metadata', {}))
metadata['pydantic_internal_union_discriminator'] = discriminator
def apply_discriminator(
schema: core_schema.CoreSchema,
discriminator: str | Discriminator,
definitions: dict[str, core_schema.CoreSchema] | None = None,
) -> core_schema.CoreSchema:
"""Applies the discriminator and returns a new core schema.
Args:
schema: The input schema.
discriminator: The name of the field which will serve as the discriminator.
definitions: A mapping of schema ref to schema.
Returns:
The new core schema.
Raises:
TypeError:
- If `discriminator` is used with invalid union variant.
- If `discriminator` is used with `Union` type with one variant.
- If `discriminator` value mapped to multiple choices.
MissingDefinitionForUnionRef:
If the definition for ref is missing.
PydanticUserError:
- If a model in union doesn't have a discriminator field.
- If discriminator field has a non-string alias.
- If discriminator fields have different aliases.
- If discriminator field not of type `Literal`.
"""
from ..types import Discriminator
if isinstance(discriminator, Discriminator):
if isinstance(discriminator.discriminator, str):
discriminator = discriminator.discriminator
else:
return discriminator._convert_schema(schema)
return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema)
class _ApplyInferredDiscriminator:
"""This class is used to convert an input schema containing a union schema into one where that union is
replaced with a tagged-union, with all the associated debugging and performance benefits.
This is done by:
* Validating that the input schema is compatible with the provided discriminator
* Introspecting the schema to determine which discriminator values should map to which union choices
* Handling various edge cases such as 'definitions', 'default', 'nullable' schemas, and more
I have chosen to implement the conversion algorithm in this class, rather than a function,
to make it easier to maintain state while recursively walking the provided CoreSchema.
"""
def __init__(self, discriminator: str, definitions: dict[str, core_schema.CoreSchema]):
# `discriminator` should be the name of the field which will serve as the discriminator.
# It must be the python name of the field, and *not* the field's alias. Note that as of now,
# all members of a discriminated union _must_ use a field with the same name as the discriminator.
# This may change if/when we expose a way to manually specify the TaggedUnionSchema's choices.
self.discriminator = discriminator
# `definitions` should contain a mapping of schema ref to schema for all schemas which might
# be referenced by some choice
self.definitions = definitions
# `_discriminator_alias` will hold the value, if present, of the alias for the discriminator
#
# Note: following the v1 implementation, we currently disallow the use of different aliases
# for different choices. This is not a limitation of pydantic_core, but if we try to handle
# this, the inference logic gets complicated very quickly, and could result in confusing
# debugging challenges for users making subtle mistakes.
#
# Rather than trying to do the most powerful inference possible, I think we should eventually
# expose a way to more-manually control the way the TaggedUnionSchema is constructed through
# the use of a new type which would be placed as an Annotation on the Union type. This would
# provide the full flexibility/power of pydantic_core's TaggedUnionSchema where necessary for
# more complex cases, without over-complicating the inference logic for the common cases.
self._discriminator_alias: str | None = None
# `_should_be_nullable` indicates whether the converted union has `None` as an allowed value.
# If `None` is an acceptable value of the (possibly-wrapped) union, we ignore it while
# constructing the TaggedUnionSchema, but set the `_should_be_nullable` attribute to True.
# Once we have constructed the TaggedUnionSchema, if `_should_be_nullable` is True, we ensure
# that the final schema gets wrapped as a NullableSchema. This has the same semantics on the
# python side, but resolves the issue that `None` cannot correspond to any discriminator values.
self._should_be_nullable = False
# `_is_nullable` is used to track if the final produced schema will definitely be nullable;
# we set it to True if the input schema is wrapped in a nullable schema that we know will be preserved
# as an indication that, even if None is discovered as one of the union choices, we will not need to wrap
# the final value in another nullable schema.
#
# This is more complicated than just checking for the final outermost schema having type 'nullable' thanks
# to the possible presence of other wrapper schemas such as DefinitionsSchema, WithDefaultSchema, etc.
self._is_nullable = False
# `_choices_to_handle` serves as a stack of choices to add to the tagged union. Initially, choices
# from the union in the wrapped schema will be appended to this list, and the recursive choice-handling
# algorithm may add more choices to this stack as (nested) unions are encountered.
self._choices_to_handle: list[core_schema.CoreSchema] = []
# `_tagged_union_choices` is built during the call to `apply`, and will hold the choices to be included
# in the output TaggedUnionSchema that will replace the union from the input schema
self._tagged_union_choices: dict[Hashable, core_schema.CoreSchema] = {}
# `_used` is changed to True after applying the discriminator to prevent accidental reuse
self._used = False
def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
"""Return a new CoreSchema based on `schema` that uses a tagged-union with the discriminator provided
to this class.
Args:
schema: The input schema.
Returns:
The new core schema.
Raises:
TypeError:
- If `discriminator` is used with invalid union variant.
- If `discriminator` is used with `Union` type with one variant.
- If `discriminator` value mapped to multiple choices.
ValueError:
If the definition for ref is missing.
PydanticUserError:
- If a model in union doesn't have a discriminator field.
- If discriminator field has a non-string alias.
- If discriminator fields have different aliases.
- If discriminator field not of type `Literal`.
"""
assert not self._used
schema = self._apply_to_root(schema)
if self._should_be_nullable and not self._is_nullable:
schema = core_schema.nullable_schema(schema)
self._used = True
return schema
def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
"""This method handles the outer-most stage of recursion over the input schema:
unwrapping nullable or definitions schemas, and calling the `_handle_choice`
method iteratively on the choices extracted (recursively) from the possibly-wrapped union.
"""
if schema['type'] == 'nullable':
self._is_nullable = True
wrapped = self._apply_to_root(schema['schema'])
nullable_wrapper = schema.copy()
nullable_wrapper['schema'] = wrapped
return nullable_wrapper
if schema['type'] == 'definitions':
wrapped = self._apply_to_root(schema['schema'])
definitions_wrapper = schema.copy()
definitions_wrapper['schema'] = wrapped
return definitions_wrapper
if schema['type'] != 'union':
# If the schema is not a union, it probably means it just had a single member and
# was flattened by pydantic_core.
# However, it still may make sense to apply the discriminator to this schema,
# as a way to get discriminated-union-style error messages, so we allow this here.
schema = core_schema.union_schema([schema])
# Reverse the choices list before extending the stack so that they get handled in the order they occur
choices_schemas = [v[0] if isinstance(v, tuple) else v for v in schema['choices'][::-1]]
self._choices_to_handle.extend(choices_schemas)
while self._choices_to_handle:
choice = self._choices_to_handle.pop()
self._handle_choice(choice)
if self._discriminator_alias is not None and self._discriminator_alias != self.discriminator:
# * We need to annotate `discriminator` as a union here to handle both branches of this conditional
# * We need to annotate `discriminator` as list[list[str | int]] and not list[list[str]] due to the
# invariance of list, and because list[list[str | int]] is the type of the discriminator argument
# to tagged_union_schema below
# * See the docstring of pydantic_core.core_schema.tagged_union_schema for more details about how to
# interpret the value of the discriminator argument to tagged_union_schema. (The list[list[str]] here
# is the appropriate way to provide a list of fallback attributes to check for a discriminator value.)
discriminator: str | list[list[str | int]] = [[self.discriminator], [self._discriminator_alias]]
else:
discriminator = self.discriminator
return core_schema.tagged_union_schema(
choices=self._tagged_union_choices,
discriminator=discriminator,
custom_error_type=schema.get('custom_error_type'),
custom_error_message=schema.get('custom_error_message'),
custom_error_context=schema.get('custom_error_context'),
strict=False,
from_attributes=True,
ref=schema.get('ref'),
metadata=schema.get('metadata'),
serialization=schema.get('serialization'),
)
def _handle_choice(self, choice: core_schema.CoreSchema) -> None:
"""This method handles the "middle" stage of recursion over the input schema.
Specifically, it is responsible for handling each choice of the outermost union
(and any "coalesced" choices obtained from inner unions).
Here, "handling" entails:
* Coalescing nested unions and compatible tagged-unions
* Tracking the presence of 'none' and 'nullable' schemas occurring as choices
* Validating that each allowed discriminator value maps to a unique choice
* Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema.
"""
if choice['type'] == 'definition-ref':
if choice['schema_ref'] not in self.definitions:
raise MissingDefinitionForUnionRef(choice['schema_ref'])
if choice['type'] == 'none':
self._should_be_nullable = True
elif choice['type'] == 'definitions':
self._handle_choice(choice['schema'])
elif choice['type'] == 'nullable':
self._should_be_nullable = True
self._handle_choice(choice['schema']) # unwrap the nullable schema
elif choice['type'] == 'union':
# Reverse the choices list before extending the stack so that they get handled in the order they occur
choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]]
self._choices_to_handle.extend(choices_schemas)
elif choice['type'] not in {
'model',
'typed-dict',
'tagged-union',
'lax-or-strict',
'dataclass',
'dataclass-args',
'definition-ref',
} and not _core_utils.is_function_with_inner_schema(choice):
# We should eventually handle 'definition-ref' as well
err_str = f'The core schema type {choice["type"]!r} is not a valid discriminated union variant.'
if choice['type'] == 'list':
err_str += (
' If you are making use of a list of union types, make sure the discriminator is applied to the '
'union type and not the list (e.g. `list[Annotated[<T> | <U>, Field(discriminator=...)]]`).'
)
raise TypeError(err_str)
else:
if choice['type'] == 'tagged-union' and self._is_discriminator_shared(choice):
# In this case, this inner tagged-union is compatible with the outer tagged-union,
# and its choices can be coalesced into the outer TaggedUnionSchema.
subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))]
# Reverse the choices list before extending the stack so that they get handled in the order they occur
self._choices_to_handle.extend(subchoices[::-1])
return
inferred_discriminator_values = self._infer_discriminator_values_for_choice(choice, source_name=None)
self._set_unique_choice_for_values(choice, inferred_discriminator_values)
def _is_discriminator_shared(self, choice: core_schema.TaggedUnionSchema) -> bool:
"""This method returns a boolean indicating whether the discriminator for the `choice`
is the same as that being used for the outermost tagged union. This is used to
determine whether this TaggedUnionSchema choice should be "coalesced" into the top level,
or whether it should be treated as a separate (nested) choice.
"""
inner_discriminator = choice['discriminator']
return inner_discriminator == self.discriminator or (
isinstance(inner_discriminator, list)
and (self.discriminator in inner_discriminator or [self.discriminator] in inner_discriminator)
)
def _infer_discriminator_values_for_choice( # noqa C901
self, choice: core_schema.CoreSchema, source_name: str | None
) -> list[str | int]:
"""This function recurses over `choice`, extracting all discriminator values that should map to this choice.
`model_name` is accepted for the purpose of producing useful error messages.
"""
if choice['type'] == 'definitions':
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
elif _core_utils.is_function_with_inner_schema(choice):
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
elif choice['type'] == 'lax-or-strict':
return sorted(
set(
self._infer_discriminator_values_for_choice(choice['lax_schema'], source_name=None)
+ self._infer_discriminator_values_for_choice(choice['strict_schema'], source_name=None)
)
)
elif choice['type'] == 'tagged-union':
values: list[str | int] = []
# Ignore str/int "choices" since these are just references to other choices
subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))]
for subchoice in subchoices:
subchoice_values = self._infer_discriminator_values_for_choice(subchoice, source_name=None)
values.extend(subchoice_values)
return values
elif choice['type'] == 'union':
values = []
for subchoice in choice['choices']:
subchoice_schema = subchoice[0] if isinstance(subchoice, tuple) else subchoice
subchoice_values = self._infer_discriminator_values_for_choice(subchoice_schema, source_name=None)
values.extend(subchoice_values)
return values
elif choice['type'] == 'nullable':
self._should_be_nullable = True
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=None)
elif choice['type'] == 'model':
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__)
elif choice['type'] == 'dataclass':
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__)
elif choice['type'] == 'model-fields':
return self._infer_discriminator_values_for_model_choice(choice, source_name=source_name)
elif choice['type'] == 'dataclass-args':
return self._infer_discriminator_values_for_dataclass_choice(choice, source_name=source_name)
elif choice['type'] == 'typed-dict':
return self._infer_discriminator_values_for_typed_dict_choice(choice, source_name=source_name)
elif choice['type'] == 'definition-ref':
schema_ref = choice['schema_ref']
if schema_ref not in self.definitions:
raise MissingDefinitionForUnionRef(schema_ref)
return self._infer_discriminator_values_for_choice(self.definitions[schema_ref], source_name=source_name)
else:
err_str = f'The core schema type {choice["type"]!r} is not a valid discriminated union variant.'
if choice['type'] == 'list':
err_str += (
' If you are making use of a list of union types, make sure the discriminator is applied to the '
'union type and not the list (e.g. `list[Annotated[<T> | <U>, Field(discriminator=...)]]`).'
)
raise TypeError(err_str)
def _infer_discriminator_values_for_typed_dict_choice(
self, choice: core_schema.TypedDictSchema, source_name: str | None = None
) -> list[str | int]:
"""This method just extracts the _infer_discriminator_values_for_choice logic specific to TypedDictSchema
for the sake of readability.
"""
source = 'TypedDict' if source_name is None else f'TypedDict {source_name!r}'
field = choice['fields'].get(self.discriminator)
if field is None:
raise PydanticUserError(
f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
)
return self._infer_discriminator_values_for_field(field, source)
def _infer_discriminator_values_for_model_choice(
self, choice: core_schema.ModelFieldsSchema, source_name: str | None = None
) -> list[str | int]:
source = 'ModelFields' if source_name is None else f'Model {source_name!r}'
field = choice['fields'].get(self.discriminator)
if field is None:
raise PydanticUserError(
f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
)
return self._infer_discriminator_values_for_field(field, source)
def _infer_discriminator_values_for_dataclass_choice(
self, choice: core_schema.DataclassArgsSchema, source_name: str | None = None
) -> list[str | int]:
source = 'DataclassArgs' if source_name is None else f'Dataclass {source_name!r}'
for field in choice['fields']:
if field['name'] == self.discriminator:
break
else:
raise PydanticUserError(
f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
)
return self._infer_discriminator_values_for_field(field, source)
def _infer_discriminator_values_for_field(self, field: CoreSchemaField, source: str) -> list[str | int]:
if field['type'] == 'computed-field':
# This should never occur as a discriminator, as it is only relevant to serialization
return []
alias = field.get('validation_alias', self.discriminator)
if not isinstance(alias, str):
raise PydanticUserError(
f'Alias {alias!r} is not supported in a discriminated union', code='discriminator-alias-type'
)
if self._discriminator_alias is None:
self._discriminator_alias = alias
elif self._discriminator_alias != alias:
raise PydanticUserError(
f'Aliases for discriminator {self.discriminator!r} must be the same '
f'(got {alias}, {self._discriminator_alias})',
code='discriminator-alias',
)
return self._infer_discriminator_values_for_inner_schema(field['schema'], source)
def _infer_discriminator_values_for_inner_schema(
self, schema: core_schema.CoreSchema, source: str
) -> list[str | int]:
"""When inferring discriminator values for a field, we typically extract the expected values from a literal
schema. This function does that, but also handles nested unions and defaults.
"""
if schema['type'] == 'literal':
return schema['expected']
elif schema['type'] == 'union':
# Generally when multiple values are allowed they should be placed in a single `Literal`, but
# we add this case to handle the situation where a field is annotated as a `Union` of `Literal`s.
# For example, this lets us handle `Union[Literal['key'], Union[Literal['Key'], Literal['KEY']]]`
values: list[Any] = []
for choice in schema['choices']:
choice_schema = choice[0] if isinstance(choice, tuple) else choice
choice_values = self._infer_discriminator_values_for_inner_schema(choice_schema, source)
values.extend(choice_values)
return values
elif schema['type'] == 'default':
# This will happen if the field has a default value; we ignore it while extracting the discriminator values
return self._infer_discriminator_values_for_inner_schema(schema['schema'], source)
elif schema['type'] == 'function-after':
# After validators don't affect the discriminator values
return self._infer_discriminator_values_for_inner_schema(schema['schema'], source)
elif schema['type'] in {'function-before', 'function-wrap', 'function-plain'}:
validator_type = repr(schema['type'].split('-')[1])
raise PydanticUserError(
f'Cannot use a mode={validator_type} validator in the'
f' discriminator field {self.discriminator!r} of {source}',
code='discriminator-validator',
)
else:
raise PydanticUserError(
f'{source} needs field {self.discriminator!r} to be of type `Literal`',
code='discriminator-needs-literal',
)
def _set_unique_choice_for_values(self, choice: core_schema.CoreSchema, values: Sequence[str | int]) -> None:
"""This method updates `self.tagged_union_choices` so that all provided (discriminator) `values` map to the
provided `choice`, validating that none of these values already map to another (different) choice.
"""
for discriminator_value in values:
if discriminator_value in self._tagged_union_choices:
# It is okay if `value` is already in tagged_union_choices as long as it maps to the same value.
# Because tagged_union_choices may map values to other values, we need to walk the choices dict
# until we get to a "real" choice, and confirm that is equal to the one assigned.
existing_choice = self._tagged_union_choices[discriminator_value]
if existing_choice != choice:
raise TypeError(
f'Value {discriminator_value!r} for discriminator '
f'{self.discriminator!r} mapped to multiple choices'
)
else:
self._tagged_union_choices[discriminator_value] = choice

View File

@ -0,0 +1,108 @@
"""Utilities related to attribute docstring extraction."""
from __future__ import annotations
import ast
import inspect
import textwrap
from typing import Any
class DocstringVisitor(ast.NodeVisitor):
def __init__(self) -> None:
super().__init__()
self.target: str | None = None
self.attrs: dict[str, str] = {}
self.previous_node_type: type[ast.AST] | None = None
def visit(self, node: ast.AST) -> Any:
node_result = super().visit(node)
self.previous_node_type = type(node)
return node_result
def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
if isinstance(node.target, ast.Name):
self.target = node.target.id
def visit_Expr(self, node: ast.Expr) -> Any:
if (
isinstance(node.value, ast.Constant)
and isinstance(node.value.value, str)
and self.previous_node_type is ast.AnnAssign
):
docstring = inspect.cleandoc(node.value.value)
if self.target:
self.attrs[self.target] = docstring
self.target = None
def _dedent_source_lines(source: list[str]) -> str:
# Required for nested class definitions, e.g. in a function block
dedent_source = textwrap.dedent(''.join(source))
if dedent_source.startswith((' ', '\t')):
# We are in the case where there's a dedented (usually multiline) string
# at a lower indentation level than the class itself. We wrap our class
# in a function as a workaround.
dedent_source = f'def dedent_workaround():\n{dedent_source}'
return dedent_source
def _extract_source_from_frame(cls: type[Any]) -> list[str] | None:
frame = inspect.currentframe()
while frame:
if inspect.getmodule(frame) is inspect.getmodule(cls):
lnum = frame.f_lineno
try:
lines, _ = inspect.findsource(frame)
except OSError: # pragma: no cover
# Source can't be retrieved (maybe because running in an interactive terminal),
# we don't want to error here.
pass
else:
block_lines = inspect.getblock(lines[lnum - 1 :])
dedent_source = _dedent_source_lines(block_lines)
try:
block_tree = ast.parse(dedent_source)
except SyntaxError:
pass
else:
stmt = block_tree.body[0]
if isinstance(stmt, ast.FunctionDef) and stmt.name == 'dedent_workaround':
# `_dedent_source_lines` wrapped the class around the workaround function
stmt = stmt.body[0]
if isinstance(stmt, ast.ClassDef) and stmt.name == cls.__name__:
return block_lines
frame = frame.f_back
def extract_docstrings_from_cls(cls: type[Any], use_inspect: bool = False) -> dict[str, str]:
"""Map model attributes and their corresponding docstring.
Args:
cls: The class of the Pydantic model to inspect.
use_inspect: Whether to skip usage of frames to find the object and use
the `inspect` module instead.
Returns:
A mapping containing attribute names and their corresponding docstring.
"""
if use_inspect:
# Might not work as expected if two classes have the same name in the same source file.
try:
source, _ = inspect.getsourcelines(cls)
except OSError: # pragma: no cover
return {}
else:
source = _extract_source_from_frame(cls)
if not source:
return {}
dedent_source = _dedent_source_lines(source)
visitor = DocstringVisitor()
visitor.visit(ast.parse(dedent_source))
return visitor.attrs

View File

@ -0,0 +1,459 @@
"""Private logic related to fields (the `Field()` function and `FieldInfo` class), and arguments to `Annotated`."""
from __future__ import annotations as _annotations
import dataclasses
import warnings
from collections.abc import Mapping
from copy import copy
from functools import cache
from inspect import Parameter, ismethoddescriptor, signature
from re import Pattern
from typing import TYPE_CHECKING, Any, Callable, TypeVar
from pydantic_core import PydanticUndefined
from typing_extensions import TypeIs, get_origin
from typing_inspection import typing_objects
from typing_inspection.introspection import AnnotationSource
from pydantic import PydanticDeprecatedSince211
from pydantic.errors import PydanticUserError
from . import _generics, _typing_extra
from ._config import ConfigWrapper
from ._docs_extraction import extract_docstrings_from_cls
from ._import_utils import import_cached_base_model, import_cached_field_info
from ._namespace_utils import NsResolver
from ._repr import Representation
from ._utils import can_be_positional
if TYPE_CHECKING:
from annotated_types import BaseMetadata
from ..fields import FieldInfo
from ..main import BaseModel
from ._dataclasses import StandardDataclass
from ._decorators import DecoratorInfos
class PydanticMetadata(Representation):
"""Base class for annotation markers like `Strict`."""
__slots__ = ()
def pydantic_general_metadata(**metadata: Any) -> BaseMetadata:
"""Create a new `_PydanticGeneralMetadata` class with the given metadata.
Args:
**metadata: The metadata to add.
Returns:
The new `_PydanticGeneralMetadata` class.
"""
return _general_metadata_cls()(metadata) # type: ignore
@cache
def _general_metadata_cls() -> type[BaseMetadata]:
"""Do it this way to avoid importing `annotated_types` at import time."""
from annotated_types import BaseMetadata
class _PydanticGeneralMetadata(PydanticMetadata, BaseMetadata):
"""Pydantic general metadata like `max_digits`."""
def __init__(self, metadata: Any):
self.__dict__ = metadata
return _PydanticGeneralMetadata # type: ignore
def _update_fields_from_docstrings(cls: type[Any], fields: dict[str, FieldInfo], use_inspect: bool = False) -> None:
fields_docs = extract_docstrings_from_cls(cls, use_inspect=use_inspect)
for ann_name, field_info in fields.items():
if field_info.description is None and ann_name in fields_docs:
field_info.description = fields_docs[ann_name]
def collect_model_fields( # noqa: C901
cls: type[BaseModel],
config_wrapper: ConfigWrapper,
ns_resolver: NsResolver | None,
*,
typevars_map: Mapping[TypeVar, Any] | None = None,
) -> tuple[dict[str, FieldInfo], set[str]]:
"""Collect the fields and class variables names of a nascent Pydantic model.
The fields collection process is *lenient*, meaning it won't error if string annotations
fail to evaluate. If this happens, the original annotation (and assigned value, if any)
is stored on the created `FieldInfo` instance.
The `rebuild_model_fields()` should be called at a later point (e.g. when rebuilding the model),
and will make use of these stored attributes.
Args:
cls: BaseModel or dataclass.
config_wrapper: The config wrapper instance.
ns_resolver: Namespace resolver to use when getting model annotations.
typevars_map: A dictionary mapping type variables to their concrete types.
Returns:
A two-tuple containing model fields and class variables names.
Raises:
NameError:
- If there is a conflict between a field name and protected namespaces.
- If there is a field other than `root` in `RootModel`.
- If a field shadows an attribute in the parent model.
"""
BaseModel = import_cached_base_model()
FieldInfo_ = import_cached_field_info()
bases = cls.__bases__
parent_fields_lookup: dict[str, FieldInfo] = {}
for base in reversed(bases):
if model_fields := getattr(base, '__pydantic_fields__', None):
parent_fields_lookup.update(model_fields)
type_hints = _typing_extra.get_model_type_hints(cls, ns_resolver=ns_resolver)
# https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
# annotations is only used for finding fields in parent classes
annotations = cls.__dict__.get('__annotations__', {})
fields: dict[str, FieldInfo] = {}
class_vars: set[str] = set()
for ann_name, (ann_type, evaluated) in type_hints.items():
if ann_name == 'model_config':
# We never want to treat `model_config` as a field
# Note: we may need to change this logic if/when we introduce a `BareModel` class with no
# protected namespaces (where `model_config` might be allowed as a field name)
continue
for protected_namespace in config_wrapper.protected_namespaces:
ns_violation: bool = False
if isinstance(protected_namespace, Pattern):
ns_violation = protected_namespace.match(ann_name) is not None
elif isinstance(protected_namespace, str):
ns_violation = ann_name.startswith(protected_namespace)
if ns_violation:
for b in bases:
if hasattr(b, ann_name):
if not (issubclass(b, BaseModel) and ann_name in getattr(b, '__pydantic_fields__', {})):
raise NameError(
f'Field "{ann_name}" conflicts with member {getattr(b, ann_name)}'
f' of protected namespace "{protected_namespace}".'
)
else:
valid_namespaces = ()
for pn in config_wrapper.protected_namespaces:
if isinstance(pn, Pattern):
if not pn.match(ann_name):
valid_namespaces += (f're.compile({pn.pattern})',)
else:
if not ann_name.startswith(pn):
valid_namespaces += (pn,)
warnings.warn(
f'Field "{ann_name}" in {cls.__name__} has conflict with protected namespace "{protected_namespace}".'
'\n\nYou may be able to resolve this warning by setting'
f" `model_config['protected_namespaces'] = {valid_namespaces}`.",
UserWarning,
)
if _typing_extra.is_classvar_annotation(ann_type):
class_vars.add(ann_name)
continue
assigned_value = getattr(cls, ann_name, PydanticUndefined)
if not is_valid_field_name(ann_name):
continue
if cls.__pydantic_root_model__ and ann_name != 'root':
raise NameError(
f"Unexpected field with name {ann_name!r}; only 'root' is allowed as a field of a `RootModel`"
)
# when building a generic model with `MyModel[int]`, the generic_origin check makes sure we don't get
# "... shadows an attribute" warnings
generic_origin = getattr(cls, '__pydantic_generic_metadata__', {}).get('origin')
for base in bases:
dataclass_fields = {
field.name for field in (dataclasses.fields(base) if dataclasses.is_dataclass(base) else ())
}
if hasattr(base, ann_name):
if base is generic_origin:
# Don't warn when "shadowing" of attributes in parametrized generics
continue
if ann_name in dataclass_fields:
# Don't warn when inheriting stdlib dataclasses whose fields are "shadowed" by defaults being set
# on the class instance.
continue
if ann_name not in annotations:
# Don't warn when a field exists in a parent class but has not been defined in the current class
continue
warnings.warn(
f'Field name "{ann_name}" in "{cls.__qualname__}" shadows an attribute in parent '
f'"{base.__qualname__}"',
UserWarning,
)
if assigned_value is PydanticUndefined: # no assignment, just a plain annotation
if ann_name in annotations:
# field is present in the current model's annotations (and *not* from parent classes)
field_info = FieldInfo_.from_annotation(ann_type, _source=AnnotationSource.CLASS)
elif ann_name in parent_fields_lookup:
# The field was present on one of the (possibly multiple) base classes
# copy the field to make sure typevar substitutions don't cause issues with the base classes
field_info = copy(parent_fields_lookup[ann_name])
else:
# The field was not found on any base classes; this seems to be caused by fields not getting
# generated thanks to models not being fully defined while initializing recursive models.
# Nothing stops us from just creating a new FieldInfo for this type hint, so we do this.
field_info = FieldInfo_.from_annotation(ann_type, _source=AnnotationSource.CLASS)
if not evaluated:
field_info._complete = False
# Store the original annotation that should be used to rebuild
# the field info later:
field_info._original_annotation = ann_type
else: # An assigned value is present (either the default value, or a `Field()` function)
_warn_on_nested_alias_in_annotation(ann_type, ann_name)
if isinstance(assigned_value, FieldInfo_) and ismethoddescriptor(assigned_value.default):
# `assigned_value` was fetched using `getattr`, which triggers a call to `__get__`
# for descriptors, so we do the same if the `= field(default=...)` form is used.
# Note that we only do this for method descriptors for now, we might want to
# extend this to any descriptor in the future (by simply checking for
# `hasattr(assigned_value.default, '__get__')`).
assigned_value.default = assigned_value.default.__get__(None, cls)
# The `from_annotated_attribute()` call below mutates the assigned `Field()`, so make a copy:
original_assignment = (
copy(assigned_value) if not evaluated and isinstance(assigned_value, FieldInfo_) else assigned_value
)
field_info = FieldInfo_.from_annotated_attribute(ann_type, assigned_value, _source=AnnotationSource.CLASS)
if not evaluated:
field_info._complete = False
# Store the original annotation and assignment value that should be used to rebuild
# the field info later:
field_info._original_annotation = ann_type
field_info._original_assignment = original_assignment
elif 'final' in field_info._qualifiers and not field_info.is_required():
warnings.warn(
f'Annotation {ann_name!r} is marked as final and has a default value. Pydantic treats {ann_name!r} as a '
'class variable, but it will be considered as a normal field in V3 to be aligned with dataclasses. If you '
f'still want {ann_name!r} to be considered as a class variable, annotate it as: `ClassVar[<type>] = <default>.`',
category=PydanticDeprecatedSince211,
# Incorrect when `create_model` is used, but the chance that final with a default is used is low in that case:
stacklevel=4,
)
class_vars.add(ann_name)
continue
# attributes which are fields are removed from the class namespace:
# 1. To match the behaviour of annotation-only fields
# 2. To avoid false positives in the NameError check above
try:
delattr(cls, ann_name)
except AttributeError:
pass # indicates the attribute was on a parent class
# Use cls.__dict__['__pydantic_decorators__'] instead of cls.__pydantic_decorators__
# to make sure the decorators have already been built for this exact class
decorators: DecoratorInfos = cls.__dict__['__pydantic_decorators__']
if ann_name in decorators.computed_fields:
raise TypeError(
f'Field {ann_name!r} of class {cls.__name__!r} overrides symbol of same name in a parent class. '
'This override with a computed_field is incompatible.'
)
fields[ann_name] = field_info
if typevars_map:
for field in fields.values():
if field._complete:
field.apply_typevars_map(typevars_map)
if config_wrapper.use_attribute_docstrings:
_update_fields_from_docstrings(cls, fields)
return fields, class_vars
def _warn_on_nested_alias_in_annotation(ann_type: type[Any], ann_name: str) -> None:
FieldInfo = import_cached_field_info()
args = getattr(ann_type, '__args__', None)
if args:
for anno_arg in args:
if typing_objects.is_annotated(get_origin(anno_arg)):
for anno_type_arg in _typing_extra.get_args(anno_arg):
if isinstance(anno_type_arg, FieldInfo) and anno_type_arg.alias is not None:
warnings.warn(
f'`alias` specification on field "{ann_name}" must be set on outermost annotation to take effect.',
UserWarning,
)
return
def rebuild_model_fields(
cls: type[BaseModel],
*,
ns_resolver: NsResolver,
typevars_map: Mapping[TypeVar, Any],
) -> dict[str, FieldInfo]:
"""Rebuild the (already present) model fields by trying to reevaluate annotations.
This function should be called whenever a model with incomplete fields is encountered.
Note:
This function *doesn't* mutate the model fields in place, as it can be called during
schema generation, where you don't want to mutate other model's fields.
"""
FieldInfo_ = import_cached_field_info()
rebuilt_fields: dict[str, FieldInfo] = {}
with ns_resolver.push(cls):
for f_name, field_info in cls.__pydantic_fields__.items():
if field_info._complete:
rebuilt_fields[f_name] = field_info
else:
ann = _typing_extra.eval_type(
field_info._original_annotation,
*ns_resolver.types_namespace,
)
ann = _generics.replace_types(ann, typevars_map)
if (assign := field_info._original_assignment) is PydanticUndefined:
rebuilt_fields[f_name] = FieldInfo_.from_annotation(ann, _source=AnnotationSource.CLASS)
else:
rebuilt_fields[f_name] = FieldInfo_.from_annotated_attribute(
ann, assign, _source=AnnotationSource.CLASS
)
return rebuilt_fields
def collect_dataclass_fields(
cls: type[StandardDataclass],
*,
ns_resolver: NsResolver | None = None,
typevars_map: dict[Any, Any] | None = None,
config_wrapper: ConfigWrapper | None = None,
) -> dict[str, FieldInfo]:
"""Collect the fields of a dataclass.
Args:
cls: dataclass.
ns_resolver: Namespace resolver to use when getting dataclass annotations.
Defaults to an empty instance.
typevars_map: A dictionary mapping type variables to their concrete types.
config_wrapper: The config wrapper instance.
Returns:
The dataclass fields.
"""
FieldInfo_ = import_cached_field_info()
fields: dict[str, FieldInfo] = {}
ns_resolver = ns_resolver or NsResolver()
dataclass_fields = cls.__dataclass_fields__
# The logic here is similar to `_typing_extra.get_cls_type_hints`,
# although we do it manually as stdlib dataclasses already have annotations
# collected in each class:
for base in reversed(cls.__mro__):
if not dataclasses.is_dataclass(base):
continue
with ns_resolver.push(base):
for ann_name, dataclass_field in dataclass_fields.items():
if ann_name not in base.__dict__.get('__annotations__', {}):
# `__dataclass_fields__`contains every field, even the ones from base classes.
# Only collect the ones defined on `base`.
continue
globalns, localns = ns_resolver.types_namespace
ann_type, _ = _typing_extra.try_eval_type(dataclass_field.type, globalns, localns)
if _typing_extra.is_classvar_annotation(ann_type):
continue
if (
not dataclass_field.init
and dataclass_field.default is dataclasses.MISSING
and dataclass_field.default_factory is dataclasses.MISSING
):
# TODO: We should probably do something with this so that validate_assignment behaves properly
# Issue: https://github.com/pydantic/pydantic/issues/5470
continue
if isinstance(dataclass_field.default, FieldInfo_):
if dataclass_field.default.init_var:
if dataclass_field.default.init is False:
raise PydanticUserError(
f'Dataclass field {ann_name} has init=False and init_var=True, but these are mutually exclusive.',
code='clashing-init-and-init-var',
)
# TODO: same note as above re validate_assignment
continue
field_info = FieldInfo_.from_annotated_attribute(
ann_type, dataclass_field.default, _source=AnnotationSource.DATACLASS
)
else:
field_info = FieldInfo_.from_annotated_attribute(
ann_type, dataclass_field, _source=AnnotationSource.DATACLASS
)
fields[ann_name] = field_info
if field_info.default is not PydanticUndefined and isinstance(
getattr(cls, ann_name, field_info), FieldInfo_
):
# We need this to fix the default when the "default" from __dataclass_fields__ is a pydantic.FieldInfo
setattr(cls, ann_name, field_info.default)
if typevars_map:
for field in fields.values():
# We don't pass any ns, as `field.annotation`
# was already evaluated. TODO: is this method relevant?
# Can't we juste use `_generics.replace_types`?
field.apply_typevars_map(typevars_map)
if config_wrapper is not None and config_wrapper.use_attribute_docstrings:
_update_fields_from_docstrings(
cls,
fields,
# We can't rely on the (more reliable) frame inspection method
# for stdlib dataclasses:
use_inspect=not hasattr(cls, '__is_pydantic_dataclass__'),
)
return fields
def is_valid_field_name(name: str) -> bool:
return not name.startswith('_')
def is_valid_privateattr_name(name: str) -> bool:
return name.startswith('_') and not name.startswith('__')
def takes_validated_data_argument(
default_factory: Callable[[], Any] | Callable[[dict[str, Any]], Any],
) -> TypeIs[Callable[[dict[str, Any]], Any]]:
"""Whether the provided default factory callable has a validated data parameter."""
try:
sig = signature(default_factory)
except (ValueError, TypeError):
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
# In this case, we assume no data argument is present:
return False
parameters = list(sig.parameters.values())
return len(parameters) == 1 and can_be_positional(parameters[0]) and parameters[0].default is Parameter.empty

View File

@ -0,0 +1,23 @@
from __future__ import annotations as _annotations
from dataclasses import dataclass
from typing import Union
@dataclass
class PydanticRecursiveRef:
type_ref: str
__name__ = 'PydanticRecursiveRef'
__hash__ = object.__hash__
def __call__(self) -> None:
"""Defining __call__ is necessary for the `typing` module to let you use an instance of
this class as the result of resolving a standard ForwardRef.
"""
def __or__(self, other):
return Union[self, other] # type: ignore
def __ror__(self, other):
return Union[other, self] # type: ignore

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,547 @@
from __future__ import annotations
import sys
import types
import typing
from collections import ChainMap
from collections.abc import Iterator, Mapping
from contextlib import contextmanager
from contextvars import ContextVar
from itertools import zip_longest
from types import prepare_class
from typing import TYPE_CHECKING, Annotated, Any, TypeVar
from weakref import WeakValueDictionary
import typing_extensions
from typing_inspection import typing_objects
from typing_inspection.introspection import is_union_origin
from . import _typing_extra
from ._core_utils import get_type_ref
from ._forward_ref import PydanticRecursiveRef
from ._utils import all_identical, is_model_class
if sys.version_info >= (3, 10):
from typing import _UnionGenericAlias # type: ignore[attr-defined]
if TYPE_CHECKING:
from ..main import BaseModel
GenericTypesCacheKey = tuple[Any, Any, tuple[Any, ...]]
# Note: We want to remove LimitedDict, but to do this, we'd need to improve the handling of generics caching.
# Right now, to handle recursive generics, we some types must remain cached for brief periods without references.
# By chaining the WeakValuesDict with a LimitedDict, we have a way to retain caching for all types with references,
# while also retaining a limited number of types even without references. This is generally enough to build
# specific recursive generic models without losing required items out of the cache.
KT = TypeVar('KT')
VT = TypeVar('VT')
_LIMITED_DICT_SIZE = 100
class LimitedDict(dict[KT, VT]):
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE) -> None:
self.size_limit = size_limit
super().__init__()
def __setitem__(self, key: KT, value: VT, /) -> None:
super().__setitem__(key, value)
if len(self) > self.size_limit:
excess = len(self) - self.size_limit + self.size_limit // 10
to_remove = list(self.keys())[:excess]
for k in to_remove:
del self[k]
# weak dictionaries allow the dynamically created parametrized versions of generic models to get collected
# once they are no longer referenced by the caller.
GenericTypesCache = WeakValueDictionary[GenericTypesCacheKey, 'type[BaseModel]']
if TYPE_CHECKING:
class DeepChainMap(ChainMap[KT, VT]): # type: ignore
...
else:
class DeepChainMap(ChainMap):
"""Variant of ChainMap that allows direct updates to inner scopes.
Taken from https://docs.python.org/3/library/collections.html#collections.ChainMap,
with some light modifications for this use case.
"""
def clear(self) -> None:
for mapping in self.maps:
mapping.clear()
def __setitem__(self, key: KT, value: VT) -> None:
for mapping in self.maps:
mapping[key] = value
def __delitem__(self, key: KT) -> None:
hit = False
for mapping in self.maps:
if key in mapping:
del mapping[key]
hit = True
if not hit:
raise KeyError(key)
# Despite the fact that LimitedDict _seems_ no longer necessary, I'm very nervous to actually remove it
# and discover later on that we need to re-add all this infrastructure...
# _GENERIC_TYPES_CACHE = DeepChainMap(GenericTypesCache(), LimitedDict())
_GENERIC_TYPES_CACHE: ContextVar[GenericTypesCache | None] = ContextVar('_GENERIC_TYPES_CACHE', default=None)
class PydanticGenericMetadata(typing_extensions.TypedDict):
origin: type[BaseModel] | None # analogous to typing._GenericAlias.__origin__
args: tuple[Any, ...] # analogous to typing._GenericAlias.__args__
parameters: tuple[TypeVar, ...] # analogous to typing.Generic.__parameters__
def create_generic_submodel(
model_name: str, origin: type[BaseModel], args: tuple[Any, ...], params: tuple[Any, ...]
) -> type[BaseModel]:
"""Dynamically create a submodel of a provided (generic) BaseModel.
This is used when producing concrete parametrizations of generic models. This function
only *creates* the new subclass; the schema/validators/serialization must be updated to
reflect a concrete parametrization elsewhere.
Args:
model_name: The name of the newly created model.
origin: The base class for the new model to inherit from.
args: A tuple of generic metadata arguments.
params: A tuple of generic metadata parameters.
Returns:
The created submodel.
"""
namespace: dict[str, Any] = {'__module__': origin.__module__}
bases = (origin,)
meta, ns, kwds = prepare_class(model_name, bases)
namespace.update(ns)
created_model = meta(
model_name,
bases,
namespace,
__pydantic_generic_metadata__={
'origin': origin,
'args': args,
'parameters': params,
},
__pydantic_reset_parent_namespace__=False,
**kwds,
)
model_module, called_globally = _get_caller_frame_info(depth=3)
if called_globally: # create global reference and therefore allow pickling
object_by_reference = None
reference_name = model_name
reference_module_globals = sys.modules[created_model.__module__].__dict__
while object_by_reference is not created_model:
object_by_reference = reference_module_globals.setdefault(reference_name, created_model)
reference_name += '_'
return created_model
def _get_caller_frame_info(depth: int = 2) -> tuple[str | None, bool]:
"""Used inside a function to check whether it was called globally.
Args:
depth: The depth to get the frame.
Returns:
A tuple contains `module_name` and `called_globally`.
Raises:
RuntimeError: If the function is not called inside a function.
"""
try:
previous_caller_frame = sys._getframe(depth)
except ValueError as e:
raise RuntimeError('This function must be used inside another function') from e
except AttributeError: # sys module does not have _getframe function, so there's nothing we can do about it
return None, False
frame_globals = previous_caller_frame.f_globals
return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals
DictValues: type[Any] = {}.values().__class__
def iter_contained_typevars(v: Any) -> Iterator[TypeVar]:
"""Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found.
This is inspired as an alternative to directly accessing the `__parameters__` attribute of a GenericAlias,
since __parameters__ of (nested) generic BaseModel subclasses won't show up in that list.
"""
if isinstance(v, TypeVar):
yield v
elif is_model_class(v):
yield from v.__pydantic_generic_metadata__['parameters']
elif isinstance(v, (DictValues, list)):
for var in v:
yield from iter_contained_typevars(var)
else:
args = get_args(v)
for arg in args:
yield from iter_contained_typevars(arg)
def get_args(v: Any) -> Any:
pydantic_generic_metadata: PydanticGenericMetadata | None = getattr(v, '__pydantic_generic_metadata__', None)
if pydantic_generic_metadata:
return pydantic_generic_metadata.get('args')
return typing_extensions.get_args(v)
def get_origin(v: Any) -> Any:
pydantic_generic_metadata: PydanticGenericMetadata | None = getattr(v, '__pydantic_generic_metadata__', None)
if pydantic_generic_metadata:
return pydantic_generic_metadata.get('origin')
return typing_extensions.get_origin(v)
def get_standard_typevars_map(cls: Any) -> dict[TypeVar, Any] | None:
"""Package a generic type's typevars and parametrization (if present) into a dictionary compatible with the
`replace_types` function. Specifically, this works with standard typing generics and typing._GenericAlias.
"""
origin = get_origin(cls)
if origin is None:
return None
if not hasattr(origin, '__parameters__'):
return None
# In this case, we know that cls is a _GenericAlias, and origin is the generic type
# So it is safe to access cls.__args__ and origin.__parameters__
args: tuple[Any, ...] = cls.__args__ # type: ignore
parameters: tuple[TypeVar, ...] = origin.__parameters__
return dict(zip(parameters, args))
def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVar, Any]:
"""Package a generic BaseModel's typevars and concrete parametrization (if present) into a dictionary compatible
with the `replace_types` function.
Since BaseModel.__class_getitem__ does not produce a typing._GenericAlias, and the BaseModel generic info is
stored in the __pydantic_generic_metadata__ attribute, we need special handling here.
"""
# TODO: This could be unified with `get_standard_typevars_map` if we stored the generic metadata
# in the __origin__, __args__, and __parameters__ attributes of the model.
generic_metadata = cls.__pydantic_generic_metadata__
origin = generic_metadata['origin']
args = generic_metadata['args']
if not args:
# No need to go into `iter_contained_typevars`:
return {}
return dict(zip(iter_contained_typevars(origin), args))
def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any:
"""Return type with all occurrences of `type_map` keys recursively replaced with their values.
Args:
type_: The class or generic alias.
type_map: Mapping from `TypeVar` instance to concrete types.
Returns:
A new type representing the basic structure of `type_` with all
`typevar_map` keys recursively replaced.
Example:
```python
from typing import List, Union
from pydantic._internal._generics import replace_types
replace_types(tuple[str, Union[List[str], float]], {str: int})
#> tuple[int, Union[List[int], float]]
```
"""
if not type_map:
return type_
type_args = get_args(type_)
origin_type = get_origin(type_)
if typing_objects.is_annotated(origin_type):
annotated_type, *annotations = type_args
annotated_type = replace_types(annotated_type, type_map)
# TODO remove parentheses when we drop support for Python 3.10:
return Annotated[(annotated_type, *annotations)]
# Having type args is a good indicator that this is a typing special form
# instance or a generic alias of some sort.
if type_args:
resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
if all_identical(type_args, resolved_type_args):
# If all arguments are the same, there is no need to modify the
# type or create a new object at all
return type_
if (
origin_type is not None
and isinstance(type_, _typing_extra.typing_base)
and not isinstance(origin_type, _typing_extra.typing_base)
and getattr(type_, '_name', None) is not None
):
# In python < 3.9 generic aliases don't exist so any of these like `list`,
# `type` or `collections.abc.Callable` need to be translated.
# See: https://www.python.org/dev/peps/pep-0585
origin_type = getattr(typing, type_._name)
assert origin_type is not None
if is_union_origin(origin_type):
if any(typing_objects.is_any(arg) for arg in resolved_type_args):
# `Any | T` ~ `Any`:
resolved_type_args = (Any,)
# `Never | T` ~ `T`:
resolved_type_args = tuple(
arg
for arg in resolved_type_args
if not (typing_objects.is_noreturn(arg) or typing_objects.is_never(arg))
)
# PEP-604 syntax (Ex.: list | str) is represented with a types.UnionType object that does not have __getitem__.
# We also cannot use isinstance() since we have to compare types.
if sys.version_info >= (3, 10) and origin_type is types.UnionType:
return _UnionGenericAlias(origin_type, resolved_type_args)
# NotRequired[T] and Required[T] don't support tuple type resolved_type_args, hence the condition below
return origin_type[resolved_type_args[0] if len(resolved_type_args) == 1 else resolved_type_args]
# We handle pydantic generic models separately as they don't have the same
# semantics as "typing" classes or generic aliases
if not origin_type and is_model_class(type_):
parameters = type_.__pydantic_generic_metadata__['parameters']
if not parameters:
return type_
resolved_type_args = tuple(replace_types(t, type_map) for t in parameters)
if all_identical(parameters, resolved_type_args):
return type_
return type_[resolved_type_args]
# Handle special case for typehints that can have lists as arguments.
# `typing.Callable[[int, str], int]` is an example for this.
if isinstance(type_, list):
resolved_list = [replace_types(element, type_map) for element in type_]
if all_identical(type_, resolved_list):
return type_
return resolved_list
# If all else fails, we try to resolve the type directly and otherwise just
# return the input with no modifications.
return type_map.get(type_, type_)
def map_generic_model_arguments(cls: type[BaseModel], args: tuple[Any, ...]) -> dict[TypeVar, Any]:
"""Return a mapping between the parameters of a generic model and the provided arguments during parameterization.
Raises:
TypeError: If the number of arguments does not match the parameters (i.e. if providing too few or too many arguments).
Example:
```python {test="skip" lint="skip"}
class Model[T, U, V = int](BaseModel): ...
map_generic_model_arguments(Model, (str, bytes))
#> {T: str, U: bytes, V: int}
map_generic_model_arguments(Model, (str,))
#> TypeError: Too few arguments for <class '__main__.Model'>; actual 1, expected at least 2
map_generic_model_arguments(Model, (str, bytes, int, complex))
#> TypeError: Too many arguments for <class '__main__.Model'>; actual 4, expected 3
```
Note:
This function is analogous to the private `typing._check_generic_specialization` function.
"""
parameters = cls.__pydantic_generic_metadata__['parameters']
expected_len = len(parameters)
typevars_map: dict[TypeVar, Any] = {}
_missing = object()
for parameter, argument in zip_longest(parameters, args, fillvalue=_missing):
if parameter is _missing:
raise TypeError(f'Too many arguments for {cls}; actual {len(args)}, expected {expected_len}')
if argument is _missing:
param = typing.cast(TypeVar, parameter)
try:
has_default = param.has_default()
except AttributeError:
# Happens if using `typing.TypeVar` (and not `typing_extensions`) on Python < 3.13.
has_default = False
if has_default:
# The default might refer to other type parameters. For an example, see:
# https://typing.readthedocs.io/en/latest/spec/generics.html#type-parameters-as-parameters-to-generics
typevars_map[param] = replace_types(param.__default__, typevars_map)
else:
expected_len -= sum(hasattr(p, 'has_default') and p.has_default() for p in parameters)
raise TypeError(f'Too few arguments for {cls}; actual {len(args)}, expected at least {expected_len}')
else:
param = typing.cast(TypeVar, parameter)
typevars_map[param] = argument
return typevars_map
_generic_recursion_cache: ContextVar[set[str] | None] = ContextVar('_generic_recursion_cache', default=None)
@contextmanager
def generic_recursion_self_type(
origin: type[BaseModel], args: tuple[Any, ...]
) -> Iterator[PydanticRecursiveRef | None]:
"""This contextmanager should be placed around the recursive calls used to build a generic type,
and accept as arguments the generic origin type and the type arguments being passed to it.
If the same origin and arguments are observed twice, it implies that a self-reference placeholder
can be used while building the core schema, and will produce a schema_ref that will be valid in the
final parent schema.
"""
previously_seen_type_refs = _generic_recursion_cache.get()
if previously_seen_type_refs is None:
previously_seen_type_refs = set()
token = _generic_recursion_cache.set(previously_seen_type_refs)
else:
token = None
try:
type_ref = get_type_ref(origin, args_override=args)
if type_ref in previously_seen_type_refs:
self_type = PydanticRecursiveRef(type_ref=type_ref)
yield self_type
else:
previously_seen_type_refs.add(type_ref)
yield
previously_seen_type_refs.remove(type_ref)
finally:
if token:
_generic_recursion_cache.reset(token)
def recursively_defined_type_refs() -> set[str]:
visited = _generic_recursion_cache.get()
if not visited:
return set() # not in a generic recursion, so there are no types
return visited.copy() # don't allow modifications
def get_cached_generic_type_early(parent: type[BaseModel], typevar_values: Any) -> type[BaseModel] | None:
"""The use of a two-stage cache lookup approach was necessary to have the highest performance possible for
repeated calls to `__class_getitem__` on generic types (which may happen in tighter loops during runtime),
while still ensuring that certain alternative parametrizations ultimately resolve to the same type.
As a concrete example, this approach was necessary to make Model[List[T]][int] equal to Model[List[int]].
The approach could be modified to not use two different cache keys at different points, but the
_early_cache_key is optimized to be as quick to compute as possible (for repeated-access speed), and the
_late_cache_key is optimized to be as "correct" as possible, so that two types that will ultimately be the
same after resolving the type arguments will always produce cache hits.
If we wanted to move to only using a single cache key per type, we would either need to always use the
slower/more computationally intensive logic associated with _late_cache_key, or would need to accept
that Model[List[T]][int] is a different type than Model[List[T]][int]. Because we rely on subclass relationships
during validation, I think it is worthwhile to ensure that types that are functionally equivalent are actually
equal.
"""
generic_types_cache = _GENERIC_TYPES_CACHE.get()
if generic_types_cache is None:
generic_types_cache = GenericTypesCache()
_GENERIC_TYPES_CACHE.set(generic_types_cache)
return generic_types_cache.get(_early_cache_key(parent, typevar_values))
def get_cached_generic_type_late(
parent: type[BaseModel], typevar_values: Any, origin: type[BaseModel], args: tuple[Any, ...]
) -> type[BaseModel] | None:
"""See the docstring of `get_cached_generic_type_early` for more information about the two-stage cache lookup."""
generic_types_cache = _GENERIC_TYPES_CACHE.get()
if (
generic_types_cache is None
): # pragma: no cover (early cache is guaranteed to run first and initialize the cache)
generic_types_cache = GenericTypesCache()
_GENERIC_TYPES_CACHE.set(generic_types_cache)
cached = generic_types_cache.get(_late_cache_key(origin, args, typevar_values))
if cached is not None:
set_cached_generic_type(parent, typevar_values, cached, origin, args)
return cached
def set_cached_generic_type(
parent: type[BaseModel],
typevar_values: tuple[Any, ...],
type_: type[BaseModel],
origin: type[BaseModel] | None = None,
args: tuple[Any, ...] | None = None,
) -> None:
"""See the docstring of `get_cached_generic_type_early` for more information about why items are cached with
two different keys.
"""
generic_types_cache = _GENERIC_TYPES_CACHE.get()
if (
generic_types_cache is None
): # pragma: no cover (cache lookup is guaranteed to run first and initialize the cache)
generic_types_cache = GenericTypesCache()
_GENERIC_TYPES_CACHE.set(generic_types_cache)
generic_types_cache[_early_cache_key(parent, typevar_values)] = type_
if len(typevar_values) == 1:
generic_types_cache[_early_cache_key(parent, typevar_values[0])] = type_
if origin and args:
generic_types_cache[_late_cache_key(origin, args, typevar_values)] = type_
def _union_orderings_key(typevar_values: Any) -> Any:
"""This is intended to help differentiate between Union types with the same arguments in different order.
Thanks to caching internal to the `typing` module, it is not possible to distinguish between
List[Union[int, float]] and List[Union[float, int]] (and similarly for other "parent" origins besides List)
because `typing` considers Union[int, float] to be equal to Union[float, int].
However, you _can_ distinguish between (top-level) Union[int, float] vs. Union[float, int].
Because we parse items as the first Union type that is successful, we get slightly more consistent behavior
if we make an effort to distinguish the ordering of items in a union. It would be best if we could _always_
get the exact-correct order of items in the union, but that would require a change to the `typing` module itself.
(See https://github.com/python/cpython/issues/86483 for reference.)
"""
if isinstance(typevar_values, tuple):
args_data = []
for value in typevar_values:
args_data.append(_union_orderings_key(value))
return tuple(args_data)
elif typing_objects.is_union(typing_extensions.get_origin(typevar_values)):
return get_args(typevar_values)
else:
return ()
def _early_cache_key(cls: type[BaseModel], typevar_values: Any) -> GenericTypesCacheKey:
"""This is intended for minimal computational overhead during lookups of cached types.
Note that this is overly simplistic, and it's possible that two different cls/typevar_values
inputs would ultimately result in the same type being created in BaseModel.__class_getitem__.
To handle this, we have a fallback _late_cache_key that is checked later if the _early_cache_key
lookup fails, and should result in a cache hit _precisely_ when the inputs to __class_getitem__
would result in the same type.
"""
return cls, typevar_values, _union_orderings_key(typevar_values)
def _late_cache_key(origin: type[BaseModel], args: tuple[Any, ...], typevar_values: Any) -> GenericTypesCacheKey:
"""This is intended for use later in the process of creating a new type, when we have more information
about the exact args that will be passed. If it turns out that a different set of inputs to
__class_getitem__ resulted in the same inputs to the generic type creation process, we can still
return the cached type, and update the cache with the _early_cache_key as well.
"""
# The _union_orderings_key is placed at the start here to ensure there cannot be a collision with an
# _early_cache_key, as that function will always produce a BaseModel subclass as the first item in the key,
# whereas this function will always produce a tuple as the first item in the key.
return _union_orderings_key(typevar_values), origin, args

View File

@ -0,0 +1,27 @@
"""Git utilities, adopted from mypy's git utilities (https://github.com/python/mypy/blob/master/mypy/git.py)."""
from __future__ import annotations
import os
import subprocess
def is_git_repo(dir: str) -> bool:
"""Is the given directory version-controlled with git?"""
return os.path.exists(os.path.join(dir, '.git'))
def have_git() -> bool: # pragma: no cover
"""Can we run the git executable?"""
try:
subprocess.check_output(['git', '--help'])
return True
except subprocess.CalledProcessError:
return False
except OSError:
return False
def git_revision(dir: str) -> str:
"""Get the SHA-1 of the HEAD of a git repository."""
return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], cwd=dir).decode('utf-8').strip()

View File

@ -0,0 +1,20 @@
from functools import cache
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from pydantic import BaseModel
from pydantic.fields import FieldInfo
@cache
def import_cached_base_model() -> type['BaseModel']:
from pydantic import BaseModel
return BaseModel
@cache
def import_cached_field_info() -> type['FieldInfo']:
from pydantic.fields import FieldInfo
return FieldInfo

View File

@ -0,0 +1,7 @@
import sys
# `slots` is available on Python >= 3.10
if sys.version_info >= (3, 10):
slots_true = {'slots': True}
else:
slots_true = {}

View File

@ -0,0 +1,397 @@
from __future__ import annotations
from collections import defaultdict
from collections.abc import Iterable
from copy import copy
from decimal import Decimal
from functools import lru_cache, partial
from typing import TYPE_CHECKING, Any
from pydantic_core import CoreSchema, PydanticCustomError, ValidationError, to_jsonable_python
from pydantic_core import core_schema as cs
from ._fields import PydanticMetadata
from ._import_utils import import_cached_field_info
if TYPE_CHECKING:
pass
STRICT = {'strict'}
FAIL_FAST = {'fail_fast'}
LENGTH_CONSTRAINTS = {'min_length', 'max_length'}
INEQUALITY = {'le', 'ge', 'lt', 'gt'}
NUMERIC_CONSTRAINTS = {'multiple_of', *INEQUALITY}
ALLOW_INF_NAN = {'allow_inf_nan'}
STR_CONSTRAINTS = {
*LENGTH_CONSTRAINTS,
*STRICT,
'strip_whitespace',
'to_lower',
'to_upper',
'pattern',
'coerce_numbers_to_str',
}
BYTES_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
LIST_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
TUPLE_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
SET_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
DICT_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
GENERATOR_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
SEQUENCE_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *FAIL_FAST}
FLOAT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *ALLOW_INF_NAN, *STRICT}
DECIMAL_CONSTRAINTS = {'max_digits', 'decimal_places', *FLOAT_CONSTRAINTS}
INT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *ALLOW_INF_NAN, *STRICT}
BOOL_CONSTRAINTS = STRICT
UUID_CONSTRAINTS = STRICT
DATE_TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
TIMEDELTA_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
LAX_OR_STRICT_CONSTRAINTS = STRICT
ENUM_CONSTRAINTS = STRICT
COMPLEX_CONSTRAINTS = STRICT
UNION_CONSTRAINTS = {'union_mode'}
URL_CONSTRAINTS = {
'max_length',
'allowed_schemes',
'host_required',
'default_host',
'default_port',
'default_path',
}
TEXT_SCHEMA_TYPES = ('str', 'bytes', 'url', 'multi-host-url')
SEQUENCE_SCHEMA_TYPES = ('list', 'tuple', 'set', 'frozenset', 'generator', *TEXT_SCHEMA_TYPES)
NUMERIC_SCHEMA_TYPES = ('float', 'int', 'date', 'time', 'timedelta', 'datetime')
CONSTRAINTS_TO_ALLOWED_SCHEMAS: dict[str, set[str]] = defaultdict(set)
constraint_schema_pairings: list[tuple[set[str], tuple[str, ...]]] = [
(STR_CONSTRAINTS, TEXT_SCHEMA_TYPES),
(BYTES_CONSTRAINTS, ('bytes',)),
(LIST_CONSTRAINTS, ('list',)),
(TUPLE_CONSTRAINTS, ('tuple',)),
(SET_CONSTRAINTS, ('set', 'frozenset')),
(DICT_CONSTRAINTS, ('dict',)),
(GENERATOR_CONSTRAINTS, ('generator',)),
(FLOAT_CONSTRAINTS, ('float',)),
(INT_CONSTRAINTS, ('int',)),
(DATE_TIME_CONSTRAINTS, ('date', 'time', 'datetime', 'timedelta')),
# TODO: this is a bit redundant, we could probably avoid some of these
(STRICT, (*TEXT_SCHEMA_TYPES, *SEQUENCE_SCHEMA_TYPES, *NUMERIC_SCHEMA_TYPES, 'typed-dict', 'model')),
(UNION_CONSTRAINTS, ('union',)),
(URL_CONSTRAINTS, ('url', 'multi-host-url')),
(BOOL_CONSTRAINTS, ('bool',)),
(UUID_CONSTRAINTS, ('uuid',)),
(LAX_OR_STRICT_CONSTRAINTS, ('lax-or-strict',)),
(ENUM_CONSTRAINTS, ('enum',)),
(DECIMAL_CONSTRAINTS, ('decimal',)),
(COMPLEX_CONSTRAINTS, ('complex',)),
]
for constraints, schemas in constraint_schema_pairings:
for c in constraints:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[c].update(schemas)
def as_jsonable_value(v: Any) -> Any:
if type(v) not in (int, str, float, bytes, bool, type(None)):
return to_jsonable_python(v)
return v
def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]:
"""Expand the annotations.
Args:
annotations: An iterable of annotations.
Returns:
An iterable of expanded annotations.
Example:
```python
from annotated_types import Ge, Len
from pydantic._internal._known_annotated_metadata import expand_grouped_metadata
print(list(expand_grouped_metadata([Ge(4), Len(5)])))
#> [Ge(ge=4), MinLen(min_length=5)]
```
"""
import annotated_types as at
FieldInfo = import_cached_field_info()
for annotation in annotations:
if isinstance(annotation, at.GroupedMetadata):
yield from annotation
elif isinstance(annotation, FieldInfo):
yield from annotation.metadata
# this is a bit problematic in that it results in duplicate metadata
# all of our "consumers" can handle it, but it is not ideal
# we probably should split up FieldInfo into:
# - annotated types metadata
# - individual metadata known only to Pydantic
annotation = copy(annotation)
annotation.metadata = []
yield annotation
else:
yield annotation
@lru_cache
def _get_at_to_constraint_map() -> dict[type, str]:
"""Return a mapping of annotated types to constraints.
Normally, we would define a mapping like this in the module scope, but we can't do that
because we don't permit module level imports of `annotated_types`, in an attempt to speed up
the import time of `pydantic`. We still only want to have this dictionary defined in one place,
so we use this function to cache the result.
"""
import annotated_types as at
return {
at.Gt: 'gt',
at.Ge: 'ge',
at.Lt: 'lt',
at.Le: 'le',
at.MultipleOf: 'multiple_of',
at.MinLen: 'min_length',
at.MaxLen: 'max_length',
}
def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | None: # noqa: C901
"""Apply `annotation` to `schema` if it is an annotation we know about (Gt, Le, etc.).
Otherwise return `None`.
This does not handle all known annotations. If / when it does, it can always
return a CoreSchema and return the unmodified schema if the annotation should be ignored.
Assumes that GroupedMetadata has already been expanded via `expand_grouped_metadata`.
Args:
annotation: The annotation.
schema: The schema.
Returns:
An updated schema with annotation if it is an annotation we know about, `None` otherwise.
Raises:
PydanticCustomError: If `Predicate` fails.
"""
import annotated_types as at
from ._validators import NUMERIC_VALIDATOR_LOOKUP, forbid_inf_nan_check
schema = schema.copy()
schema_update, other_metadata = collect_known_metadata([annotation])
schema_type = schema['type']
chain_schema_constraints: set[str] = {
'pattern',
'strip_whitespace',
'to_lower',
'to_upper',
'coerce_numbers_to_str',
}
chain_schema_steps: list[CoreSchema] = []
for constraint, value in schema_update.items():
if constraint not in CONSTRAINTS_TO_ALLOWED_SCHEMAS:
raise ValueError(f'Unknown constraint {constraint}')
allowed_schemas = CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint]
# if it becomes necessary to handle more than one constraint
# in this recursive case with function-after or function-wrap, we should refactor
# this is a bit challenging because we sometimes want to apply constraints to the inner schema,
# whereas other times we want to wrap the existing schema with a new one that enforces a new constraint.
if schema_type in {'function-before', 'function-wrap', 'function-after'} and constraint == 'strict':
schema['schema'] = apply_known_metadata(annotation, schema['schema']) # type: ignore # schema is function schema
return schema
# if we're allowed to apply constraint directly to the schema, like le to int, do that
if schema_type in allowed_schemas:
if constraint == 'union_mode' and schema_type == 'union':
schema['mode'] = value # type: ignore # schema is UnionSchema
else:
if schema_type == 'decimal' and constraint in {'multiple_of', 'le', 'ge', 'lt', 'gt'}:
schema[constraint] = Decimal(value)
else:
schema[constraint] = value
continue
# else, apply a function after validator to the schema to enforce the corresponding constraint
if constraint in chain_schema_constraints:
def _apply_constraint_with_incompatibility_info(
value: Any, handler: cs.ValidatorFunctionWrapHandler
) -> Any:
try:
x = handler(value)
except ValidationError as ve:
# if the error is about the type, it's likely that the constraint is incompatible the type of the field
# for example, the following invalid schema wouldn't be caught during schema build, but rather at this point
# with a cryptic 'string_type' error coming from the string validator,
# that we'd rather express as a constraint incompatibility error (TypeError)
# Annotated[list[int], Field(pattern='abc')]
if 'type' in ve.errors()[0]['type']:
raise TypeError(
f"Unable to apply constraint '{constraint}' to supplied value {value} for schema of type '{schema_type}'" # noqa: B023
)
raise ve
return x
chain_schema_steps.append(
cs.no_info_wrap_validator_function(
_apply_constraint_with_incompatibility_info, cs.str_schema(**{constraint: value})
)
)
elif constraint in NUMERIC_VALIDATOR_LOOKUP:
if constraint in LENGTH_CONSTRAINTS:
inner_schema = schema
while inner_schema['type'] in {'function-before', 'function-wrap', 'function-after'}:
inner_schema = inner_schema['schema'] # type: ignore
inner_schema_type = inner_schema['type']
if inner_schema_type == 'list' or (
inner_schema_type == 'json-or-python' and inner_schema['json_schema']['type'] == 'list' # type: ignore
):
js_constraint_key = 'minItems' if constraint == 'min_length' else 'maxItems'
else:
js_constraint_key = 'minLength' if constraint == 'min_length' else 'maxLength'
else:
js_constraint_key = constraint
schema = cs.no_info_after_validator_function(
partial(NUMERIC_VALIDATOR_LOOKUP[constraint], **{constraint: value}), schema
)
metadata = schema.get('metadata', {})
if (existing_json_schema_updates := metadata.get('pydantic_js_updates')) is not None:
metadata['pydantic_js_updates'] = {
**existing_json_schema_updates,
**{js_constraint_key: as_jsonable_value(value)},
}
else:
metadata['pydantic_js_updates'] = {js_constraint_key: as_jsonable_value(value)}
schema['metadata'] = metadata
elif constraint == 'allow_inf_nan' and value is False:
schema = cs.no_info_after_validator_function(
forbid_inf_nan_check,
schema,
)
else:
# It's rare that we'd get here, but it's possible if we add a new constraint and forget to handle it
# Most constraint errors are caught at runtime during attempted application
raise RuntimeError(f"Unable to apply constraint '{constraint}' to schema of type '{schema_type}'")
for annotation in other_metadata:
if (annotation_type := type(annotation)) in (at_to_constraint_map := _get_at_to_constraint_map()):
constraint = at_to_constraint_map[annotation_type]
validator = NUMERIC_VALIDATOR_LOOKUP.get(constraint)
if validator is None:
raise ValueError(f'Unknown constraint {constraint}')
schema = cs.no_info_after_validator_function(
partial(validator, {constraint: getattr(annotation, constraint)}), schema
)
continue
elif isinstance(annotation, (at.Predicate, at.Not)):
predicate_name = f'{annotation.func.__qualname__}' if hasattr(annotation.func, '__qualname__') else ''
def val_func(v: Any) -> Any:
predicate_satisfied = annotation.func(v) # noqa: B023
# annotation.func may also raise an exception, let it pass through
if isinstance(annotation, at.Predicate): # noqa: B023
if not predicate_satisfied:
raise PydanticCustomError(
'predicate_failed',
f'Predicate {predicate_name} failed', # type: ignore # noqa: B023
)
else:
if predicate_satisfied:
raise PydanticCustomError(
'not_operation_failed',
f'Not of {predicate_name} failed', # type: ignore # noqa: B023
)
return v
schema = cs.no_info_after_validator_function(val_func, schema)
else:
# ignore any other unknown metadata
return None
if chain_schema_steps:
chain_schema_steps = [schema] + chain_schema_steps
return cs.chain_schema(chain_schema_steps)
return schema
def collect_known_metadata(annotations: Iterable[Any]) -> tuple[dict[str, Any], list[Any]]:
"""Split `annotations` into known metadata and unknown annotations.
Args:
annotations: An iterable of annotations.
Returns:
A tuple contains a dict of known metadata and a list of unknown annotations.
Example:
```python
from annotated_types import Gt, Len
from pydantic._internal._known_annotated_metadata import collect_known_metadata
print(collect_known_metadata([Gt(1), Len(42), ...]))
#> ({'gt': 1, 'min_length': 42}, [Ellipsis])
```
"""
annotations = expand_grouped_metadata(annotations)
res: dict[str, Any] = {}
remaining: list[Any] = []
for annotation in annotations:
# isinstance(annotation, PydanticMetadata) also covers ._fields:_PydanticGeneralMetadata
if isinstance(annotation, PydanticMetadata):
res.update(annotation.__dict__)
# we don't use dataclasses.asdict because that recursively calls asdict on the field values
elif (annotation_type := type(annotation)) in (at_to_constraint_map := _get_at_to_constraint_map()):
constraint = at_to_constraint_map[annotation_type]
res[constraint] = getattr(annotation, constraint)
elif isinstance(annotation, type) and issubclass(annotation, PydanticMetadata):
# also support PydanticMetadata classes being used without initialisation,
# e.g. `Annotated[int, Strict]` as well as `Annotated[int, Strict()]`
res.update({k: v for k, v in vars(annotation).items() if not k.startswith('_')})
else:
remaining.append(annotation)
# Nones can sneak in but pydantic-core will reject them
# it'd be nice to clean things up so we don't put in None (we probably don't _need_ to, it was just easier)
# but this is simple enough to kick that can down the road
res = {k: v for k, v in res.items() if v is not None}
return res, remaining
def check_metadata(metadata: dict[str, Any], allowed: Iterable[str], source_type: Any) -> None:
"""A small utility function to validate that the given metadata can be applied to the target.
More than saving lines of code, this gives us a consistent error message for all of our internal implementations.
Args:
metadata: A dict of metadata.
allowed: An iterable of allowed metadata.
source_type: The source type.
Raises:
TypeError: If there is metadatas that can't be applied on source type.
"""
unknown = metadata.keys() - set(allowed)
if unknown:
raise TypeError(
f'The following constraints cannot be applied to {source_type!r}: {", ".join([f"{k!r}" for k in unknown])}'
)

View File

@ -0,0 +1,228 @@
from __future__ import annotations
from collections.abc import Iterator, Mapping
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, Union
from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator
from ..errors import PydanticErrorCodes, PydanticUserError
from ..plugin._schema_validator import PluggableSchemaValidator
if TYPE_CHECKING:
from ..dataclasses import PydanticDataclass
from ..main import BaseModel
from ..type_adapter import TypeAdapter
ValSer = TypeVar('ValSer', bound=Union[SchemaValidator, PluggableSchemaValidator, SchemaSerializer])
T = TypeVar('T')
class MockCoreSchema(Mapping[str, Any]):
"""Mocker for `pydantic_core.CoreSchema` which optionally attempts to
rebuild the thing it's mocking when one of its methods is accessed and raises an error if that fails.
"""
__slots__ = '_error_message', '_code', '_attempt_rebuild', '_built_memo'
def __init__(
self,
error_message: str,
*,
code: PydanticErrorCodes,
attempt_rebuild: Callable[[], CoreSchema | None] | None = None,
) -> None:
self._error_message = error_message
self._code: PydanticErrorCodes = code
self._attempt_rebuild = attempt_rebuild
self._built_memo: CoreSchema | None = None
def __getitem__(self, key: str) -> Any:
return self._get_built().__getitem__(key)
def __len__(self) -> int:
return self._get_built().__len__()
def __iter__(self) -> Iterator[str]:
return self._get_built().__iter__()
def _get_built(self) -> CoreSchema:
if self._built_memo is not None:
return self._built_memo
if self._attempt_rebuild:
schema = self._attempt_rebuild()
if schema is not None:
self._built_memo = schema
return schema
raise PydanticUserError(self._error_message, code=self._code)
def rebuild(self) -> CoreSchema | None:
self._built_memo = None
if self._attempt_rebuild:
schema = self._attempt_rebuild()
if schema is not None:
return schema
else:
raise PydanticUserError(self._error_message, code=self._code)
return None
class MockValSer(Generic[ValSer]):
"""Mocker for `pydantic_core.SchemaValidator` or `pydantic_core.SchemaSerializer` which optionally attempts to
rebuild the thing it's mocking when one of its methods is accessed and raises an error if that fails.
"""
__slots__ = '_error_message', '_code', '_val_or_ser', '_attempt_rebuild'
def __init__(
self,
error_message: str,
*,
code: PydanticErrorCodes,
val_or_ser: Literal['validator', 'serializer'],
attempt_rebuild: Callable[[], ValSer | None] | None = None,
) -> None:
self._error_message = error_message
self._val_or_ser = SchemaValidator if val_or_ser == 'validator' else SchemaSerializer
self._code: PydanticErrorCodes = code
self._attempt_rebuild = attempt_rebuild
def __getattr__(self, item: str) -> None:
__tracebackhide__ = True
if self._attempt_rebuild:
val_ser = self._attempt_rebuild()
if val_ser is not None:
return getattr(val_ser, item)
# raise an AttributeError if `item` doesn't exist
getattr(self._val_or_ser, item)
raise PydanticUserError(self._error_message, code=self._code)
def rebuild(self) -> ValSer | None:
if self._attempt_rebuild:
val_ser = self._attempt_rebuild()
if val_ser is not None:
return val_ser
else:
raise PydanticUserError(self._error_message, code=self._code)
return None
def set_type_adapter_mocks(adapter: TypeAdapter) -> None:
"""Set `core_schema`, `validator` and `serializer` to mock core types on a type adapter instance.
Args:
adapter: The type adapter instance to set the mocks on
"""
type_repr = str(adapter._type)
undefined_type_error_message = (
f'`TypeAdapter[{type_repr}]` is not fully defined; you should define `{type_repr}` and all referenced types,'
f' then call `.rebuild()` on the instance.'
)
def attempt_rebuild_fn(attr_fn: Callable[[TypeAdapter], T]) -> Callable[[], T | None]:
def handler() -> T | None:
if adapter.rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
return attr_fn(adapter)
return None
return handler
adapter.core_schema = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue]
undefined_type_error_message,
code='class-not-fully-defined',
attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.core_schema),
)
adapter.validator = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='validator',
attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.validator),
)
adapter.serializer = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='serializer',
attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.serializer),
)
def set_model_mocks(cls: type[BaseModel], undefined_name: str = 'all referenced types') -> None:
"""Set `__pydantic_core_schema__`, `__pydantic_validator__` and `__pydantic_serializer__` to mock core types on a model.
Args:
cls: The model class to set the mocks on
undefined_name: Name of the undefined thing, used in error messages
"""
undefined_type_error_message = (
f'`{cls.__name__}` is not fully defined; you should define {undefined_name},'
f' then call `{cls.__name__}.model_rebuild()`.'
)
def attempt_rebuild_fn(attr_fn: Callable[[type[BaseModel]], T]) -> Callable[[], T | None]:
def handler() -> T | None:
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
return attr_fn(cls)
return None
return handler
cls.__pydantic_core_schema__ = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue]
undefined_type_error_message,
code='class-not-fully-defined',
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_core_schema__),
)
cls.__pydantic_validator__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='validator',
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_validator__),
)
cls.__pydantic_serializer__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='serializer',
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_serializer__),
)
def set_dataclass_mocks(cls: type[PydanticDataclass], undefined_name: str = 'all referenced types') -> None:
"""Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a dataclass.
Args:
cls: The model class to set the mocks on
undefined_name: Name of the undefined thing, used in error messages
"""
from ..dataclasses import rebuild_dataclass
undefined_type_error_message = (
f'`{cls.__name__}` is not fully defined; you should define {undefined_name},'
f' then call `pydantic.dataclasses.rebuild_dataclass({cls.__name__})`.'
)
def attempt_rebuild_fn(attr_fn: Callable[[type[PydanticDataclass]], T]) -> Callable[[], T | None]:
def handler() -> T | None:
if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5) is not False:
return attr_fn(cls)
return None
return handler
cls.__pydantic_core_schema__ = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue]
undefined_type_error_message,
code='class-not-fully-defined',
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_core_schema__),
)
cls.__pydantic_validator__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='validator',
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_validator__),
)
cls.__pydantic_serializer__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='serializer',
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_serializer__),
)

View File

@ -0,0 +1,792 @@
"""Private logic for creating models."""
from __future__ import annotations as _annotations
import builtins
import operator
import sys
import typing
import warnings
import weakref
from abc import ABCMeta
from functools import cache, partial, wraps
from types import FunctionType
from typing import Any, Callable, Generic, Literal, NoReturn, cast
from pydantic_core import PydanticUndefined, SchemaSerializer
from typing_extensions import TypeAliasType, dataclass_transform, deprecated, get_args, get_origin
from typing_inspection import typing_objects
from ..errors import PydanticUndefinedAnnotation, PydanticUserError
from ..plugin._schema_validator import create_schema_validator
from ..warnings import GenericBeforeBaseModelWarning, PydanticDeprecatedSince20
from ._config import ConfigWrapper
from ._decorators import DecoratorInfos, PydanticDescriptorProxy, get_attribute_from_bases, unwrap_wrapped_function
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name
from ._generate_schema import GenerateSchema, InvalidSchemaError
from ._generics import PydanticGenericMetadata, get_model_typevars_map
from ._import_utils import import_cached_base_model, import_cached_field_info
from ._mock_val_ser import set_model_mocks
from ._namespace_utils import NsResolver
from ._signature import generate_pydantic_signature
from ._typing_extra import (
_make_forward_ref,
eval_type_backport,
is_classvar_annotation,
parent_frame_namespace,
)
from ._utils import LazyClassAttribute, SafeGetItemProxy
if typing.TYPE_CHECKING:
from ..fields import Field as PydanticModelField
from ..fields import FieldInfo, ModelPrivateAttr
from ..fields import PrivateAttr as PydanticModelPrivateAttr
from ..main import BaseModel
else:
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
# and https://youtrack.jetbrains.com/issue/PY-51428
DeprecationWarning = PydanticDeprecatedSince20
PydanticModelField = object()
PydanticModelPrivateAttr = object()
object_setattr = object.__setattr__
class _ModelNamespaceDict(dict):
"""A dictionary subclass that intercepts attribute setting on model classes and
warns about overriding of decorators.
"""
def __setitem__(self, k: str, v: object) -> None:
existing: Any = self.get(k, None)
if existing and v is not existing and isinstance(existing, PydanticDescriptorProxy):
warnings.warn(f'`{k}` overrides an existing Pydantic `{existing.decorator_info.decorator_repr}` decorator')
return super().__setitem__(k, v)
def NoInitField(
*,
init: Literal[False] = False,
) -> Any:
"""Only for typing purposes. Used as default value of `__pydantic_fields_set__`,
`__pydantic_extra__`, `__pydantic_private__`, so they could be ignored when
synthesizing the `__init__` signature.
"""
@dataclass_transform(kw_only_default=True, field_specifiers=(PydanticModelField, PydanticModelPrivateAttr, NoInitField))
class ModelMetaclass(ABCMeta):
def __new__(
mcs,
cls_name: str,
bases: tuple[type[Any], ...],
namespace: dict[str, Any],
__pydantic_generic_metadata__: PydanticGenericMetadata | None = None,
__pydantic_reset_parent_namespace__: bool = True,
_create_model_module: str | None = None,
**kwargs: Any,
) -> type:
"""Metaclass for creating Pydantic models.
Args:
cls_name: The name of the class to be created.
bases: The base classes of the class to be created.
namespace: The attribute dictionary of the class to be created.
__pydantic_generic_metadata__: Metadata for generic models.
__pydantic_reset_parent_namespace__: Reset parent namespace.
_create_model_module: The module of the class to be created, if created by `create_model`.
**kwargs: Catch-all for any other keyword arguments.
Returns:
The new class created by the metaclass.
"""
# Note `ModelMetaclass` refers to `BaseModel`, but is also used to *create* `BaseModel`, so we rely on the fact
# that `BaseModel` itself won't have any bases, but any subclass of it will, to determine whether the `__new__`
# call we're in the middle of is for the `BaseModel` class.
if bases:
base_field_names, class_vars, base_private_attributes = mcs._collect_bases_data(bases)
config_wrapper = ConfigWrapper.for_model(bases, namespace, kwargs)
namespace['model_config'] = config_wrapper.config_dict
private_attributes = inspect_namespace(
namespace, config_wrapper.ignored_types, class_vars, base_field_names
)
if private_attributes or base_private_attributes:
original_model_post_init = get_model_post_init(namespace, bases)
if original_model_post_init is not None:
# if there are private_attributes and a model_post_init function, we handle both
@wraps(original_model_post_init)
def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
"""We need to both initialize private attributes and call the user-defined model_post_init
method.
"""
init_private_attributes(self, context)
original_model_post_init(self, context)
namespace['model_post_init'] = wrapped_model_post_init
else:
namespace['model_post_init'] = init_private_attributes
namespace['__class_vars__'] = class_vars
namespace['__private_attributes__'] = {**base_private_attributes, **private_attributes}
cls = cast('type[BaseModel]', super().__new__(mcs, cls_name, bases, namespace, **kwargs))
BaseModel_ = import_cached_base_model()
mro = cls.__mro__
if Generic in mro and mro.index(Generic) < mro.index(BaseModel_):
warnings.warn(
GenericBeforeBaseModelWarning(
'Classes should inherit from `BaseModel` before generic classes (e.g. `typing.Generic[T]`) '
'for pydantic generics to work properly.'
),
stacklevel=2,
)
cls.__pydantic_custom_init__ = not getattr(cls.__init__, '__pydantic_base_init__', False)
cls.__pydantic_post_init__ = (
None if cls.model_post_init is BaseModel_.model_post_init else 'model_post_init'
)
cls.__pydantic_setattr_handlers__ = {}
cls.__pydantic_decorators__ = DecoratorInfos.build(cls)
# Use the getattr below to grab the __parameters__ from the `typing.Generic` parent class
if __pydantic_generic_metadata__:
cls.__pydantic_generic_metadata__ = __pydantic_generic_metadata__
else:
parent_parameters = getattr(cls, '__pydantic_generic_metadata__', {}).get('parameters', ())
parameters = getattr(cls, '__parameters__', None) or parent_parameters
if parameters and parent_parameters and not all(x in parameters for x in parent_parameters):
from ..root_model import RootModelRootType
missing_parameters = tuple(x for x in parameters if x not in parent_parameters)
if RootModelRootType in parent_parameters and RootModelRootType not in parameters:
# This is a special case where the user has subclassed `RootModel`, but has not parametrized
# RootModel with the generic type identifiers being used. Ex:
# class MyModel(RootModel, Generic[T]):
# root: T
# Should instead just be:
# class MyModel(RootModel[T]):
# root: T
parameters_str = ', '.join([x.__name__ for x in missing_parameters])
error_message = (
f'{cls.__name__} is a subclass of `RootModel`, but does not include the generic type identifier(s) '
f'{parameters_str} in its parameters. '
f'You should parametrize RootModel directly, e.g., `class {cls.__name__}(RootModel[{parameters_str}]): ...`.'
)
else:
combined_parameters = parent_parameters + missing_parameters
parameters_str = ', '.join([str(x) for x in combined_parameters])
generic_type_label = f'typing.Generic[{parameters_str}]'
error_message = (
f'All parameters must be present on typing.Generic;'
f' you should inherit from {generic_type_label}.'
)
if Generic not in bases: # pragma: no cover
# We raise an error here not because it is desirable, but because some cases are mishandled.
# It would be nice to remove this error and still have things behave as expected, it's just
# challenging because we are using a custom `__class_getitem__` to parametrize generic models,
# and not returning a typing._GenericAlias from it.
bases_str = ', '.join([x.__name__ for x in bases] + [generic_type_label])
error_message += (
f' Note: `typing.Generic` must go last: `class {cls.__name__}({bases_str}): ...`)'
)
raise TypeError(error_message)
cls.__pydantic_generic_metadata__ = {
'origin': None,
'args': (),
'parameters': parameters,
}
cls.__pydantic_complete__ = False # Ensure this specific class gets completed
# preserve `__set_name__` protocol defined in https://peps.python.org/pep-0487
# for attributes not in `new_namespace` (e.g. private attributes)
for name, obj in private_attributes.items():
obj.__set_name__(cls, name)
if __pydantic_reset_parent_namespace__:
cls.__pydantic_parent_namespace__ = build_lenient_weakvaluedict(parent_frame_namespace())
parent_namespace: dict[str, Any] | None = getattr(cls, '__pydantic_parent_namespace__', None)
if isinstance(parent_namespace, dict):
parent_namespace = unpack_lenient_weakvaluedict(parent_namespace)
ns_resolver = NsResolver(parent_namespace=parent_namespace)
set_model_fields(cls, config_wrapper=config_wrapper, ns_resolver=ns_resolver)
# This is also set in `complete_model_class()`, after schema gen because they are recreated.
# We set them here as well for backwards compatibility:
cls.__pydantic_computed_fields__ = {
k: v.info for k, v in cls.__pydantic_decorators__.computed_fields.items()
}
if config_wrapper.defer_build:
# TODO we can also stop there if `__pydantic_fields_complete__` is False.
# However, `set_model_fields()` is currently lenient and we don't have access to the `NameError`.
# (which is useful as we can provide the name in the error message: `set_model_mock(cls, e.name)`)
set_model_mocks(cls)
else:
# Any operation that requires accessing the field infos instances should be put inside
# `complete_model_class()`:
complete_model_class(
cls,
config_wrapper,
raise_errors=False,
ns_resolver=ns_resolver,
create_model_module=_create_model_module,
)
if config_wrapper.frozen and '__hash__' not in namespace:
set_default_hash_func(cls, bases)
# using super(cls, cls) on the next line ensures we only call the parent class's __pydantic_init_subclass__
# I believe the `type: ignore` is only necessary because mypy doesn't realize that this code branch is
# only hit for _proper_ subclasses of BaseModel
super(cls, cls).__pydantic_init_subclass__(**kwargs) # type: ignore[misc]
return cls
else:
# These are instance variables, but have been assigned to `NoInitField` to trick the type checker.
for instance_slot in '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__':
namespace.pop(
instance_slot,
None, # In case the metaclass is used with a class other than `BaseModel`.
)
namespace.get('__annotations__', {}).clear()
return super().__new__(mcs, cls_name, bases, namespace, **kwargs)
if not typing.TYPE_CHECKING: # pragma: no branch
# We put `__getattr__` in a non-TYPE_CHECKING block because otherwise, mypy allows arbitrary attribute access
def __getattr__(self, item: str) -> Any:
"""This is necessary to keep attribute access working for class attribute access."""
private_attributes = self.__dict__.get('__private_attributes__')
if private_attributes and item in private_attributes:
return private_attributes[item]
raise AttributeError(item)
@classmethod
def __prepare__(cls, *args: Any, **kwargs: Any) -> dict[str, object]:
return _ModelNamespaceDict()
def __instancecheck__(self, instance: Any) -> bool:
"""Avoid calling ABC _abc_instancecheck unless we're pretty sure.
See #3829 and python/cpython#92810
"""
return hasattr(instance, '__pydantic_decorators__') and super().__instancecheck__(instance)
def __subclasscheck__(self, subclass: type[Any]) -> bool:
"""Avoid calling ABC _abc_subclasscheck unless we're pretty sure.
See #3829 and python/cpython#92810
"""
return hasattr(subclass, '__pydantic_decorators__') and super().__subclasscheck__(subclass)
@staticmethod
def _collect_bases_data(bases: tuple[type[Any], ...]) -> tuple[set[str], set[str], dict[str, ModelPrivateAttr]]:
BaseModel = import_cached_base_model()
field_names: set[str] = set()
class_vars: set[str] = set()
private_attributes: dict[str, ModelPrivateAttr] = {}
for base in bases:
if issubclass(base, BaseModel) and base is not BaseModel:
# model_fields might not be defined yet in the case of generics, so we use getattr here:
field_names.update(getattr(base, '__pydantic_fields__', {}).keys())
class_vars.update(base.__class_vars__)
private_attributes.update(base.__private_attributes__)
return field_names, class_vars, private_attributes
@property
@deprecated('The `__fields__` attribute is deprecated, use `model_fields` instead.', category=None)
def __fields__(self) -> dict[str, FieldInfo]:
warnings.warn(
'The `__fields__` attribute is deprecated, use `model_fields` instead.',
PydanticDeprecatedSince20,
stacklevel=2,
)
return getattr(self, '__pydantic_fields__', {})
@property
def __pydantic_fields_complete__(self) -> bool:
"""Whether the fields where successfully collected (i.e. type hints were successfully resolves).
This is a private attribute, not meant to be used outside Pydantic.
"""
if not hasattr(self, '__pydantic_fields__'):
return False
field_infos = cast('dict[str, FieldInfo]', self.__pydantic_fields__) # pyright: ignore[reportAttributeAccessIssue]
return all(field_info._complete for field_info in field_infos.values())
def __dir__(self) -> list[str]:
attributes = list(super().__dir__())
if '__fields__' in attributes:
attributes.remove('__fields__')
return attributes
def init_private_attributes(self: BaseModel, context: Any, /) -> None:
"""This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Args:
self: The BaseModel instance.
context: The context.
"""
if getattr(self, '__pydantic_private__', None) is None:
pydantic_private = {}
for name, private_attr in self.__private_attributes__.items():
default = private_attr.get_default()
if default is not PydanticUndefined:
pydantic_private[name] = default
object_setattr(self, '__pydantic_private__', pydantic_private)
def get_model_post_init(namespace: dict[str, Any], bases: tuple[type[Any], ...]) -> Callable[..., Any] | None:
"""Get the `model_post_init` method from the namespace or the class bases, or `None` if not defined."""
if 'model_post_init' in namespace:
return namespace['model_post_init']
BaseModel = import_cached_base_model()
model_post_init = get_attribute_from_bases(bases, 'model_post_init')
if model_post_init is not BaseModel.model_post_init:
return model_post_init
def inspect_namespace( # noqa C901
namespace: dict[str, Any],
ignored_types: tuple[type[Any], ...],
base_class_vars: set[str],
base_class_fields: set[str],
) -> dict[str, ModelPrivateAttr]:
"""Iterate over the namespace and:
* gather private attributes
* check for items which look like fields but are not (e.g. have no annotation) and warn.
Args:
namespace: The attribute dictionary of the class to be created.
ignored_types: A tuple of ignore types.
base_class_vars: A set of base class class variables.
base_class_fields: A set of base class fields.
Returns:
A dict contains private attributes info.
Raises:
TypeError: If there is a `__root__` field in model.
NameError: If private attribute name is invalid.
PydanticUserError:
- If a field does not have a type annotation.
- If a field on base class was overridden by a non-annotated attribute.
"""
from ..fields import ModelPrivateAttr, PrivateAttr
FieldInfo = import_cached_field_info()
all_ignored_types = ignored_types + default_ignored_types()
private_attributes: dict[str, ModelPrivateAttr] = {}
raw_annotations = namespace.get('__annotations__', {})
if '__root__' in raw_annotations or '__root__' in namespace:
raise TypeError("To define root models, use `pydantic.RootModel` rather than a field called '__root__'")
ignored_names: set[str] = set()
for var_name, value in list(namespace.items()):
if var_name == 'model_config' or var_name == '__pydantic_extra__':
continue
elif (
isinstance(value, type)
and value.__module__ == namespace['__module__']
and '__qualname__' in namespace
and value.__qualname__.startswith(namespace['__qualname__'])
):
# `value` is a nested type defined in this namespace; don't error
continue
elif isinstance(value, all_ignored_types) or value.__class__.__module__ == 'functools':
ignored_names.add(var_name)
continue
elif isinstance(value, ModelPrivateAttr):
if var_name.startswith('__'):
raise NameError(
'Private attributes must not use dunder names;'
f' use a single underscore prefix instead of {var_name!r}.'
)
elif is_valid_field_name(var_name):
raise NameError(
'Private attributes must not use valid field names;'
f' use sunder names, e.g. {"_" + var_name!r} instead of {var_name!r}.'
)
private_attributes[var_name] = value
del namespace[var_name]
elif isinstance(value, FieldInfo) and not is_valid_field_name(var_name):
suggested_name = var_name.lstrip('_') or 'my_field' # don't suggest '' for all-underscore name
raise NameError(
f'Fields must not use names with leading underscores;'
f' e.g., use {suggested_name!r} instead of {var_name!r}.'
)
elif var_name.startswith('__'):
continue
elif is_valid_privateattr_name(var_name):
if var_name not in raw_annotations or not is_classvar_annotation(raw_annotations[var_name]):
private_attributes[var_name] = cast(ModelPrivateAttr, PrivateAttr(default=value))
del namespace[var_name]
elif var_name in base_class_vars:
continue
elif var_name not in raw_annotations:
if var_name in base_class_fields:
raise PydanticUserError(
f'Field {var_name!r} defined on a base class was overridden by a non-annotated attribute. '
f'All field definitions, including overrides, require a type annotation.',
code='model-field-overridden',
)
elif isinstance(value, FieldInfo):
raise PydanticUserError(
f'Field {var_name!r} requires a type annotation', code='model-field-missing-annotation'
)
else:
raise PydanticUserError(
f'A non-annotated attribute was detected: `{var_name} = {value!r}`. All model fields require a '
f'type annotation; if `{var_name}` is not meant to be a field, you may be able to resolve this '
f"error by annotating it as a `ClassVar` or updating `model_config['ignored_types']`.",
code='model-field-missing-annotation',
)
for ann_name, ann_type in raw_annotations.items():
if (
is_valid_privateattr_name(ann_name)
and ann_name not in private_attributes
and ann_name not in ignored_names
# This condition can be a false negative when `ann_type` is stringified,
# but it is handled in most cases in `set_model_fields`:
and not is_classvar_annotation(ann_type)
and ann_type not in all_ignored_types
and getattr(ann_type, '__module__', None) != 'functools'
):
if isinstance(ann_type, str):
# Walking up the frames to get the module namespace where the model is defined
# (as the model class wasn't created yet, we unfortunately can't use `cls.__module__`):
frame = sys._getframe(2)
if frame is not None:
try:
ann_type = eval_type_backport(
_make_forward_ref(ann_type, is_argument=False, is_class=True),
globalns=frame.f_globals,
localns=frame.f_locals,
)
except (NameError, TypeError):
pass
if typing_objects.is_annotated(get_origin(ann_type)):
_, *metadata = get_args(ann_type)
private_attr = next((v for v in metadata if isinstance(v, ModelPrivateAttr)), None)
if private_attr is not None:
private_attributes[ann_name] = private_attr
continue
private_attributes[ann_name] = PrivateAttr()
return private_attributes
def set_default_hash_func(cls: type[BaseModel], bases: tuple[type[Any], ...]) -> None:
base_hash_func = get_attribute_from_bases(bases, '__hash__')
new_hash_func = make_hash_func(cls)
if base_hash_func in {None, object.__hash__} or getattr(base_hash_func, '__code__', None) == new_hash_func.__code__:
# If `__hash__` is some default, we generate a hash function.
# It will be `None` if not overridden from BaseModel.
# It may be `object.__hash__` if there is another
# parent class earlier in the bases which doesn't override `__hash__` (e.g. `typing.Generic`).
# It may be a value set by `set_default_hash_func` if `cls` is a subclass of another frozen model.
# In the last case we still need a new hash function to account for new `model_fields`.
cls.__hash__ = new_hash_func
def make_hash_func(cls: type[BaseModel]) -> Any:
getter = operator.itemgetter(*cls.__pydantic_fields__.keys()) if cls.__pydantic_fields__ else lambda _: 0
def hash_func(self: Any) -> int:
try:
return hash(getter(self.__dict__))
except KeyError:
# In rare cases (such as when using the deprecated copy method), the __dict__ may not contain
# all model fields, which is how we can get here.
# getter(self.__dict__) is much faster than any 'safe' method that accounts for missing keys,
# and wrapping it in a `try` doesn't slow things down much in the common case.
return hash(getter(SafeGetItemProxy(self.__dict__)))
return hash_func
def set_model_fields(
cls: type[BaseModel],
config_wrapper: ConfigWrapper,
ns_resolver: NsResolver | None,
) -> None:
"""Collect and set `cls.__pydantic_fields__` and `cls.__class_vars__`.
Args:
cls: BaseModel or dataclass.
config_wrapper: The config wrapper instance.
ns_resolver: Namespace resolver to use when getting model annotations.
"""
typevars_map = get_model_typevars_map(cls)
fields, class_vars = collect_model_fields(cls, config_wrapper, ns_resolver, typevars_map=typevars_map)
cls.__pydantic_fields__ = fields
cls.__class_vars__.update(class_vars)
for k in class_vars:
# Class vars should not be private attributes
# We remove them _here_ and not earlier because we rely on inspecting the class to determine its classvars,
# but private attributes are determined by inspecting the namespace _prior_ to class creation.
# In the case that a classvar with a leading-'_' is defined via a ForwardRef (e.g., when using
# `__future__.annotations`), we want to remove the private attribute which was detected _before_ we knew it
# evaluated to a classvar
value = cls.__private_attributes__.pop(k, None)
if value is not None and value.default is not PydanticUndefined:
setattr(cls, k, value.default)
def complete_model_class(
cls: type[BaseModel],
config_wrapper: ConfigWrapper,
*,
raise_errors: bool = True,
ns_resolver: NsResolver | None = None,
create_model_module: str | None = None,
) -> bool:
"""Finish building a model class.
This logic must be called after class has been created since validation functions must be bound
and `get_type_hints` requires a class object.
Args:
cls: BaseModel or dataclass.
config_wrapper: The config wrapper instance.
raise_errors: Whether to raise errors.
ns_resolver: The namespace resolver instance to use during schema building.
create_model_module: The module of the class to be created, if created by `create_model`.
Returns:
`True` if the model is successfully completed, else `False`.
Raises:
PydanticUndefinedAnnotation: If `PydanticUndefinedAnnotation` occurs in`__get_pydantic_core_schema__`
and `raise_errors=True`.
"""
typevars_map = get_model_typevars_map(cls)
gen_schema = GenerateSchema(
config_wrapper,
ns_resolver,
typevars_map,
)
try:
schema = gen_schema.generate_schema(cls)
except PydanticUndefinedAnnotation as e:
if raise_errors:
raise
set_model_mocks(cls, f'`{e.name}`')
return False
core_config = config_wrapper.core_config(title=cls.__name__)
try:
schema = gen_schema.clean_schema(schema)
except InvalidSchemaError:
set_model_mocks(cls)
return False
# This needs to happen *after* model schema generation, as the return type
# of the properties are evaluated and the `ComputedFieldInfo` are recreated:
cls.__pydantic_computed_fields__ = {k: v.info for k, v in cls.__pydantic_decorators__.computed_fields.items()}
set_deprecated_descriptors(cls)
cls.__pydantic_core_schema__ = schema
cls.__pydantic_validator__ = create_schema_validator(
schema,
cls,
create_model_module or cls.__module__,
cls.__qualname__,
'create_model' if create_model_module else 'BaseModel',
core_config,
config_wrapper.plugin_settings,
)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
cls.__pydantic_complete__ = True
# set __signature__ attr only for model class, but not for its instances
# (because instances can define `__call__`, and `inspect.signature` shouldn't
# use the `__signature__` attribute and instead generate from `__call__`).
cls.__signature__ = LazyClassAttribute(
'__signature__',
partial(
generate_pydantic_signature,
init=cls.__init__,
fields=cls.__pydantic_fields__,
validate_by_name=config_wrapper.validate_by_name,
extra=config_wrapper.extra,
),
)
return True
def set_deprecated_descriptors(cls: type[BaseModel]) -> None:
"""Set data descriptors on the class for deprecated fields."""
for field, field_info in cls.__pydantic_fields__.items():
if (msg := field_info.deprecation_message) is not None:
desc = _DeprecatedFieldDescriptor(msg)
desc.__set_name__(cls, field)
setattr(cls, field, desc)
for field, computed_field_info in cls.__pydantic_computed_fields__.items():
if (
(msg := computed_field_info.deprecation_message) is not None
# Avoid having two warnings emitted:
and not hasattr(unwrap_wrapped_function(computed_field_info.wrapped_property), '__deprecated__')
):
desc = _DeprecatedFieldDescriptor(msg, computed_field_info.wrapped_property)
desc.__set_name__(cls, field)
setattr(cls, field, desc)
class _DeprecatedFieldDescriptor:
"""Read-only data descriptor used to emit a runtime deprecation warning before accessing a deprecated field.
Attributes:
msg: The deprecation message to be emitted.
wrapped_property: The property instance if the deprecated field is a computed field, or `None`.
field_name: The name of the field being deprecated.
"""
field_name: str
def __init__(self, msg: str, wrapped_property: property | None = None) -> None:
self.msg = msg
self.wrapped_property = wrapped_property
def __set_name__(self, cls: type[BaseModel], name: str) -> None:
self.field_name = name
def __get__(self, obj: BaseModel | None, obj_type: type[BaseModel] | None = None) -> Any:
if obj is None:
if self.wrapped_property is not None:
return self.wrapped_property.__get__(None, obj_type)
raise AttributeError(self.field_name)
warnings.warn(self.msg, builtins.DeprecationWarning, stacklevel=2)
if self.wrapped_property is not None:
return self.wrapped_property.__get__(obj, obj_type)
return obj.__dict__[self.field_name]
# Defined to make it a data descriptor and take precedence over the instance's dictionary.
# Note that it will not be called when setting a value on a model instance
# as `BaseModel.__setattr__` is defined and takes priority.
def __set__(self, obj: Any, value: Any) -> NoReturn:
raise AttributeError(self.field_name)
class _PydanticWeakRef:
"""Wrapper for `weakref.ref` that enables `pickle` serialization.
Cloudpickle fails to serialize `weakref.ref` objects due to an arcane error related
to abstract base classes (`abc.ABC`). This class works around the issue by wrapping
`weakref.ref` instead of subclassing it.
See https://github.com/pydantic/pydantic/issues/6763 for context.
Semantics:
- If not pickled, behaves the same as a `weakref.ref`.
- If pickled along with the referenced object, the same `weakref.ref` behavior
will be maintained between them after unpickling.
- If pickled without the referenced object, after unpickling the underlying
reference will be cleared (`__call__` will always return `None`).
"""
def __init__(self, obj: Any):
if obj is None:
# The object will be `None` upon deserialization if the serialized weakref
# had lost its underlying object.
self._wr = None
else:
self._wr = weakref.ref(obj)
def __call__(self) -> Any:
if self._wr is None:
return None
else:
return self._wr()
def __reduce__(self) -> tuple[Callable, tuple[weakref.ReferenceType | None]]:
return _PydanticWeakRef, (self(),)
def build_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | None:
"""Takes an input dictionary, and produces a new value that (invertibly) replaces the values with weakrefs.
We can't just use a WeakValueDictionary because many types (including int, str, etc.) can't be stored as values
in a WeakValueDictionary.
The `unpack_lenient_weakvaluedict` function can be used to reverse this operation.
"""
if d is None:
return None
result = {}
for k, v in d.items():
try:
proxy = _PydanticWeakRef(v)
except TypeError:
proxy = v
result[k] = proxy
return result
def unpack_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | None:
"""Inverts the transform performed by `build_lenient_weakvaluedict`."""
if d is None:
return None
result = {}
for k, v in d.items():
if isinstance(v, _PydanticWeakRef):
v = v()
if v is not None:
result[k] = v
else:
result[k] = v
return result
@cache
def default_ignored_types() -> tuple[type[Any], ...]:
from ..fields import ComputedFieldInfo
ignored_types = [
FunctionType,
property,
classmethod,
staticmethod,
PydanticDescriptorProxy,
ComputedFieldInfo,
TypeAliasType, # from `typing_extensions`
]
if sys.version_info >= (3, 12):
ignored_types.append(typing.TypeAliasType)
return tuple(ignored_types)

View File

@ -0,0 +1,293 @@
from __future__ import annotations
import sys
from collections.abc import Generator, Iterator, Mapping
from contextlib import contextmanager
from functools import cached_property
from typing import Any, Callable, NamedTuple, TypeVar
from typing_extensions import ParamSpec, TypeAlias, TypeAliasType, TypeVarTuple
GlobalsNamespace: TypeAlias = 'dict[str, Any]'
"""A global namespace.
In most cases, this is a reference to the `__dict__` attribute of a module.
This namespace type is expected as the `globals` argument during annotations evaluation.
"""
MappingNamespace: TypeAlias = Mapping[str, Any]
"""Any kind of namespace.
In most cases, this is a local namespace (e.g. the `__dict__` attribute of a class,
the [`f_locals`][frame.f_locals] attribute of a frame object, when dealing with types
defined inside functions).
This namespace type is expected as the `locals` argument during annotations evaluation.
"""
_TypeVarLike: TypeAlias = 'TypeVar | ParamSpec | TypeVarTuple'
class NamespacesTuple(NamedTuple):
"""A tuple of globals and locals to be used during annotations evaluation.
This datastructure is defined as a named tuple so that it can easily be unpacked:
```python {lint="skip" test="skip"}
def eval_type(typ: type[Any], ns: NamespacesTuple) -> None:
return eval(typ, *ns)
```
"""
globals: GlobalsNamespace
"""The namespace to be used as the `globals` argument during annotations evaluation."""
locals: MappingNamespace
"""The namespace to be used as the `locals` argument during annotations evaluation."""
def get_module_ns_of(obj: Any) -> dict[str, Any]:
"""Get the namespace of the module where the object is defined.
Caution: this function does not return a copy of the module namespace, so the result
should not be mutated. The burden of enforcing this is on the caller.
"""
module_name = getattr(obj, '__module__', None)
if module_name:
try:
return sys.modules[module_name].__dict__
except KeyError:
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
return {}
return {}
# Note that this class is almost identical to `collections.ChainMap`, but need to enforce
# immutable mappings here:
class LazyLocalNamespace(Mapping[str, Any]):
"""A lazily evaluated mapping, to be used as the `locals` argument during annotations evaluation.
While the [`eval`][eval] function expects a mapping as the `locals` argument, it only
performs `__getitem__` calls. The [`Mapping`][collections.abc.Mapping] abstract base class
is fully implemented only for type checking purposes.
Args:
*namespaces: The namespaces to consider, in ascending order of priority.
Example:
```python {lint="skip" test="skip"}
ns = LazyLocalNamespace({'a': 1, 'b': 2}, {'a': 3})
ns['a']
#> 3
ns['b']
#> 2
```
"""
def __init__(self, *namespaces: MappingNamespace) -> None:
self._namespaces = namespaces
@cached_property
def data(self) -> dict[str, Any]:
return {k: v for ns in self._namespaces for k, v in ns.items()}
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, key: str) -> Any:
return self.data[key]
def __contains__(self, key: object) -> bool:
return key in self.data
def __iter__(self) -> Iterator[str]:
return iter(self.data)
def ns_for_function(obj: Callable[..., Any], parent_namespace: MappingNamespace | None = None) -> NamespacesTuple:
"""Return the global and local namespaces to be used when evaluating annotations for the provided function.
The global namespace will be the `__dict__` attribute of the module the function was defined in.
The local namespace will contain the `__type_params__` introduced by PEP 695.
Args:
obj: The object to use when building namespaces.
parent_namespace: Optional namespace to be added with the lowest priority in the local namespace.
If the passed function is a method, the `parent_namespace` will be the namespace of the class
the method is defined in. Thus, we also fetch type `__type_params__` from there (i.e. the
class-scoped type variables).
"""
locals_list: list[MappingNamespace] = []
if parent_namespace is not None:
locals_list.append(parent_namespace)
# Get the `__type_params__` attribute introduced by PEP 695.
# Note that the `typing._eval_type` function expects type params to be
# passed as a separate argument. However, internally, `_eval_type` calls
# `ForwardRef._evaluate` which will merge type params with the localns,
# essentially mimicking what we do here.
type_params: tuple[_TypeVarLike, ...] = getattr(obj, '__type_params__', ())
if parent_namespace is not None:
# We also fetch type params from the parent namespace. If present, it probably
# means the function was defined in a class. This is to support the following:
# https://github.com/python/cpython/issues/124089.
type_params += parent_namespace.get('__type_params__', ())
locals_list.append({t.__name__: t for t in type_params})
# What about short-cirtuiting to `obj.__globals__`?
globalns = get_module_ns_of(obj)
return NamespacesTuple(globalns, LazyLocalNamespace(*locals_list))
class NsResolver:
"""A class responsible for the namespaces resolving logic for annotations evaluation.
This class handles the namespace logic when evaluating annotations mainly for class objects.
It holds a stack of classes that are being inspected during the core schema building,
and the `types_namespace` property exposes the globals and locals to be used for
type annotation evaluation. Additionally -- if no class is present in the stack -- a
fallback globals and locals can be provided using the `namespaces_tuple` argument
(this is useful when generating a schema for a simple annotation, e.g. when using
`TypeAdapter`).
The namespace creation logic is unfortunately flawed in some cases, for backwards
compatibility reasons and to better support valid edge cases. See the description
for the `parent_namespace` argument and the example for more details.
Args:
namespaces_tuple: The default globals and locals to use if no class is present
on the stack. This can be useful when using the `GenerateSchema` class
with `TypeAdapter`, where the "type" being analyzed is a simple annotation.
parent_namespace: An optional parent namespace that will be added to the locals
with the lowest priority. For a given class defined in a function, the locals
of this function are usually used as the parent namespace:
```python {lint="skip" test="skip"}
from pydantic import BaseModel
def func() -> None:
SomeType = int
class Model(BaseModel):
f: 'SomeType'
# when collecting fields, an namespace resolver instance will be created
# this way:
# ns_resolver = NsResolver(parent_namespace={'SomeType': SomeType})
```
For backwards compatibility reasons and to support valid edge cases, this parent
namespace will be used for *every* type being pushed to the stack. In the future,
we might want to be smarter by only doing so when the type being pushed is defined
in the same module as the parent namespace.
Example:
```python {lint="skip" test="skip"}
ns_resolver = NsResolver(
parent_namespace={'fallback': 1},
)
class Sub:
m: 'Model'
class Model:
some_local = 1
sub: Sub
ns_resolver = NsResolver()
# This is roughly what happens when we build a core schema for `Model`:
with ns_resolver.push(Model):
ns_resolver.types_namespace
#> NamespacesTuple({'Sub': Sub}, {'Model': Model, 'some_local': 1})
# First thing to notice here, the model being pushed is added to the locals.
# Because `NsResolver` is being used during the model definition, it is not
# yet added to the globals. This is useful when resolving self-referencing annotations.
with ns_resolver.push(Sub):
ns_resolver.types_namespace
#> NamespacesTuple({'Sub': Sub}, {'Sub': Sub, 'Model': Model})
# Second thing to notice: `Sub` is present in both the globals and locals.
# This is not an issue, just that as described above, the model being pushed
# is added to the locals, but it happens to be present in the globals as well
# because it is already defined.
# Third thing to notice: `Model` is also added in locals. This is a backwards
# compatibility workaround that allows for `Sub` to be able to resolve `'Model'`
# correctly (as otherwise models would have to be rebuilt even though this
# doesn't look necessary).
```
"""
def __init__(
self,
namespaces_tuple: NamespacesTuple | None = None,
parent_namespace: MappingNamespace | None = None,
) -> None:
self._base_ns_tuple = namespaces_tuple or NamespacesTuple({}, {})
self._parent_ns = parent_namespace
self._types_stack: list[type[Any] | TypeAliasType] = []
@cached_property
def types_namespace(self) -> NamespacesTuple:
"""The current global and local namespaces to be used for annotations evaluation."""
if not self._types_stack:
# TODO: should we merge the parent namespace here?
# This is relevant for TypeAdapter, where there are no types on the stack, and we might
# need access to the parent_ns. Right now, we sidestep this in `type_adapter.py` by passing
# locals to both parent_ns and the base_ns_tuple, but this is a bit hacky.
# we might consider something like:
# if self._parent_ns is not None:
# # Hacky workarounds, see class docstring:
# # An optional parent namespace that will be added to the locals with the lowest priority
# locals_list: list[MappingNamespace] = [self._parent_ns, self._base_ns_tuple.locals]
# return NamespacesTuple(self._base_ns_tuple.globals, LazyLocalNamespace(*locals_list))
return self._base_ns_tuple
typ = self._types_stack[-1]
globalns = get_module_ns_of(typ)
locals_list: list[MappingNamespace] = []
# Hacky workarounds, see class docstring:
# An optional parent namespace that will be added to the locals with the lowest priority
if self._parent_ns is not None:
locals_list.append(self._parent_ns)
if len(self._types_stack) > 1:
first_type = self._types_stack[0]
locals_list.append({first_type.__name__: first_type})
# Adding `__type_params__` *before* `vars(typ)`, as the latter takes priority
# (see https://github.com/python/cpython/pull/120272).
# TODO `typ.__type_params__` when we drop support for Python 3.11:
type_params: tuple[_TypeVarLike, ...] = getattr(typ, '__type_params__', ())
if type_params:
# Adding `__type_params__` is mostly useful for generic classes defined using
# PEP 695 syntax *and* using forward annotations (see the example in
# https://github.com/python/cpython/issues/114053). For TypeAliasType instances,
# it is way less common, but still required if using a string annotation in the alias
# value, e.g. `type A[T] = 'T'` (which is not necessary in most cases).
locals_list.append({t.__name__: t for t in type_params})
# TypeAliasType instances don't have a `__dict__` attribute, so the check
# is necessary:
if hasattr(typ, '__dict__'):
locals_list.append(vars(typ))
# The `len(self._types_stack) > 1` check above prevents this from being added twice:
locals_list.append({typ.__name__: typ})
return NamespacesTuple(globalns, LazyLocalNamespace(*locals_list))
@contextmanager
def push(self, typ: type[Any] | TypeAliasType, /) -> Generator[None]:
"""Push a type to the stack."""
self._types_stack.append(typ)
# Reset the cached property:
self.__dict__.pop('types_namespace', None)
try:
yield
finally:
self._types_stack.pop()
self.__dict__.pop('types_namespace', None)

View File

@ -0,0 +1,125 @@
"""Tools to provide pretty/human-readable display of objects."""
from __future__ import annotations as _annotations
import types
import typing
from typing import Any
import typing_extensions
from typing_inspection import typing_objects
from typing_inspection.introspection import is_union_origin
from . import _typing_extra
if typing.TYPE_CHECKING:
ReprArgs: typing_extensions.TypeAlias = 'typing.Iterable[tuple[str | None, Any]]'
RichReprResult: typing_extensions.TypeAlias = (
'typing.Iterable[Any | tuple[Any] | tuple[str, Any] | tuple[str, Any, Any]]'
)
class PlainRepr(str):
"""String class where repr doesn't include quotes. Useful with Representation when you want to return a string
representation of something that is valid (or pseudo-valid) python.
"""
def __repr__(self) -> str:
return str(self)
class Representation:
# Mixin to provide `__str__`, `__repr__`, and `__pretty__` and `__rich_repr__` methods.
# `__pretty__` is used by [devtools](https://python-devtools.helpmanual.io/).
# `__rich_repr__` is used by [rich](https://rich.readthedocs.io/en/stable/pretty.html).
# (this is not a docstring to avoid adding a docstring to classes which inherit from Representation)
# we don't want to use a type annotation here as it can break get_type_hints
__slots__ = () # type: typing.Collection[str]
def __repr_args__(self) -> ReprArgs:
"""Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden.
Can either return:
* name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]`
* or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]`
"""
attrs_names = self.__slots__
if not attrs_names and hasattr(self, '__dict__'):
attrs_names = self.__dict__.keys()
attrs = ((s, getattr(self, s)) for s in attrs_names)
return [(a, v if v is not self else self.__repr_recursion__(v)) for a, v in attrs if v is not None]
def __repr_name__(self) -> str:
"""Name of the instance's class, used in __repr__."""
return self.__class__.__name__
def __repr_recursion__(self, object: Any) -> str:
"""Returns the string representation of a recursive object."""
# This is copied over from the stdlib `pprint` module:
return f'<Recursion on {type(object).__name__} with id={id(object)}>'
def __repr_str__(self, join_str: str) -> str:
return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
def __pretty__(self, fmt: typing.Callable[[Any], Any], **kwargs: Any) -> typing.Generator[Any, None, None]:
"""Used by devtools (https://python-devtools.helpmanual.io/) to pretty print objects."""
yield self.__repr_name__() + '('
yield 1
for name, value in self.__repr_args__():
if name is not None:
yield name + '='
yield fmt(value)
yield ','
yield 0
yield -1
yield ')'
def __rich_repr__(self) -> RichReprResult:
"""Used by Rich (https://rich.readthedocs.io/en/stable/pretty.html) to pretty print objects."""
for name, field_repr in self.__repr_args__():
if name is None:
yield field_repr
else:
yield name, field_repr
def __str__(self) -> str:
return self.__repr_str__(' ')
def __repr__(self) -> str:
return f'{self.__repr_name__()}({self.__repr_str__(", ")})'
def display_as_type(obj: Any) -> str:
"""Pretty representation of a type, should be as close as possible to the original type definition string.
Takes some logic from `typing._type_repr`.
"""
if isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)):
return obj.__name__
elif obj is ...:
return '...'
elif isinstance(obj, Representation):
return repr(obj)
elif isinstance(obj, typing.ForwardRef) or typing_objects.is_typealiastype(obj):
return str(obj)
if not isinstance(obj, (_typing_extra.typing_base, _typing_extra.WithArgsTypes, type)):
obj = obj.__class__
if is_union_origin(typing_extensions.get_origin(obj)):
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
return f'Union[{args}]'
elif isinstance(obj, _typing_extra.WithArgsTypes):
if typing_objects.is_literal(typing_extensions.get_origin(obj)):
args = ', '.join(map(repr, typing_extensions.get_args(obj)))
else:
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
try:
return f'{obj.__qualname__}[{args}]'
except AttributeError:
return str(obj).replace('typing.', '').replace('typing_extensions.', '') # handles TypeAliasType in 3.12
elif isinstance(obj, type):
return obj.__qualname__
else:
return repr(obj).replace('typing.', '').replace('typing_extensions.', '')

View File

@ -0,0 +1,204 @@
# pyright: reportTypedDictNotRequiredAccess=false, reportGeneralTypeIssues=false, reportArgumentType=false, reportAttributeAccessIssue=false
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TypedDict
from pydantic_core.core_schema import ComputedField, CoreSchema, DefinitionReferenceSchema, SerSchema
from typing_extensions import TypeAlias
AllSchemas: TypeAlias = 'CoreSchema | SerSchema | ComputedField'
class GatherResult(TypedDict):
"""Schema traversing result."""
collected_references: dict[str, DefinitionReferenceSchema | None]
"""The collected definition references.
If a definition reference schema can be inlined, it means that there is
only one in the whole core schema. As such, it is stored as the value.
Otherwise, the value is set to `None`.
"""
deferred_discriminator_schemas: list[CoreSchema]
"""The list of core schemas having the discriminator application deferred."""
class MissingDefinitionError(LookupError):
"""A reference was pointing to a non-existing core schema."""
def __init__(self, schema_reference: str, /) -> None:
self.schema_reference = schema_reference
@dataclass
class GatherContext:
"""The current context used during core schema traversing.
Context instances should only be used during schema traversing.
"""
definitions: dict[str, CoreSchema]
"""The available definitions."""
deferred_discriminator_schemas: list[CoreSchema] = field(init=False, default_factory=list)
"""The list of core schemas having the discriminator application deferred.
Internally, these core schemas have a specific key set in the core metadata dict.
"""
collected_references: dict[str, DefinitionReferenceSchema | None] = field(init=False, default_factory=dict)
"""The collected definition references.
If a definition reference schema can be inlined, it means that there is
only one in the whole core schema. As such, it is stored as the value.
Otherwise, the value is set to `None`.
During schema traversing, definition reference schemas can be added as candidates, or removed
(by setting the value to `None`).
"""
def traverse_metadata(schema: AllSchemas, ctx: GatherContext) -> None:
meta = schema.get('metadata')
if meta is not None and 'pydantic_internal_union_discriminator' in meta:
ctx.deferred_discriminator_schemas.append(schema) # pyright: ignore[reportArgumentType]
def traverse_definition_ref(def_ref_schema: DefinitionReferenceSchema, ctx: GatherContext) -> None:
schema_ref = def_ref_schema['schema_ref']
if schema_ref not in ctx.collected_references:
definition = ctx.definitions.get(schema_ref)
if definition is None:
raise MissingDefinitionError(schema_ref)
# The `'definition-ref'` schema was only encountered once, make it
# a candidate to be inlined:
ctx.collected_references[schema_ref] = def_ref_schema
traverse_schema(definition, ctx)
if 'serialization' in def_ref_schema:
traverse_schema(def_ref_schema['serialization'], ctx)
traverse_metadata(def_ref_schema, ctx)
else:
# The `'definition-ref'` schema was already encountered, meaning
# the previously encountered schema (and this one) can't be inlined:
ctx.collected_references[schema_ref] = None
def traverse_schema(schema: AllSchemas, context: GatherContext) -> None:
# TODO When we drop 3.9, use a match statement to get better type checking and remove
# file-level type ignore.
# (the `'type'` could also be fetched in every `if/elif` statement, but this alters performance).
schema_type = schema['type']
if schema_type == 'definition-ref':
traverse_definition_ref(schema, context)
# `traverse_definition_ref` handles the possible serialization and metadata schemas:
return
elif schema_type == 'definitions':
traverse_schema(schema['schema'], context)
for definition in schema['definitions']:
traverse_schema(definition, context)
elif schema_type in {'list', 'set', 'frozenset', 'generator'}:
if 'items_schema' in schema:
traverse_schema(schema['items_schema'], context)
elif schema_type == 'tuple':
if 'items_schema' in schema:
for s in schema['items_schema']:
traverse_schema(s, context)
elif schema_type == 'dict':
if 'keys_schema' in schema:
traverse_schema(schema['keys_schema'], context)
if 'values_schema' in schema:
traverse_schema(schema['values_schema'], context)
elif schema_type == 'union':
for choice in schema['choices']:
if isinstance(choice, tuple):
traverse_schema(choice[0], context)
else:
traverse_schema(choice, context)
elif schema_type == 'tagged-union':
for v in schema['choices'].values():
traverse_schema(v, context)
elif schema_type == 'chain':
for step in schema['steps']:
traverse_schema(step, context)
elif schema_type == 'lax-or-strict':
traverse_schema(schema['lax_schema'], context)
traverse_schema(schema['strict_schema'], context)
elif schema_type == 'json-or-python':
traverse_schema(schema['json_schema'], context)
traverse_schema(schema['python_schema'], context)
elif schema_type in {'model-fields', 'typed-dict'}:
if 'extras_schema' in schema:
traverse_schema(schema['extras_schema'], context)
if 'computed_fields' in schema:
for s in schema['computed_fields']:
traverse_schema(s, context)
for s in schema['fields'].values():
traverse_schema(s, context)
elif schema_type == 'dataclass-args':
if 'computed_fields' in schema:
for s in schema['computed_fields']:
traverse_schema(s, context)
for s in schema['fields']:
traverse_schema(s, context)
elif schema_type == 'arguments':
for s in schema['arguments_schema']:
traverse_schema(s['schema'], context)
if 'var_args_schema' in schema:
traverse_schema(schema['var_args_schema'], context)
if 'var_kwargs_schema' in schema:
traverse_schema(schema['var_kwargs_schema'], context)
elif schema_type == 'arguments-v3':
for s in schema['arguments_schema']:
traverse_schema(s['schema'], context)
elif schema_type == 'call':
traverse_schema(schema['arguments_schema'], context)
if 'return_schema' in schema:
traverse_schema(schema['return_schema'], context)
elif schema_type == 'computed-field':
traverse_schema(schema['return_schema'], context)
elif schema_type == 'function-plain':
# TODO duplicate schema types for serializers and validators, needs to be deduplicated.
if 'return_schema' in schema:
traverse_schema(schema['return_schema'], context)
if 'json_schema_input_schema' in schema:
traverse_schema(schema['json_schema_input_schema'], context)
elif schema_type == 'function-wrap':
# TODO duplicate schema types for serializers and validators, needs to be deduplicated.
if 'return_schema' in schema:
traverse_schema(schema['return_schema'], context)
if 'schema' in schema:
traverse_schema(schema['schema'], context)
if 'json_schema_input_schema' in schema:
traverse_schema(schema['json_schema_input_schema'], context)
else:
if 'schema' in schema:
traverse_schema(schema['schema'], context)
if 'serialization' in schema:
traverse_schema(schema['serialization'], context)
traverse_metadata(schema, context)
def gather_schemas_for_cleaning(schema: CoreSchema, definitions: dict[str, CoreSchema]) -> GatherResult:
"""Traverse the core schema and definitions and return the necessary information for schema cleaning.
During the core schema traversing, any `'definition-ref'` schema is:
- Validated: the reference must point to an existing definition. If this is not the case, a
`MissingDefinitionError` exception is raised.
- Stored in the context: the actual reference is stored in the context. Depending on whether
the `'definition-ref'` schema is encountered more that once, the schema itself is also
saved in the context to be inlined (i.e. replaced by the definition it points to).
"""
context = GatherContext(definitions)
traverse_schema(schema, context)
return {
'collected_references': context.collected_references,
'deferred_discriminator_schemas': context.deferred_discriminator_schemas,
}

View File

@ -0,0 +1,125 @@
"""Types and utility functions used by various other internal tools."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, Literal
from pydantic_core import core_schema
from ..annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
if TYPE_CHECKING:
from ..json_schema import GenerateJsonSchema, JsonSchemaValue
from ._core_utils import CoreSchemaOrField
from ._generate_schema import GenerateSchema
from ._namespace_utils import NamespacesTuple
GetJsonSchemaFunction = Callable[[CoreSchemaOrField, GetJsonSchemaHandler], JsonSchemaValue]
HandlerOverride = Callable[[CoreSchemaOrField], JsonSchemaValue]
class GenerateJsonSchemaHandler(GetJsonSchemaHandler):
"""JsonSchemaHandler implementation that doesn't do ref unwrapping by default.
This is used for any Annotated metadata so that we don't end up with conflicting
modifications to the definition schema.
Used internally by Pydantic, please do not rely on this implementation.
See `GetJsonSchemaHandler` for the handler API.
"""
def __init__(self, generate_json_schema: GenerateJsonSchema, handler_override: HandlerOverride | None) -> None:
self.generate_json_schema = generate_json_schema
self.handler = handler_override or generate_json_schema.generate_inner
self.mode = generate_json_schema.mode
def __call__(self, core_schema: CoreSchemaOrField, /) -> JsonSchemaValue:
return self.handler(core_schema)
def resolve_ref_schema(self, maybe_ref_json_schema: JsonSchemaValue) -> JsonSchemaValue:
"""Resolves `$ref` in the json schema.
This returns the input json schema if there is no `$ref` in json schema.
Args:
maybe_ref_json_schema: The input json schema that may contains `$ref`.
Returns:
Resolved json schema.
Raises:
LookupError: If it can't find the definition for `$ref`.
"""
if '$ref' not in maybe_ref_json_schema:
return maybe_ref_json_schema
ref = maybe_ref_json_schema['$ref']
json_schema = self.generate_json_schema.get_schema_from_definitions(ref)
if json_schema is None:
raise LookupError(
f'Could not find a ref for {ref}.'
' Maybe you tried to call resolve_ref_schema from within a recursive model?'
)
return json_schema
class CallbackGetCoreSchemaHandler(GetCoreSchemaHandler):
"""Wrapper to use an arbitrary function as a `GetCoreSchemaHandler`.
Used internally by Pydantic, please do not rely on this implementation.
See `GetCoreSchemaHandler` for the handler API.
"""
def __init__(
self,
handler: Callable[[Any], core_schema.CoreSchema],
generate_schema: GenerateSchema,
ref_mode: Literal['to-def', 'unpack'] = 'to-def',
) -> None:
self._handler = handler
self._generate_schema = generate_schema
self._ref_mode = ref_mode
def __call__(self, source_type: Any, /) -> core_schema.CoreSchema:
schema = self._handler(source_type)
if self._ref_mode == 'to-def':
ref = schema.get('ref')
if ref is not None:
return self._generate_schema.defs.create_definition_reference_schema(schema)
return schema
else: # ref_mode = 'unpack'
return self.resolve_ref_schema(schema)
def _get_types_namespace(self) -> NamespacesTuple:
return self._generate_schema._types_namespace
def generate_schema(self, source_type: Any, /) -> core_schema.CoreSchema:
return self._generate_schema.generate_schema(source_type)
@property
def field_name(self) -> str | None:
return self._generate_schema.field_name_stack.get()
def resolve_ref_schema(self, maybe_ref_schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
"""Resolves reference in the core schema.
Args:
maybe_ref_schema: The input core schema that may contains reference.
Returns:
Resolved core schema.
Raises:
LookupError: If it can't find the definition for reference.
"""
if maybe_ref_schema['type'] == 'definition-ref':
ref = maybe_ref_schema['schema_ref']
definition = self._generate_schema.defs.get_schema_from_ref(ref)
if definition is None:
raise LookupError(
f'Could not find a ref for {ref}.'
' Maybe you tried to call resolve_ref_schema from within a recursive model?'
)
return definition
elif maybe_ref_schema['type'] == 'definitions':
return self.resolve_ref_schema(maybe_ref_schema['schema'])
return maybe_ref_schema

View File

@ -0,0 +1,53 @@
from __future__ import annotations
import collections
import collections.abc
import typing
from typing import Any
from pydantic_core import PydanticOmit, core_schema
SEQUENCE_ORIGIN_MAP: dict[Any, Any] = {
typing.Deque: collections.deque, # noqa: UP006
collections.deque: collections.deque,
list: list,
typing.List: list, # noqa: UP006
tuple: tuple,
typing.Tuple: tuple, # noqa: UP006
set: set,
typing.AbstractSet: set,
typing.Set: set, # noqa: UP006
frozenset: frozenset,
typing.FrozenSet: frozenset, # noqa: UP006
typing.Sequence: list,
typing.MutableSequence: list,
typing.MutableSet: set,
# this doesn't handle subclasses of these
# parametrized typing.Set creates one of these
collections.abc.MutableSet: set,
collections.abc.Set: frozenset,
}
def serialize_sequence_via_list(
v: Any, handler: core_schema.SerializerFunctionWrapHandler, info: core_schema.SerializationInfo
) -> Any:
items: list[Any] = []
mapped_origin = SEQUENCE_ORIGIN_MAP.get(type(v), None)
if mapped_origin is None:
# we shouldn't hit this branch, should probably add a serialization error or something
return v
for index, item in enumerate(v):
try:
v = handler(item, index)
except PydanticOmit:
pass
else:
items.append(v)
if info.mode_is_json():
return items
else:
return mapped_origin(items)

View File

@ -0,0 +1,188 @@
from __future__ import annotations
import dataclasses
from inspect import Parameter, Signature, signature
from typing import TYPE_CHECKING, Any, Callable
from pydantic_core import PydanticUndefined
from ._utils import is_valid_identifier
if TYPE_CHECKING:
from ..config import ExtraValues
from ..fields import FieldInfo
# Copied over from stdlib dataclasses
class _HAS_DEFAULT_FACTORY_CLASS:
def __repr__(self):
return '<factory>'
_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS()
def _field_name_for_signature(field_name: str, field_info: FieldInfo) -> str:
"""Extract the correct name to use for the field when generating a signature.
Assuming the field has a valid alias, this will return the alias. Otherwise, it will return the field name.
First priority is given to the alias, then the validation_alias, then the field name.
Args:
field_name: The name of the field
field_info: The corresponding FieldInfo object.
Returns:
The correct name to use when generating a signature.
"""
if isinstance(field_info.alias, str) and is_valid_identifier(field_info.alias):
return field_info.alias
if isinstance(field_info.validation_alias, str) and is_valid_identifier(field_info.validation_alias):
return field_info.validation_alias
return field_name
def _process_param_defaults(param: Parameter) -> Parameter:
"""Modify the signature for a parameter in a dataclass where the default value is a FieldInfo instance.
Args:
param (Parameter): The parameter
Returns:
Parameter: The custom processed parameter
"""
from ..fields import FieldInfo
param_default = param.default
if isinstance(param_default, FieldInfo):
annotation = param.annotation
# Replace the annotation if appropriate
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if annotation == 'Any':
annotation = Any
# Replace the field default
default = param_default.default
if default is PydanticUndefined:
if param_default.default_factory is PydanticUndefined:
default = Signature.empty
else:
# this is used by dataclasses to indicate a factory exists:
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
return param.replace(
annotation=annotation, name=_field_name_for_signature(param.name, param_default), default=default
)
return param
def _generate_signature_parameters( # noqa: C901 (ignore complexity, could use a refactor)
init: Callable[..., None],
fields: dict[str, FieldInfo],
validate_by_name: bool,
extra: ExtraValues | None,
) -> dict[str, Parameter]:
"""Generate a mapping of parameter names to Parameter objects for a pydantic BaseModel or dataclass."""
from itertools import islice
present_params = signature(init).parameters.values()
merged_params: dict[str, Parameter] = {}
var_kw = None
use_var_kw = False
for param in islice(present_params, 1, None): # skip self arg
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if fields.get(param.name):
# exclude params with init=False
if getattr(fields[param.name], 'init', True) is False:
continue
param = param.replace(name=_field_name_for_signature(param.name, fields[param.name]))
if param.annotation == 'Any':
param = param.replace(annotation=Any)
if param.kind is param.VAR_KEYWORD:
var_kw = param
continue
merged_params[param.name] = param
if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
allow_names = validate_by_name
for field_name, field in fields.items():
# when alias is a str it should be used for signature generation
param_name = _field_name_for_signature(field_name, field)
if field_name in merged_params or param_name in merged_params:
continue
if not is_valid_identifier(param_name):
if allow_names:
param_name = field_name
else:
use_var_kw = True
continue
if field.is_required():
default = Parameter.empty
elif field.default_factory is not None:
# Mimics stdlib dataclasses:
default = _HAS_DEFAULT_FACTORY
else:
default = field.default
merged_params[param_name] = Parameter(
param_name,
Parameter.KEYWORD_ONLY,
annotation=field.rebuild_annotation(),
default=default,
)
if extra == 'allow':
use_var_kw = True
if var_kw and use_var_kw:
# Make sure the parameter for extra kwargs
# does not have the same name as a field
default_model_signature = [
('self', Parameter.POSITIONAL_ONLY),
('data', Parameter.VAR_KEYWORD),
]
if [(p.name, p.kind) for p in present_params] == default_model_signature:
# if this is the standard model signature, use extra_data as the extra args name
var_kw_name = 'extra_data'
else:
# else start from var_kw
var_kw_name = var_kw.name
# generate a name that's definitely unique
while var_kw_name in fields:
var_kw_name += '_'
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
return merged_params
def generate_pydantic_signature(
init: Callable[..., None],
fields: dict[str, FieldInfo],
validate_by_name: bool,
extra: ExtraValues | None,
is_dataclass: bool = False,
) -> Signature:
"""Generate signature for a pydantic BaseModel or dataclass.
Args:
init: The class init.
fields: The model fields.
validate_by_name: The `validate_by_name` value of the config.
extra: The `extra` value of the config.
is_dataclass: Whether the model is a dataclass.
Returns:
The dataclass/BaseModel subclass signature.
"""
merged_params = _generate_signature_parameters(init, fields, validate_by_name, extra)
if is_dataclass:
merged_params = {k: _process_param_defaults(v) for k, v in merged_params.items()}
return Signature(parameters=list(merged_params.values()), return_annotation=None)

View File

@ -0,0 +1,714 @@
"""Logic for interacting with type annotations, mostly extensions, shims and hacks to wrap Python's typing module."""
from __future__ import annotations
import collections.abc
import re
import sys
import types
import typing
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, cast
import typing_extensions
from typing_extensions import deprecated, get_args, get_origin
from typing_inspection import typing_objects
from typing_inspection.introspection import is_union_origin
from pydantic.version import version_short
from ._namespace_utils import GlobalsNamespace, MappingNamespace, NsResolver, get_module_ns_of
if sys.version_info < (3, 10):
NoneType = type(None)
EllipsisType = type(Ellipsis)
else:
from types import EllipsisType as EllipsisType
from types import NoneType as NoneType
if TYPE_CHECKING:
from pydantic import BaseModel
# As per https://typing-extensions.readthedocs.io/en/latest/#runtime-use-of-types,
# always check for both `typing` and `typing_extensions` variants of a typing construct.
# (this is implemented differently than the suggested approach in the `typing_extensions`
# docs for performance).
_t_annotated = typing.Annotated
_te_annotated = typing_extensions.Annotated
def is_annotated(tp: Any, /) -> bool:
"""Return whether the provided argument is a `Annotated` special form.
```python {test="skip" lint="skip"}
is_annotated(Annotated[int, ...])
#> True
```
"""
origin = get_origin(tp)
return origin is _t_annotated or origin is _te_annotated
def annotated_type(tp: Any, /) -> Any | None:
"""Return the type of the `Annotated` special form, or `None`."""
return tp.__origin__ if typing_objects.is_annotated(get_origin(tp)) else None
def unpack_type(tp: Any, /) -> Any | None:
"""Return the type wrapped by the `Unpack` special form, or `None`."""
return get_args(tp)[0] if typing_objects.is_unpack(get_origin(tp)) else None
def is_hashable(tp: Any, /) -> bool:
"""Return whether the provided argument is the `Hashable` class.
```python {test="skip" lint="skip"}
is_hashable(Hashable)
#> True
```
"""
# `get_origin` is documented as normalizing any typing-module aliases to `collections` classes,
# hence the second check:
return tp is collections.abc.Hashable or get_origin(tp) is collections.abc.Hashable
def is_callable(tp: Any, /) -> bool:
"""Return whether the provided argument is a `Callable`, parametrized or not.
```python {test="skip" lint="skip"}
is_callable(Callable[[int], str])
#> True
is_callable(typing.Callable)
#> True
is_callable(collections.abc.Callable)
#> True
```
"""
# `get_origin` is documented as normalizing any typing-module aliases to `collections` classes,
# hence the second check:
return tp is collections.abc.Callable or get_origin(tp) is collections.abc.Callable
_classvar_re = re.compile(r'((\w+\.)?Annotated\[)?(\w+\.)?ClassVar\[')
def is_classvar_annotation(tp: Any, /) -> bool:
"""Return whether the provided argument represents a class variable annotation.
Although not explicitly stated by the typing specification, `ClassVar` can be used
inside `Annotated` and as such, this function checks for this specific scenario.
Because this function is used to detect class variables before evaluating forward references
(or because evaluation failed), we also implement a naive regex match implementation. This is
required because class variables are inspected before fields are collected, so we try to be
as accurate as possible.
"""
if typing_objects.is_classvar(tp):
return True
origin = get_origin(tp)
if typing_objects.is_classvar(origin):
return True
if typing_objects.is_annotated(origin):
annotated_type = tp.__origin__
if typing_objects.is_classvar(annotated_type) or typing_objects.is_classvar(get_origin(annotated_type)):
return True
str_ann: str | None = None
if isinstance(tp, typing.ForwardRef):
str_ann = tp.__forward_arg__
if isinstance(tp, str):
str_ann = tp
if str_ann is not None and _classvar_re.match(str_ann):
# stdlib dataclasses do something similar, although a bit more advanced
# (see `dataclass._is_type`).
return True
return False
_t_final = typing.Final
_te_final = typing_extensions.Final
# TODO implement `is_finalvar_annotation` as Final can be wrapped with other special forms:
def is_finalvar(tp: Any, /) -> bool:
"""Return whether the provided argument is a `Final` special form, parametrized or not.
```python {test="skip" lint="skip"}
is_finalvar(Final[int])
#> True
is_finalvar(Final)
#> True
"""
# Final is not necessarily parametrized:
if tp is _t_final or tp is _te_final:
return True
origin = get_origin(tp)
return origin is _t_final or origin is _te_final
_NONE_TYPES: tuple[Any, ...] = (None, NoneType, typing.Literal[None], typing_extensions.Literal[None])
def is_none_type(tp: Any, /) -> bool:
"""Return whether the argument represents the `None` type as part of an annotation.
```python {test="skip" lint="skip"}
is_none_type(None)
#> True
is_none_type(NoneType)
#> True
is_none_type(Literal[None])
#> True
is_none_type(type[None])
#> False
"""
return tp in _NONE_TYPES
def is_namedtuple(tp: Any, /) -> bool:
"""Return whether the provided argument is a named tuple class.
The class can be created using `typing.NamedTuple` or `collections.namedtuple`.
Parametrized generic classes are *not* assumed to be named tuples.
"""
from ._utils import lenient_issubclass # circ. import
return lenient_issubclass(tp, tuple) and hasattr(tp, '_fields')
# TODO In 2.12, delete this export. It is currently defined only to not break
# pydantic-settings which relies on it:
origin_is_union = is_union_origin
def is_generic_alias(tp: Any, /) -> bool:
return isinstance(tp, (types.GenericAlias, typing._GenericAlias)) # pyright: ignore[reportAttributeAccessIssue]
# TODO: Ideally, we should avoid relying on the private `typing` constructs:
if sys.version_info < (3, 10):
WithArgsTypes: tuple[Any, ...] = (typing._GenericAlias, types.GenericAlias) # pyright: ignore[reportAttributeAccessIssue]
else:
WithArgsTypes: tuple[Any, ...] = (typing._GenericAlias, types.GenericAlias, types.UnionType) # pyright: ignore[reportAttributeAccessIssue]
# Similarly, we shouldn't rely on this `_Final` class, which is even more private than `_GenericAlias`:
typing_base: Any = typing._Final # pyright: ignore[reportAttributeAccessIssue]
### Annotation evaluations functions:
def parent_frame_namespace(*, parent_depth: int = 2, force: bool = False) -> dict[str, Any] | None:
"""Fetch the local namespace of the parent frame where this function is called.
Using this function is mostly useful to resolve forward annotations pointing to members defined in a local namespace,
such as assignments inside a function. Using the standard library tools, it is currently not possible to resolve
such annotations:
```python {lint="skip" test="skip"}
from typing import get_type_hints
def func() -> None:
Alias = int
class C:
a: 'Alias'
# Raises a `NameError: 'Alias' is not defined`
get_type_hints(C)
```
Pydantic uses this function when a Pydantic model is being defined to fetch the parent frame locals. However,
this only allows us to fetch the parent frame namespace and not other parents (e.g. a model defined in a function,
itself defined in another function). Inspecting the next outer frames (using `f_back`) is not reliable enough
(see https://discuss.python.org/t/20659).
Because this function is mostly used to better resolve forward annotations, nothing is returned if the parent frame's
code object is defined at the module level. In this case, the locals of the frame will be the same as the module
globals where the class is defined (see `_namespace_utils.get_module_ns_of`). However, if you still want to fetch
the module globals (e.g. when rebuilding a model, where the frame where the rebuild call is performed might contain
members that you want to use for forward annotations evaluation), you can use the `force` parameter.
Args:
parent_depth: The depth at which to get the frame. Defaults to 2, meaning the parent frame where this function
is called will be used.
force: Whether to always return the frame locals, even if the frame's code object is defined at the module level.
Returns:
The locals of the namespace, or `None` if it was skipped as per the described logic.
"""
frame = sys._getframe(parent_depth)
if frame.f_code.co_name.startswith('<generic parameters of'):
# As `parent_frame_namespace` is mostly called in `ModelMetaclass.__new__`,
# the parent frame can be the annotation scope if the PEP 695 generic syntax is used.
# (see https://docs.python.org/3/reference/executionmodel.html#annotation-scopes,
# https://docs.python.org/3/reference/compound_stmts.html#generic-classes).
# In this case, the code name is set to `<generic parameters of MyClass>`,
# and we need to skip this frame as it is irrelevant.
frame = cast(types.FrameType, frame.f_back) # guaranteed to not be `None`
# note, we don't copy frame.f_locals here (or during the last return call), because we don't expect the namespace to be
# modified down the line if this becomes a problem, we could implement some sort of frozen mapping structure to enforce this.
if force:
return frame.f_locals
# If either of the following conditions are true, the class is defined at the top module level.
# To better understand why we need both of these checks, see
# https://github.com/pydantic/pydantic/pull/10113#discussion_r1714981531.
if frame.f_back is None or frame.f_code.co_name == '<module>':
return None
return frame.f_locals
def _type_convert(arg: Any) -> Any:
"""Convert `None` to `NoneType` and strings to `ForwardRef` instances.
This is a backport of the private `typing._type_convert` function. When
evaluating a type, `ForwardRef._evaluate` ends up being called, and is
responsible for making this conversion. However, we still have to apply
it for the first argument passed to our type evaluation functions, similarly
to the `typing.get_type_hints` function.
"""
if arg is None:
return NoneType
if isinstance(arg, str):
# Like `typing.get_type_hints`, assume the arg can be in any context,
# hence the proper `is_argument` and `is_class` args:
return _make_forward_ref(arg, is_argument=False, is_class=True)
return arg
def get_model_type_hints(
obj: type[BaseModel],
*,
ns_resolver: NsResolver | None = None,
) -> dict[str, tuple[Any, bool]]:
"""Collect annotations from a Pydantic model class, including those from parent classes.
Args:
obj: The Pydantic model to inspect.
ns_resolver: A namespace resolver instance to use. Defaults to an empty instance.
Returns:
A dictionary mapping annotation names to a two-tuple: the first element is the evaluated
type or the original annotation if a `NameError` occurred, the second element is a boolean
indicating if whether the evaluation succeeded.
"""
hints: dict[str, Any] | dict[str, tuple[Any, bool]] = {}
ns_resolver = ns_resolver or NsResolver()
for base in reversed(obj.__mro__):
ann: dict[str, Any] | None = base.__dict__.get('__annotations__')
if not ann or isinstance(ann, types.GetSetDescriptorType):
continue
with ns_resolver.push(base):
globalns, localns = ns_resolver.types_namespace
for name, value in ann.items():
if name.startswith('_'):
# For private attributes, we only need the annotation to detect the `ClassVar` special form.
# For this reason, we still try to evaluate it, but we also catch any possible exception (on
# top of the `NameError`s caught in `try_eval_type`) that could happen so that users are free
# to use any kind of forward annotation for private fields (e.g. circular imports, new typing
# syntax, etc).
try:
hints[name] = try_eval_type(value, globalns, localns)
except Exception:
hints[name] = (value, False)
else:
hints[name] = try_eval_type(value, globalns, localns)
return hints
def get_cls_type_hints(
obj: type[Any],
*,
ns_resolver: NsResolver | None = None,
) -> dict[str, Any]:
"""Collect annotations from a class, including those from parent classes.
Args:
obj: The class to inspect.
ns_resolver: A namespace resolver instance to use. Defaults to an empty instance.
"""
hints: dict[str, Any] | dict[str, tuple[Any, bool]] = {}
ns_resolver = ns_resolver or NsResolver()
for base in reversed(obj.__mro__):
ann: dict[str, Any] | None = base.__dict__.get('__annotations__')
if not ann or isinstance(ann, types.GetSetDescriptorType):
continue
with ns_resolver.push(base):
globalns, localns = ns_resolver.types_namespace
for name, value in ann.items():
hints[name] = eval_type(value, globalns, localns)
return hints
def try_eval_type(
value: Any,
globalns: GlobalsNamespace | None = None,
localns: MappingNamespace | None = None,
) -> tuple[Any, bool]:
"""Try evaluating the annotation using the provided namespaces.
Args:
value: The value to evaluate. If `None`, it will be replaced by `type[None]`. If an instance
of `str`, it will be converted to a `ForwardRef`.
localns: The global namespace to use during annotation evaluation.
globalns: The local namespace to use during annotation evaluation.
Returns:
A two-tuple containing the possibly evaluated type and a boolean indicating
whether the evaluation succeeded or not.
"""
value = _type_convert(value)
try:
return eval_type_backport(value, globalns, localns), True
except NameError:
return value, False
def eval_type(
value: Any,
globalns: GlobalsNamespace | None = None,
localns: MappingNamespace | None = None,
) -> Any:
"""Evaluate the annotation using the provided namespaces.
Args:
value: The value to evaluate. If `None`, it will be replaced by `type[None]`. If an instance
of `str`, it will be converted to a `ForwardRef`.
localns: The global namespace to use during annotation evaluation.
globalns: The local namespace to use during annotation evaluation.
"""
value = _type_convert(value)
return eval_type_backport(value, globalns, localns)
@deprecated(
'`eval_type_lenient` is deprecated, use `try_eval_type` instead.',
category=None,
)
def eval_type_lenient(
value: Any,
globalns: GlobalsNamespace | None = None,
localns: MappingNamespace | None = None,
) -> Any:
ev, _ = try_eval_type(value, globalns, localns)
return ev
def eval_type_backport(
value: Any,
globalns: GlobalsNamespace | None = None,
localns: MappingNamespace | None = None,
type_params: tuple[Any, ...] | None = None,
) -> Any:
"""An enhanced version of `typing._eval_type` which will fall back to using the `eval_type_backport`
package if it's installed to let older Python versions use newer typing constructs.
Specifically, this transforms `X | Y` into `typing.Union[X, Y]` and `list[X]` into `typing.List[X]`
(as well as all the types made generic in PEP 585) if the original syntax is not supported in the
current Python version.
This function will also display a helpful error if the value passed fails to evaluate.
"""
try:
return _eval_type_backport(value, globalns, localns, type_params)
except TypeError as e:
if 'Unable to evaluate type annotation' in str(e):
raise
# If it is a `TypeError` and value isn't a `ForwardRef`, it would have failed during annotation definition.
# Thus we assert here for type checking purposes:
assert isinstance(value, typing.ForwardRef)
message = f'Unable to evaluate type annotation {value.__forward_arg__!r}.'
if sys.version_info >= (3, 11):
e.add_note(message)
raise
else:
raise TypeError(message) from e
except RecursionError as e:
# TODO ideally recursion errors should be checked in `eval_type` above, but `eval_type_backport`
# is used directly in some places.
message = (
"If you made use of an implicit recursive type alias (e.g. `MyType = list['MyType']), "
'consider using PEP 695 type aliases instead. For more details, refer to the documentation: '
f'https://docs.pydantic.dev/{version_short()}/concepts/types/#named-recursive-types'
)
if sys.version_info >= (3, 11):
e.add_note(message)
raise
else:
raise RecursionError(f'{e.args[0]}\n{message}')
def _eval_type_backport(
value: Any,
globalns: GlobalsNamespace | None = None,
localns: MappingNamespace | None = None,
type_params: tuple[Any, ...] | None = None,
) -> Any:
try:
return _eval_type(value, globalns, localns, type_params)
except TypeError as e:
if not (isinstance(value, typing.ForwardRef) and is_backport_fixable_error(e)):
raise
try:
from eval_type_backport import eval_type_backport
except ImportError:
raise TypeError(
f'Unable to evaluate type annotation {value.__forward_arg__!r}. If you are making use '
'of the new typing syntax (unions using `|` since Python 3.10 or builtins subscripting '
'since Python 3.9), you should either replace the use of new syntax with the existing '
'`typing` constructs or install the `eval_type_backport` package.'
) from e
return eval_type_backport(
value,
globalns,
localns, # pyright: ignore[reportArgumentType], waiting on a new `eval_type_backport` release.
try_default=False,
)
def _eval_type(
value: Any,
globalns: GlobalsNamespace | None = None,
localns: MappingNamespace | None = None,
type_params: tuple[Any, ...] | None = None,
) -> Any:
if sys.version_info >= (3, 13):
return typing._eval_type( # type: ignore
value, globalns, localns, type_params=type_params
)
else:
return typing._eval_type( # type: ignore
value, globalns, localns
)
def is_backport_fixable_error(e: TypeError) -> bool:
msg = str(e)
return sys.version_info < (3, 10) and msg.startswith('unsupported operand type(s) for |: ')
def get_function_type_hints(
function: Callable[..., Any],
*,
include_keys: set[str] | None = None,
globalns: GlobalsNamespace | None = None,
localns: MappingNamespace | None = None,
) -> dict[str, Any]:
"""Return type hints for a function.
This is similar to the `typing.get_type_hints` function, with a few differences:
- Support `functools.partial` by using the underlying `func` attribute.
- Do not wrap type annotation of a parameter with `Optional` if it has a default value of `None`
(related bug: https://github.com/python/cpython/issues/90353, only fixed in 3.11+).
"""
try:
if isinstance(function, partial):
annotations = function.func.__annotations__
else:
annotations = function.__annotations__
except AttributeError:
# Some functions (e.g. builtins) don't have annotations:
return {}
if globalns is None:
globalns = get_module_ns_of(function)
type_params: tuple[Any, ...] | None = None
if localns is None:
# If localns was specified, it is assumed to already contain type params. This is because
# Pydantic has more advanced logic to do so (see `_namespace_utils.ns_for_function`).
type_params = getattr(function, '__type_params__', ())
type_hints = {}
for name, value in annotations.items():
if include_keys is not None and name not in include_keys:
continue
if value is None:
value = NoneType
elif isinstance(value, str):
value = _make_forward_ref(value)
type_hints[name] = eval_type_backport(value, globalns, localns, type_params)
return type_hints
if sys.version_info < (3, 9, 8) or (3, 10) <= sys.version_info < (3, 10, 1):
def _make_forward_ref(
arg: Any,
is_argument: bool = True,
*,
is_class: bool = False,
) -> typing.ForwardRef:
"""Wrapper for ForwardRef that accounts for the `is_class` argument missing in older versions.
The `module` argument is omitted as it breaks <3.9.8, =3.10.0 and isn't used in the calls below.
See https://github.com/python/cpython/pull/28560 for some background.
The backport happened on 3.9.8, see:
https://github.com/pydantic/pydantic/discussions/6244#discussioncomment-6275458,
and on 3.10.1 for the 3.10 branch, see:
https://github.com/pydantic/pydantic/issues/6912
Implemented as EAFP with memory.
"""
return typing.ForwardRef(arg, is_argument)
else:
_make_forward_ref = typing.ForwardRef
if sys.version_info >= (3, 10):
get_type_hints = typing.get_type_hints
else:
"""
For older versions of python, we have a custom implementation of `get_type_hints` which is a close as possible to
the implementation in CPython 3.10.8.
"""
@typing.no_type_check
def get_type_hints( # noqa: C901
obj: Any,
globalns: dict[str, Any] | None = None,
localns: dict[str, Any] | None = None,
include_extras: bool = False,
) -> dict[str, Any]: # pragma: no cover
"""Taken verbatim from python 3.10.8 unchanged, except:
* type annotations of the function definition above.
* prefixing `typing.` where appropriate
* Use `_make_forward_ref` instead of `typing.ForwardRef` to handle the `is_class` argument.
https://github.com/python/cpython/blob/aaaf5174241496afca7ce4d4584570190ff972fe/Lib/typing.py#L1773-L1875
DO NOT CHANGE THIS METHOD UNLESS ABSOLUTELY NECESSARY.
======================================================
Return type hints for an object.
This is often the same as obj.__annotations__, but it handles
forward references encoded as string literals, adds Optional[t] if a
default value equal to None is set and recursively replaces all
'Annotated[T, ...]' with 'T' (unless 'include_extras=True').
The argument may be a module, class, method, or function. The annotations
are returned as a dictionary. For classes, annotations include also
inherited members.
TypeError is raised if the argument is not of a type that can contain
annotations, and an empty dictionary is returned if no annotations are
present.
BEWARE -- the behavior of globalns and localns is counterintuitive
(unless you are familiar with how eval() and exec() work). The
search order is locals first, then globals.
- If no dict arguments are passed, an attempt is made to use the
globals from obj (or the respective module's globals for classes),
and these are also used as the locals. If the object does not appear
to have globals, an empty dictionary is used. For classes, the search
order is globals first then locals.
- If one dict argument is passed, it is used for both globals and
locals.
- If two dict arguments are passed, they specify globals and
locals, respectively.
"""
if getattr(obj, '__no_type_check__', None):
return {}
# Classes require a special treatment.
if isinstance(obj, type):
hints = {}
for base in reversed(obj.__mro__):
if globalns is None:
base_globals = getattr(sys.modules.get(base.__module__, None), '__dict__', {})
else:
base_globals = globalns
ann = base.__dict__.get('__annotations__', {})
if isinstance(ann, types.GetSetDescriptorType):
ann = {}
base_locals = dict(vars(base)) if localns is None else localns
if localns is None and globalns is None:
# This is surprising, but required. Before Python 3.10,
# get_type_hints only evaluated the globalns of
# a class. To maintain backwards compatibility, we reverse
# the globalns and localns order so that eval() looks into
# *base_globals* first rather than *base_locals*.
# This only affects ForwardRefs.
base_globals, base_locals = base_locals, base_globals
for name, value in ann.items():
if value is None:
value = type(None)
if isinstance(value, str):
value = _make_forward_ref(value, is_argument=False, is_class=True)
value = eval_type_backport(value, base_globals, base_locals)
hints[name] = value
if not include_extras and hasattr(typing, '_strip_annotations'):
return {
k: typing._strip_annotations(t) # type: ignore
for k, t in hints.items()
}
else:
return hints
if globalns is None:
if isinstance(obj, types.ModuleType):
globalns = obj.__dict__
else:
nsobj = obj
# Find globalns for the unwrapped object.
while hasattr(nsobj, '__wrapped__'):
nsobj = nsobj.__wrapped__
globalns = getattr(nsobj, '__globals__', {})
if localns is None:
localns = globalns
elif localns is None:
localns = globalns
hints = getattr(obj, '__annotations__', None)
if hints is None:
# Return empty annotations for something that _could_ have them.
if isinstance(obj, typing._allowed_types): # type: ignore
return {}
else:
raise TypeError(f'{obj!r} is not a module, class, method, or function.')
defaults = typing._get_defaults(obj) # type: ignore
hints = dict(hints)
for name, value in hints.items():
if value is None:
value = type(None)
if isinstance(value, str):
# class-level forward refs were handled above, this must be either
# a module-level annotation or a function argument annotation
value = _make_forward_ref(
value,
is_argument=not isinstance(obj, types.ModuleType),
is_class=False,
)
value = eval_type_backport(value, globalns, localns)
if name in defaults and defaults[name] is None:
value = typing.Optional[value]
hints[name] = value
return hints if include_extras else {k: typing._strip_annotations(t) for k, t in hints.items()} # type: ignore

View File

@ -0,0 +1,429 @@
"""Bucket of reusable internal utilities.
This should be reduced as much as possible with functions only used in one place, moved to that place.
"""
from __future__ import annotations as _annotations
import dataclasses
import keyword
import typing
import warnings
import weakref
from collections import OrderedDict, defaultdict, deque
from collections.abc import Mapping
from copy import deepcopy
from functools import cached_property
from inspect import Parameter
from itertools import zip_longest
from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
from typing import Any, Callable, Generic, TypeVar, overload
from typing_extensions import TypeAlias, TypeGuard, deprecated
from pydantic import PydanticDeprecatedSince211
from . import _repr, _typing_extra
from ._import_utils import import_cached_base_model
if typing.TYPE_CHECKING:
MappingIntStrAny: TypeAlias = 'typing.Mapping[int, Any] | typing.Mapping[str, Any]'
AbstractSetIntStr: TypeAlias = 'typing.AbstractSet[int] | typing.AbstractSet[str]'
from ..main import BaseModel
# these are types that are returned unchanged by deepcopy
IMMUTABLE_NON_COLLECTIONS_TYPES: set[type[Any]] = {
int,
float,
complex,
str,
bool,
bytes,
type,
_typing_extra.NoneType,
FunctionType,
BuiltinFunctionType,
LambdaType,
weakref.ref,
CodeType,
# note: including ModuleType will differ from behaviour of deepcopy by not producing error.
# It might be not a good idea in general, but considering that this function used only internally
# against default values of fields, this will allow to actually have a field with module as default value
ModuleType,
NotImplemented.__class__,
Ellipsis.__class__,
}
# these are types that if empty, might be copied with simple copy() instead of deepcopy()
BUILTIN_COLLECTIONS: set[type[Any]] = {
list,
set,
tuple,
frozenset,
dict,
OrderedDict,
defaultdict,
deque,
}
def can_be_positional(param: Parameter) -> bool:
"""Return whether the parameter accepts a positional argument.
```python {test="skip" lint="skip"}
def func(a, /, b, *, c):
pass
params = inspect.signature(func).parameters
can_be_positional(params['a'])
#> True
can_be_positional(params['b'])
#> True
can_be_positional(params['c'])
#> False
```
"""
return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
def sequence_like(v: Any) -> bool:
return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque))
def lenient_isinstance(o: Any, class_or_tuple: type[Any] | tuple[type[Any], ...] | None) -> bool: # pragma: no cover
try:
return isinstance(o, class_or_tuple) # type: ignore[arg-type]
except TypeError:
return False
def lenient_issubclass(cls: Any, class_or_tuple: Any) -> bool: # pragma: no cover
try:
return isinstance(cls, type) and issubclass(cls, class_or_tuple)
except TypeError:
if isinstance(cls, _typing_extra.WithArgsTypes):
return False
raise # pragma: no cover
def is_model_class(cls: Any) -> TypeGuard[type[BaseModel]]:
"""Returns true if cls is a _proper_ subclass of BaseModel, and provides proper type-checking,
unlike raw calls to lenient_issubclass.
"""
BaseModel = import_cached_base_model()
return lenient_issubclass(cls, BaseModel) and cls is not BaseModel
def is_valid_identifier(identifier: str) -> bool:
"""Checks that a string is a valid identifier and not a Python keyword.
:param identifier: The identifier to test.
:return: True if the identifier is valid.
"""
return identifier.isidentifier() and not keyword.iskeyword(identifier)
KeyType = TypeVar('KeyType')
def deep_update(mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]) -> dict[KeyType, Any]:
updated_mapping = mapping.copy()
for updating_mapping in updating_mappings:
for k, v in updating_mapping.items():
if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict):
updated_mapping[k] = deep_update(updated_mapping[k], v)
else:
updated_mapping[k] = v
return updated_mapping
def update_not_none(mapping: dict[Any, Any], **update: Any) -> None:
mapping.update({k: v for k, v in update.items() if v is not None})
T = TypeVar('T')
def unique_list(
input_list: list[T] | tuple[T, ...],
*,
name_factory: typing.Callable[[T], str] = str,
) -> list[T]:
"""Make a list unique while maintaining order.
We update the list if another one with the same name is set
(e.g. model validator overridden in subclass).
"""
result: list[T] = []
result_names: list[str] = []
for v in input_list:
v_name = name_factory(v)
if v_name not in result_names:
result_names.append(v_name)
result.append(v)
else:
result[result_names.index(v_name)] = v
return result
class ValueItems(_repr.Representation):
"""Class for more convenient calculation of excluded or included fields on values."""
__slots__ = ('_items', '_type')
def __init__(self, value: Any, items: AbstractSetIntStr | MappingIntStrAny) -> None:
items = self._coerce_items(items)
if isinstance(value, (list, tuple)):
items = self._normalize_indexes(items, len(value)) # type: ignore
self._items: MappingIntStrAny = items # type: ignore
def is_excluded(self, item: Any) -> bool:
"""Check if item is fully excluded.
:param item: key or index of a value
"""
return self.is_true(self._items.get(item))
def is_included(self, item: Any) -> bool:
"""Check if value is contained in self._items.
:param item: key or index of value
"""
return item in self._items
def for_element(self, e: int | str) -> AbstractSetIntStr | MappingIntStrAny | None:
""":param e: key or index of element on value
:return: raw values for element if self._items is dict and contain needed element
"""
item = self._items.get(e) # type: ignore
return item if not self.is_true(item) else None
def _normalize_indexes(self, items: MappingIntStrAny, v_length: int) -> dict[int | str, Any]:
""":param items: dict or set of indexes which will be normalized
:param v_length: length of sequence indexes of which will be
>>> self._normalize_indexes({0: True, -2: True, -1: True}, 4)
{0: True, 2: True, 3: True}
>>> self._normalize_indexes({'__all__': True}, 4)
{0: True, 1: True, 2: True, 3: True}
"""
normalized_items: dict[int | str, Any] = {}
all_items = None
for i, v in items.items():
if not (isinstance(v, typing.Mapping) or isinstance(v, typing.AbstractSet) or self.is_true(v)):
raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}')
if i == '__all__':
all_items = self._coerce_value(v)
continue
if not isinstance(i, int):
raise TypeError(
'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: '
'expected integer keys or keyword "__all__"'
)
normalized_i = v_length + i if i < 0 else i
normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i))
if not all_items:
return normalized_items
if self.is_true(all_items):
for i in range(v_length):
normalized_items.setdefault(i, ...)
return normalized_items
for i in range(v_length):
normalized_item = normalized_items.setdefault(i, {})
if not self.is_true(normalized_item):
normalized_items[i] = self.merge(all_items, normalized_item)
return normalized_items
@classmethod
def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any:
"""Merge a `base` item with an `override` item.
Both `base` and `override` are converted to dictionaries if possible.
Sets are converted to dictionaries with the sets entries as keys and
Ellipsis as values.
Each key-value pair existing in `base` is merged with `override`,
while the rest of the key-value pairs are updated recursively with this function.
Merging takes place based on the "union" of keys if `intersect` is
set to `False` (default) and on the intersection of keys if
`intersect` is set to `True`.
"""
override = cls._coerce_value(override)
base = cls._coerce_value(base)
if override is None:
return base
if cls.is_true(base) or base is None:
return override
if cls.is_true(override):
return base if intersect else override
# intersection or union of keys while preserving ordering:
if intersect:
merge_keys = [k for k in base if k in override] + [k for k in override if k in base]
else:
merge_keys = list(base) + [k for k in override if k not in base]
merged: dict[int | str, Any] = {}
for k in merge_keys:
merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect)
if merged_item is not None:
merged[k] = merged_item
return merged
@staticmethod
def _coerce_items(items: AbstractSetIntStr | MappingIntStrAny) -> MappingIntStrAny:
if isinstance(items, typing.Mapping):
pass
elif isinstance(items, typing.AbstractSet):
items = dict.fromkeys(items, ...) # type: ignore
else:
class_name = getattr(items, '__class__', '???')
raise TypeError(f'Unexpected type of exclude value {class_name}')
return items # type: ignore
@classmethod
def _coerce_value(cls, value: Any) -> Any:
if value is None or cls.is_true(value):
return value
return cls._coerce_items(value)
@staticmethod
def is_true(v: Any) -> bool:
return v is True or v is ...
def __repr_args__(self) -> _repr.ReprArgs:
return [(None, self._items)]
if typing.TYPE_CHECKING:
def LazyClassAttribute(name: str, get_value: Callable[[], T]) -> T: ...
else:
class LazyClassAttribute:
"""A descriptor exposing an attribute only accessible on a class (hidden from instances).
The attribute is lazily computed and cached during the first access.
"""
def __init__(self, name: str, get_value: Callable[[], Any]) -> None:
self.name = name
self.get_value = get_value
@cached_property
def value(self) -> Any:
return self.get_value()
def __get__(self, instance: Any, owner: type[Any]) -> None:
if instance is None:
return self.value
raise AttributeError(f'{self.name!r} attribute of {owner.__name__!r} is class-only')
Obj = TypeVar('Obj')
def smart_deepcopy(obj: Obj) -> Obj:
"""Return type as is for immutable built-in types
Use obj.copy() for built-in empty collections
Use copy.deepcopy() for non-empty collections and unknown objects.
"""
obj_type = obj.__class__
if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES:
return obj # fastest case: obj is immutable and not collection therefore will not be copied anyway
try:
if not obj and obj_type in BUILTIN_COLLECTIONS:
# faster way for empty collections, no need to copy its members
return obj if obj_type is tuple else obj.copy() # tuple doesn't have copy method # type: ignore
except (TypeError, ValueError, RuntimeError):
# do we really dare to catch ALL errors? Seems a bit risky
pass
return deepcopy(obj) # slowest way when we actually might need a deepcopy
_SENTINEL = object()
def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bool:
"""Check that the items of `left` are the same objects as those in `right`.
>>> a, b = object(), object()
>>> all_identical([a, b, a], [a, b, a])
True
>>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical"
False
"""
for left_item, right_item in zip_longest(left, right, fillvalue=_SENTINEL):
if left_item is not right_item:
return False
return True
@dataclasses.dataclass(frozen=True)
class SafeGetItemProxy:
"""Wrapper redirecting `__getitem__` to `get` with a sentinel value as default
This makes is safe to use in `operator.itemgetter` when some keys may be missing
"""
# Define __slots__manually for performances
# @dataclasses.dataclass() only support slots=True in python>=3.10
__slots__ = ('wrapped',)
wrapped: Mapping[str, Any]
def __getitem__(self, key: str, /) -> Any:
return self.wrapped.get(key, _SENTINEL)
# required to pass the object to operator.itemgetter() instances due to a quirk of typeshed
# https://github.com/python/mypy/issues/13713
# https://github.com/python/typeshed/pull/8785
# Since this is typing-only, hide it in a typing.TYPE_CHECKING block
if typing.TYPE_CHECKING:
def __contains__(self, key: str, /) -> bool:
return self.wrapped.__contains__(key)
_ModelT = TypeVar('_ModelT', bound='BaseModel')
_RT = TypeVar('_RT')
class deprecated_instance_property(Generic[_ModelT, _RT]):
"""A decorator exposing the decorated class method as a property, with a warning on instance access.
This decorator takes a class method defined on the `BaseModel` class and transforms it into
an attribute. The attribute can be accessed on both the class and instances of the class. If accessed
via an instance, a deprecation warning is emitted stating that instance access will be removed in V3.
"""
def __init__(self, fget: Callable[[type[_ModelT]], _RT], /) -> None:
# Note: fget should be a classmethod:
self.fget = fget
@overload
def __get__(self, instance: None, objtype: type[_ModelT]) -> _RT: ...
@overload
@deprecated(
'Accessing this attribute on the instance is deprecated, and will be removed in Pydantic V3. '
'Instead, you should access this attribute from the model class.',
category=None,
)
def __get__(self, instance: _ModelT, objtype: type[_ModelT]) -> _RT: ...
def __get__(self, instance: _ModelT | None, objtype: type[_ModelT]) -> _RT:
if instance is not None:
warnings.warn(
'Accessing this attribute on the instance is deprecated, and will be removed in Pydantic V3. '
'Instead, you should access this attribute from the model class.',
category=PydanticDeprecatedSince211,
stacklevel=2,
)
return self.fget.__get__(instance, objtype)()

View File

@ -0,0 +1,140 @@
from __future__ import annotations as _annotations
import functools
import inspect
from collections.abc import Awaitable
from functools import partial
from typing import Any, Callable
import pydantic_core
from ..config import ConfigDict
from ..plugin._schema_validator import create_schema_validator
from ._config import ConfigWrapper
from ._generate_schema import GenerateSchema, ValidateCallSupportedTypes
from ._namespace_utils import MappingNamespace, NsResolver, ns_for_function
def extract_function_name(func: ValidateCallSupportedTypes) -> str:
"""Extract the name of a `ValidateCallSupportedTypes` object."""
return f'partial({func.func.__name__})' if isinstance(func, functools.partial) else func.__name__
def extract_function_qualname(func: ValidateCallSupportedTypes) -> str:
"""Extract the qualname of a `ValidateCallSupportedTypes` object."""
return f'partial({func.func.__qualname__})' if isinstance(func, functools.partial) else func.__qualname__
def update_wrapper_attributes(wrapped: ValidateCallSupportedTypes, wrapper: Callable[..., Any]):
"""Update the `wrapper` function with the attributes of the `wrapped` function. Return the updated function."""
if inspect.iscoroutinefunction(wrapped):
@functools.wraps(wrapped)
async def wrapper_function(*args, **kwargs): # type: ignore
return await wrapper(*args, **kwargs)
else:
@functools.wraps(wrapped)
def wrapper_function(*args, **kwargs):
return wrapper(*args, **kwargs)
# We need to manually update this because `partial` object has no `__name__` and `__qualname__`.
wrapper_function.__name__ = extract_function_name(wrapped)
wrapper_function.__qualname__ = extract_function_qualname(wrapped)
wrapper_function.raw_function = wrapped # type: ignore
return wrapper_function
class ValidateCallWrapper:
"""This is a wrapper around a function that validates the arguments passed to it, and optionally the return value."""
__slots__ = (
'function',
'validate_return',
'schema_type',
'module',
'qualname',
'ns_resolver',
'config_wrapper',
'__pydantic_complete__',
'__pydantic_validator__',
'__return_pydantic_validator__',
)
def __init__(
self,
function: ValidateCallSupportedTypes,
config: ConfigDict | None,
validate_return: bool,
parent_namespace: MappingNamespace | None,
) -> None:
self.function = function
self.validate_return = validate_return
if isinstance(function, partial):
self.schema_type = function.func
self.module = function.func.__module__
else:
self.schema_type = function
self.module = function.__module__
self.qualname = extract_function_qualname(function)
self.ns_resolver = NsResolver(
namespaces_tuple=ns_for_function(self.schema_type, parent_namespace=parent_namespace)
)
self.config_wrapper = ConfigWrapper(config)
if not self.config_wrapper.defer_build:
self._create_validators()
else:
self.__pydantic_complete__ = False
def _create_validators(self) -> None:
gen_schema = GenerateSchema(self.config_wrapper, self.ns_resolver)
schema = gen_schema.clean_schema(gen_schema.generate_schema(self.function))
core_config = self.config_wrapper.core_config(title=self.qualname)
self.__pydantic_validator__ = create_schema_validator(
schema,
self.schema_type,
self.module,
self.qualname,
'validate_call',
core_config,
self.config_wrapper.plugin_settings,
)
if self.validate_return:
signature = inspect.signature(self.function)
return_type = signature.return_annotation if signature.return_annotation is not signature.empty else Any
gen_schema = GenerateSchema(self.config_wrapper, self.ns_resolver)
schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
validator = create_schema_validator(
schema,
self.schema_type,
self.module,
self.qualname,
'validate_call',
core_config,
self.config_wrapper.plugin_settings,
)
if inspect.iscoroutinefunction(self.function):
async def return_val_wrapper(aw: Awaitable[Any]) -> None:
return validator.validate_python(await aw)
self.__return_pydantic_validator__ = return_val_wrapper
else:
self.__return_pydantic_validator__ = validator.validate_python
else:
self.__return_pydantic_validator__ = None
self.__pydantic_complete__ = True
def __call__(self, *args: Any, **kwargs: Any) -> Any:
if not self.__pydantic_complete__:
self._create_validators()
res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
if self.__return_pydantic_validator__:
return self.__return_pydantic_validator__(res)
else:
return res

View File

@ -0,0 +1,532 @@
"""Validator functions for standard library types.
Import of this module is deferred since it contains imports of many standard library modules.
"""
from __future__ import annotations as _annotations
import collections.abc
import math
import re
import typing
from decimal import Decimal
from fractions import Fraction
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from typing import Any, Callable, Union, cast, get_origin
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
import typing_extensions
from pydantic_core import PydanticCustomError, core_schema
from pydantic_core._pydantic_core import PydanticKnownError
from typing_inspection import typing_objects
from pydantic._internal._import_utils import import_cached_field_info
from pydantic.errors import PydanticSchemaGenerationError
def sequence_validator(
input_value: typing.Sequence[Any],
/,
validator: core_schema.ValidatorFunctionWrapHandler,
) -> typing.Sequence[Any]:
"""Validator for `Sequence` types, isinstance(v, Sequence) has already been called."""
value_type = type(input_value)
# We don't accept any plain string as a sequence
# Relevant issue: https://github.com/pydantic/pydantic/issues/5595
if issubclass(value_type, (str, bytes)):
raise PydanticCustomError(
'sequence_str',
"'{type_name}' instances are not allowed as a Sequence value",
{'type_name': value_type.__name__},
)
# TODO: refactor sequence validation to validate with either a list or a tuple
# schema, depending on the type of the value.
# Additionally, we should be able to remove one of either this validator or the
# SequenceValidator in _std_types_schema.py (preferably this one, while porting over some logic).
# Effectively, a refactor for sequence validation is needed.
if value_type is tuple:
input_value = list(input_value)
v_list = validator(input_value)
# the rest of the logic is just re-creating the original type from `v_list`
if value_type is list:
return v_list
elif issubclass(value_type, range):
# return the list as we probably can't re-create the range
return v_list
elif value_type is tuple:
return tuple(v_list)
else:
# best guess at how to re-create the original type, more custom construction logic might be required
return value_type(v_list) # type: ignore[call-arg]
def import_string(value: Any) -> Any:
if isinstance(value, str):
try:
return _import_string_logic(value)
except ImportError as e:
raise PydanticCustomError('import_error', 'Invalid python path: {error}', {'error': str(e)}) from e
else:
# otherwise we just return the value and let the next validator do the rest of the work
return value
def _import_string_logic(dotted_path: str) -> Any:
"""Inspired by uvicorn — dotted paths should include a colon before the final item if that item is not a module.
(This is necessary to distinguish between a submodule and an attribute when there is a conflict.).
If the dotted path does not include a colon and the final item is not a valid module, importing as an attribute
rather than a submodule will be attempted automatically.
So, for example, the following values of `dotted_path` result in the following returned values:
* 'collections': <module 'collections'>
* 'collections.abc': <module 'collections.abc'>
* 'collections.abc:Mapping': <class 'collections.abc.Mapping'>
* `collections.abc.Mapping`: <class 'collections.abc.Mapping'> (though this is a bit slower than the previous line)
An error will be raised under any of the following scenarios:
* `dotted_path` contains more than one colon (e.g., 'collections:abc:Mapping')
* the substring of `dotted_path` before the colon is not a valid module in the environment (e.g., '123:Mapping')
* the substring of `dotted_path` after the colon is not an attribute of the module (e.g., 'collections:abc123')
"""
from importlib import import_module
components = dotted_path.strip().split(':')
if len(components) > 2:
raise ImportError(f"Import strings should have at most one ':'; received {dotted_path!r}")
module_path = components[0]
if not module_path:
raise ImportError(f'Import strings should have a nonempty module name; received {dotted_path!r}')
try:
module = import_module(module_path)
except ModuleNotFoundError as e:
if '.' in module_path:
# Check if it would be valid if the final item was separated from its module with a `:`
maybe_module_path, maybe_attribute = dotted_path.strip().rsplit('.', 1)
try:
return _import_string_logic(f'{maybe_module_path}:{maybe_attribute}')
except ImportError:
pass
raise ImportError(f'No module named {module_path!r}') from e
raise e
if len(components) > 1:
attribute = components[1]
try:
return getattr(module, attribute)
except AttributeError as e:
raise ImportError(f'cannot import name {attribute!r} from {module_path!r}') from e
else:
return module
def pattern_either_validator(input_value: Any, /) -> typing.Pattern[Any]:
if isinstance(input_value, typing.Pattern):
return input_value
elif isinstance(input_value, (str, bytes)):
# todo strict mode
return compile_pattern(input_value) # type: ignore
else:
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
def pattern_str_validator(input_value: Any, /) -> typing.Pattern[str]:
if isinstance(input_value, typing.Pattern):
if isinstance(input_value.pattern, str):
return input_value
else:
raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
elif isinstance(input_value, str):
return compile_pattern(input_value)
elif isinstance(input_value, bytes):
raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
else:
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
def pattern_bytes_validator(input_value: Any, /) -> typing.Pattern[bytes]:
if isinstance(input_value, typing.Pattern):
if isinstance(input_value.pattern, bytes):
return input_value
else:
raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
elif isinstance(input_value, bytes):
return compile_pattern(input_value)
elif isinstance(input_value, str):
raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
else:
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
PatternType = typing.TypeVar('PatternType', str, bytes)
def compile_pattern(pattern: PatternType) -> typing.Pattern[PatternType]:
try:
return re.compile(pattern)
except re.error:
raise PydanticCustomError('pattern_regex', 'Input should be a valid regular expression')
def ip_v4_address_validator(input_value: Any, /) -> IPv4Address:
if isinstance(input_value, IPv4Address):
return input_value
try:
return IPv4Address(input_value)
except ValueError:
raise PydanticCustomError('ip_v4_address', 'Input is not a valid IPv4 address')
def ip_v6_address_validator(input_value: Any, /) -> IPv6Address:
if isinstance(input_value, IPv6Address):
return input_value
try:
return IPv6Address(input_value)
except ValueError:
raise PydanticCustomError('ip_v6_address', 'Input is not a valid IPv6 address')
def ip_v4_network_validator(input_value: Any, /) -> IPv4Network:
"""Assume IPv4Network initialised with a default `strict` argument.
See more:
https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network
"""
if isinstance(input_value, IPv4Network):
return input_value
try:
return IPv4Network(input_value)
except ValueError:
raise PydanticCustomError('ip_v4_network', 'Input is not a valid IPv4 network')
def ip_v6_network_validator(input_value: Any, /) -> IPv6Network:
"""Assume IPv6Network initialised with a default `strict` argument.
See more:
https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network
"""
if isinstance(input_value, IPv6Network):
return input_value
try:
return IPv6Network(input_value)
except ValueError:
raise PydanticCustomError('ip_v6_network', 'Input is not a valid IPv6 network')
def ip_v4_interface_validator(input_value: Any, /) -> IPv4Interface:
if isinstance(input_value, IPv4Interface):
return input_value
try:
return IPv4Interface(input_value)
except ValueError:
raise PydanticCustomError('ip_v4_interface', 'Input is not a valid IPv4 interface')
def ip_v6_interface_validator(input_value: Any, /) -> IPv6Interface:
if isinstance(input_value, IPv6Interface):
return input_value
try:
return IPv6Interface(input_value)
except ValueError:
raise PydanticCustomError('ip_v6_interface', 'Input is not a valid IPv6 interface')
def fraction_validator(input_value: Any, /) -> Fraction:
if isinstance(input_value, Fraction):
return input_value
try:
return Fraction(input_value)
except ValueError:
raise PydanticCustomError('fraction_parsing', 'Input is not a valid fraction')
def forbid_inf_nan_check(x: Any) -> Any:
if not math.isfinite(x):
raise PydanticKnownError('finite_number')
return x
def _safe_repr(v: Any) -> int | float | str:
"""The context argument for `PydanticKnownError` requires a number or str type, so we do a simple repr() coercion for types like timedelta.
See tests/test_types.py::test_annotated_metadata_any_order for some context.
"""
if isinstance(v, (int, float, str)):
return v
return repr(v)
def greater_than_validator(x: Any, gt: Any) -> Any:
try:
if not (x > gt):
raise PydanticKnownError('greater_than', {'gt': _safe_repr(gt)})
return x
except TypeError:
raise TypeError(f"Unable to apply constraint 'gt' to supplied value {x}")
def greater_than_or_equal_validator(x: Any, ge: Any) -> Any:
try:
if not (x >= ge):
raise PydanticKnownError('greater_than_equal', {'ge': _safe_repr(ge)})
return x
except TypeError:
raise TypeError(f"Unable to apply constraint 'ge' to supplied value {x}")
def less_than_validator(x: Any, lt: Any) -> Any:
try:
if not (x < lt):
raise PydanticKnownError('less_than', {'lt': _safe_repr(lt)})
return x
except TypeError:
raise TypeError(f"Unable to apply constraint 'lt' to supplied value {x}")
def less_than_or_equal_validator(x: Any, le: Any) -> Any:
try:
if not (x <= le):
raise PydanticKnownError('less_than_equal', {'le': _safe_repr(le)})
return x
except TypeError:
raise TypeError(f"Unable to apply constraint 'le' to supplied value {x}")
def multiple_of_validator(x: Any, multiple_of: Any) -> Any:
try:
if x % multiple_of:
raise PydanticKnownError('multiple_of', {'multiple_of': _safe_repr(multiple_of)})
return x
except TypeError:
raise TypeError(f"Unable to apply constraint 'multiple_of' to supplied value {x}")
def min_length_validator(x: Any, min_length: Any) -> Any:
try:
if not (len(x) >= min_length):
raise PydanticKnownError(
'too_short', {'field_type': 'Value', 'min_length': min_length, 'actual_length': len(x)}
)
return x
except TypeError:
raise TypeError(f"Unable to apply constraint 'min_length' to supplied value {x}")
def max_length_validator(x: Any, max_length: Any) -> Any:
try:
if len(x) > max_length:
raise PydanticKnownError(
'too_long',
{'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)},
)
return x
except TypeError:
raise TypeError(f"Unable to apply constraint 'max_length' to supplied value {x}")
def _extract_decimal_digits_info(decimal: Decimal) -> tuple[int, int]:
"""Compute the total number of digits and decimal places for a given [`Decimal`][decimal.Decimal] instance.
This function handles both normalized and non-normalized Decimal instances.
Example: Decimal('1.230') -> 4 digits, 3 decimal places
Args:
decimal (Decimal): The decimal number to analyze.
Returns:
tuple[int, int]: A tuple containing the number of decimal places and total digits.
Though this could be divided into two separate functions, the logic is easier to follow if we couple the computation
of the number of decimals and digits together.
"""
try:
decimal_tuple = decimal.as_tuple()
assert isinstance(decimal_tuple.exponent, int)
exponent = decimal_tuple.exponent
num_digits = len(decimal_tuple.digits)
if exponent >= 0:
# A positive exponent adds that many trailing zeros
# Ex: digit_tuple=(1, 2, 3), exponent=2 -> 12300 -> 0 decimal places, 5 digits
num_digits += exponent
decimal_places = 0
else:
# If the absolute value of the negative exponent is larger than the
# number of digits, then it's the same as the number of digits,
# because it'll consume all the digits in digit_tuple and then
# add abs(exponent) - len(digit_tuple) leading zeros after the decimal point.
# Ex: digit_tuple=(1, 2, 3), exponent=-2 -> 1.23 -> 2 decimal places, 3 digits
# Ex: digit_tuple=(1, 2, 3), exponent=-4 -> 0.0123 -> 4 decimal places, 4 digits
decimal_places = abs(exponent)
num_digits = max(num_digits, decimal_places)
return decimal_places, num_digits
except (AssertionError, AttributeError):
raise TypeError(f'Unable to extract decimal digits info from supplied value {decimal}')
def max_digits_validator(x: Any, max_digits: Any) -> Any:
try:
_, num_digits = _extract_decimal_digits_info(x)
_, normalized_num_digits = _extract_decimal_digits_info(x.normalize())
if (num_digits > max_digits) and (normalized_num_digits > max_digits):
raise PydanticKnownError(
'decimal_max_digits',
{'max_digits': max_digits},
)
return x
except TypeError:
raise TypeError(f"Unable to apply constraint 'max_digits' to supplied value {x}")
def decimal_places_validator(x: Any, decimal_places: Any) -> Any:
try:
decimal_places_, _ = _extract_decimal_digits_info(x)
if decimal_places_ > decimal_places:
normalized_decimal_places, _ = _extract_decimal_digits_info(x.normalize())
if normalized_decimal_places > decimal_places:
raise PydanticKnownError(
'decimal_max_places',
{'decimal_places': decimal_places},
)
return x
except TypeError:
raise TypeError(f"Unable to apply constraint 'decimal_places' to supplied value {x}")
def deque_validator(input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> collections.deque[Any]:
return collections.deque(handler(input_value), maxlen=getattr(input_value, 'maxlen', None))
def defaultdict_validator(
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any]
) -> collections.defaultdict[Any, Any]:
if isinstance(input_value, collections.defaultdict):
default_factory = input_value.default_factory
return collections.defaultdict(default_factory, handler(input_value))
else:
return collections.defaultdict(default_default_factory, handler(input_value))
def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]:
FieldInfo = import_cached_field_info()
values_type_origin = get_origin(values_source_type)
def infer_default() -> Callable[[], Any]:
allowed_default_types: dict[Any, Any] = {
tuple: tuple,
collections.abc.Sequence: tuple,
collections.abc.MutableSequence: list,
list: list,
typing.Sequence: list,
set: set,
typing.MutableSet: set,
collections.abc.MutableSet: set,
collections.abc.Set: frozenset,
typing.MutableMapping: dict,
typing.Mapping: dict,
collections.abc.Mapping: dict,
collections.abc.MutableMapping: dict,
float: float,
int: int,
str: str,
bool: bool,
}
values_type = values_type_origin or values_source_type
instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`'
if typing_objects.is_typevar(values_type):
def type_var_default_factory() -> None:
raise RuntimeError(
'Generic defaultdict cannot be used without a concrete value type or an'
' explicit default factory, ' + instructions
)
return type_var_default_factory
elif values_type not in allowed_default_types:
# a somewhat subjective set of types that have reasonable default values
allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())])
raise PydanticSchemaGenerationError(
f'Unable to infer a default factory for keys of type {values_source_type}.'
f' Only {allowed_msg} are supported, other types require an explicit default factory'
' ' + instructions
)
return allowed_default_types[values_type]
# Assume Annotated[..., Field(...)]
if typing_objects.is_annotated(values_type_origin):
field_info = next((v for v in typing_extensions.get_args(values_source_type) if isinstance(v, FieldInfo)), None)
else:
field_info = None
if field_info and field_info.default_factory:
# Assume the default factory does not take any argument:
default_default_factory = cast(Callable[[], Any], field_info.default_factory)
else:
default_default_factory = infer_default()
return default_default_factory
def validate_str_is_valid_iana_tz(value: Any, /) -> ZoneInfo:
if isinstance(value, ZoneInfo):
return value
try:
return ZoneInfo(value)
except (ZoneInfoNotFoundError, ValueError, TypeError):
raise PydanticCustomError('zoneinfo_str', 'invalid timezone: {value}', {'value': value})
NUMERIC_VALIDATOR_LOOKUP: dict[str, Callable] = {
'gt': greater_than_validator,
'ge': greater_than_or_equal_validator,
'lt': less_than_validator,
'le': less_than_or_equal_validator,
'multiple_of': multiple_of_validator,
'min_length': min_length_validator,
'max_length': max_length_validator,
'max_digits': max_digits_validator,
'decimal_places': decimal_places_validator,
}
IpType = Union[IPv4Address, IPv6Address, IPv4Network, IPv6Network, IPv4Interface, IPv6Interface]
IP_VALIDATOR_LOOKUP: dict[type[IpType], Callable] = {
IPv4Address: ip_v4_address_validator,
IPv6Address: ip_v6_address_validator,
IPv4Network: ip_v4_network_validator,
IPv6Network: ip_v6_network_validator,
IPv4Interface: ip_v4_interface_validator,
IPv6Interface: ip_v6_interface_validator,
}
MAPPING_ORIGIN_MAP: dict[Any, Any] = {
typing.DefaultDict: collections.defaultdict, # noqa: UP006
collections.defaultdict: collections.defaultdict,
typing.OrderedDict: collections.OrderedDict, # noqa: UP006
collections.OrderedDict: collections.OrderedDict,
typing_extensions.OrderedDict: collections.OrderedDict,
typing.Counter: collections.Counter,
collections.Counter: collections.Counter,
# this doesn't handle subclasses of these
typing.Mapping: dict,
typing.MutableMapping: dict,
# parametrized typing.{Mutable}Mapping creates one of these
collections.abc.Mapping: dict,
collections.abc.MutableMapping: dict,
}