Spaces:
Running
Running
from collections import deque | |
from copy import copy | |
from dataclasses import dataclass, is_dataclass | |
from enum import Enum | |
from functools import lru_cache | |
from typing import ( | |
Any, | |
Callable, | |
Deque, | |
Dict, | |
FrozenSet, | |
List, | |
Mapping, | |
Sequence, | |
Set, | |
Tuple, | |
Type, | |
Union, | |
) | |
from fastapi.exceptions import RequestErrorModel | |
from fastapi.types import IncEx, ModelNameMap, UnionType | |
from pydantic import BaseModel, create_model | |
from pydantic.version import VERSION as P_VERSION | |
from starlette.datastructures import UploadFile | |
from typing_extensions import Annotated, Literal, get_args, get_origin | |
# Reassign variable to make it reexported for mypy | |
PYDANTIC_VERSION = P_VERSION | |
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") | |
sequence_annotation_to_type = { | |
Sequence: list, | |
List: list, | |
list: list, | |
Tuple: tuple, | |
tuple: tuple, | |
Set: set, | |
set: set, | |
FrozenSet: frozenset, | |
frozenset: frozenset, | |
Deque: deque, | |
deque: deque, | |
} | |
sequence_types = tuple(sequence_annotation_to_type.keys()) | |
if PYDANTIC_V2: | |
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError | |
from pydantic import TypeAdapter | |
from pydantic import ValidationError as ValidationError | |
from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined] | |
GetJsonSchemaHandler as GetJsonSchemaHandler, | |
) | |
from pydantic._internal._typing_extra import eval_type_lenient | |
from pydantic._internal._utils import lenient_issubclass as lenient_issubclass | |
from pydantic.fields import FieldInfo | |
from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema | |
from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue | |
from pydantic_core import CoreSchema as CoreSchema | |
from pydantic_core import PydanticUndefined, PydanticUndefinedType | |
from pydantic_core import Url as Url | |
try: | |
from pydantic_core.core_schema import ( | |
with_info_plain_validator_function as with_info_plain_validator_function, | |
) | |
except ImportError: # pragma: no cover | |
from pydantic_core.core_schema import ( | |
general_plain_validator_function as with_info_plain_validator_function, # noqa: F401 | |
) | |
Required = PydanticUndefined | |
Undefined = PydanticUndefined | |
UndefinedType = PydanticUndefinedType | |
evaluate_forwardref = eval_type_lenient | |
Validator = Any | |
class BaseConfig: | |
pass | |
class ErrorWrapper(Exception): | |
pass | |
class ModelField: | |
field_info: FieldInfo | |
name: str | |
mode: Literal["validation", "serialization"] = "validation" | |
def alias(self) -> str: | |
a = self.field_info.alias | |
return a if a is not None else self.name | |
def required(self) -> bool: | |
return self.field_info.is_required() | |
def default(self) -> Any: | |
return self.get_default() | |
def type_(self) -> Any: | |
return self.field_info.annotation | |
def __post_init__(self) -> None: | |
self._type_adapter: TypeAdapter[Any] = TypeAdapter( | |
Annotated[self.field_info.annotation, self.field_info] | |
) | |
def get_default(self) -> Any: | |
if self.field_info.is_required(): | |
return Undefined | |
return self.field_info.get_default(call_default_factory=True) | |
def validate( | |
self, | |
value: Any, | |
values: Dict[str, Any] = {}, # noqa: B006 | |
*, | |
loc: Tuple[Union[int, str], ...] = (), | |
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]: | |
try: | |
return ( | |
self._type_adapter.validate_python(value, from_attributes=True), | |
None, | |
) | |
except ValidationError as exc: | |
return None, _regenerate_error_with_loc( | |
errors=exc.errors(include_url=False), loc_prefix=loc | |
) | |
def serialize( | |
self, | |
value: Any, | |
*, | |
mode: Literal["json", "python"] = "json", | |
include: Union[IncEx, None] = None, | |
exclude: Union[IncEx, None] = None, | |
by_alias: bool = True, | |
exclude_unset: bool = False, | |
exclude_defaults: bool = False, | |
exclude_none: bool = False, | |
) -> Any: | |
# What calls this code passes a value that already called | |
# self._type_adapter.validate_python(value) | |
return self._type_adapter.dump_python( | |
value, | |
mode=mode, | |
include=include, | |
exclude=exclude, | |
by_alias=by_alias, | |
exclude_unset=exclude_unset, | |
exclude_defaults=exclude_defaults, | |
exclude_none=exclude_none, | |
) | |
def __hash__(self) -> int: | |
# Each ModelField is unique for our purposes, to allow making a dict from | |
# ModelField to its JSON Schema. | |
return id(self) | |
def get_annotation_from_field_info( | |
annotation: Any, field_info: FieldInfo, field_name: str | |
) -> Any: | |
return annotation | |
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]: | |
return errors # type: ignore[return-value] | |
def _model_rebuild(model: Type[BaseModel]) -> None: | |
model.model_rebuild() | |
def _model_dump( | |
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any | |
) -> Any: | |
return model.model_dump(mode=mode, **kwargs) | |
def _get_model_config(model: BaseModel) -> Any: | |
return model.model_config | |
def get_schema_from_model_field( | |
*, | |
field: ModelField, | |
schema_generator: GenerateJsonSchema, | |
model_name_map: ModelNameMap, | |
field_mapping: Dict[ | |
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue | |
], | |
separate_input_output_schemas: bool = True, | |
) -> Dict[str, Any]: | |
override_mode: Union[Literal["validation"], None] = ( | |
None if separate_input_output_schemas else "validation" | |
) | |
# This expects that GenerateJsonSchema was already used to generate the definitions | |
json_schema = field_mapping[(field, override_mode or field.mode)] | |
if "$ref" not in json_schema: | |
# TODO remove when deprecating Pydantic v1 | |
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207 | |
json_schema["title"] = ( | |
field.field_info.title or field.alias.title().replace("_", " ") | |
) | |
return json_schema | |
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap: | |
return {} | |
def get_definitions( | |
*, | |
fields: List[ModelField], | |
schema_generator: GenerateJsonSchema, | |
model_name_map: ModelNameMap, | |
separate_input_output_schemas: bool = True, | |
) -> Tuple[ | |
Dict[ | |
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue | |
], | |
Dict[str, Dict[str, Any]], | |
]: | |
override_mode: Union[Literal["validation"], None] = ( | |
None if separate_input_output_schemas else "validation" | |
) | |
inputs = [ | |
(field, override_mode or field.mode, field._type_adapter.core_schema) | |
for field in fields | |
] | |
field_mapping, definitions = schema_generator.generate_definitions( | |
inputs=inputs | |
) | |
return field_mapping, definitions # type: ignore[return-value] | |
def is_scalar_field(field: ModelField) -> bool: | |
from fastapi import params | |
return field_annotation_is_scalar( | |
field.field_info.annotation | |
) and not isinstance(field.field_info, params.Body) | |
def is_sequence_field(field: ModelField) -> bool: | |
return field_annotation_is_sequence(field.field_info.annotation) | |
def is_scalar_sequence_field(field: ModelField) -> bool: | |
return field_annotation_is_scalar_sequence(field.field_info.annotation) | |
def is_bytes_field(field: ModelField) -> bool: | |
return is_bytes_or_nonable_bytes_annotation(field.type_) | |
def is_bytes_sequence_field(field: ModelField) -> bool: | |
return is_bytes_sequence_annotation(field.type_) | |
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo: | |
cls = type(field_info) | |
merged_field_info = cls.from_annotation(annotation) | |
new_field_info = copy(field_info) | |
new_field_info.metadata = merged_field_info.metadata | |
new_field_info.annotation = merged_field_info.annotation | |
return new_field_info | |
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]: | |
origin_type = ( | |
get_origin(field.field_info.annotation) or field.field_info.annotation | |
) | |
assert issubclass(origin_type, sequence_types) # type: ignore[arg-type] | |
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return] | |
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]: | |
error = ValidationError.from_exception_data( | |
"Field required", [{"type": "missing", "loc": loc, "input": {}}] | |
).errors(include_url=False)[0] | |
error["input"] = None | |
return error # type: ignore[return-value] | |
def create_body_model( | |
*, fields: Sequence[ModelField], model_name: str | |
) -> Type[BaseModel]: | |
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields} | |
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload] | |
return BodyModel | |
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]: | |
return [ | |
ModelField(field_info=field_info, name=name) | |
for name, field_info in model.model_fields.items() | |
] | |
else: | |
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX | |
from pydantic import AnyUrl as Url # noqa: F401 | |
from pydantic import ( # type: ignore[assignment] | |
BaseConfig as BaseConfig, # noqa: F401 | |
) | |
from pydantic import ValidationError as ValidationError # noqa: F401 | |
from pydantic.class_validators import ( # type: ignore[no-redef] | |
Validator as Validator, # noqa: F401 | |
) | |
from pydantic.error_wrappers import ( # type: ignore[no-redef] | |
ErrorWrapper as ErrorWrapper, # noqa: F401 | |
) | |
from pydantic.errors import MissingError | |
from pydantic.fields import ( # type: ignore[attr-defined] | |
SHAPE_FROZENSET, | |
SHAPE_LIST, | |
SHAPE_SEQUENCE, | |
SHAPE_SET, | |
SHAPE_SINGLETON, | |
SHAPE_TUPLE, | |
SHAPE_TUPLE_ELLIPSIS, | |
) | |
from pydantic.fields import FieldInfo as FieldInfo | |
from pydantic.fields import ( # type: ignore[no-redef,attr-defined] | |
ModelField as ModelField, # noqa: F401 | |
) | |
from pydantic.fields import ( # type: ignore[no-redef,attr-defined] | |
Required as Required, # noqa: F401 | |
) | |
from pydantic.fields import ( # type: ignore[no-redef,attr-defined] | |
Undefined as Undefined, | |
) | |
from pydantic.fields import ( # type: ignore[no-redef, attr-defined] | |
UndefinedType as UndefinedType, # noqa: F401 | |
) | |
from pydantic.schema import ( | |
field_schema, | |
get_flat_models_from_fields, | |
get_model_name_map, | |
model_process_schema, | |
) | |
from pydantic.schema import ( # type: ignore[no-redef] # noqa: F401 | |
get_annotation_from_field_info as get_annotation_from_field_info, | |
) | |
from pydantic.typing import ( # type: ignore[no-redef] | |
evaluate_forwardref as evaluate_forwardref, # noqa: F401 | |
) | |
from pydantic.utils import ( # type: ignore[no-redef] | |
lenient_issubclass as lenient_issubclass, # noqa: F401 | |
) | |
GetJsonSchemaHandler = Any # type: ignore[assignment,misc] | |
JsonSchemaValue = Dict[str, Any] # type: ignore[misc] | |
CoreSchema = Any # type: ignore[assignment,misc] | |
sequence_shapes = { | |
SHAPE_LIST, | |
SHAPE_SET, | |
SHAPE_FROZENSET, | |
SHAPE_TUPLE, | |
SHAPE_SEQUENCE, | |
SHAPE_TUPLE_ELLIPSIS, | |
} | |
sequence_shape_to_type = { | |
SHAPE_LIST: list, | |
SHAPE_SET: set, | |
SHAPE_TUPLE: tuple, | |
SHAPE_SEQUENCE: list, | |
SHAPE_TUPLE_ELLIPSIS: list, | |
} | |
class GenerateJsonSchema: # type: ignore[no-redef] | |
ref_template: str | |
class PydanticSchemaGenerationError(Exception): # type: ignore[no-redef] | |
pass | |
def with_info_plain_validator_function( # type: ignore[misc] | |
function: Callable[..., Any], | |
*, | |
ref: Union[str, None] = None, | |
metadata: Any = None, | |
serialization: Any = None, | |
) -> Any: | |
return {} | |
def get_model_definitions( | |
*, | |
flat_models: Set[Union[Type[BaseModel], Type[Enum]]], | |
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str], | |
) -> Dict[str, Any]: | |
definitions: Dict[str, Dict[str, Any]] = {} | |
for model in flat_models: | |
m_schema, m_definitions, m_nested_models = model_process_schema( | |
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX | |
) | |
definitions.update(m_definitions) | |
model_name = model_name_map[model] | |
if "description" in m_schema: | |
m_schema["description"] = m_schema["description"].split("\f")[0] | |
definitions[model_name] = m_schema | |
return definitions | |
def is_pv1_scalar_field(field: ModelField) -> bool: | |
from fastapi import params | |
field_info = field.field_info | |
if not ( | |
field.shape == SHAPE_SINGLETON # type: ignore[attr-defined] | |
and not lenient_issubclass(field.type_, BaseModel) | |
and not lenient_issubclass(field.type_, dict) | |
and not field_annotation_is_sequence(field.type_) | |
and not is_dataclass(field.type_) | |
and not isinstance(field_info, params.Body) | |
): | |
return False | |
if field.sub_fields: # type: ignore[attr-defined] | |
if not all( | |
is_pv1_scalar_field(f) | |
for f in field.sub_fields # type: ignore[attr-defined] | |
): | |
return False | |
return True | |
def is_pv1_scalar_sequence_field(field: ModelField) -> bool: | |
if (field.shape in sequence_shapes) and not lenient_issubclass( # type: ignore[attr-defined] | |
field.type_, BaseModel | |
): | |
if field.sub_fields is not None: # type: ignore[attr-defined] | |
for sub_field in field.sub_fields: # type: ignore[attr-defined] | |
if not is_pv1_scalar_field(sub_field): | |
return False | |
return True | |
if _annotation_is_sequence(field.type_): | |
return True | |
return False | |
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]: | |
use_errors: List[Any] = [] | |
for error in errors: | |
if isinstance(error, ErrorWrapper): | |
new_errors = ValidationError( # type: ignore[call-arg] | |
errors=[error], model=RequestErrorModel | |
).errors() | |
use_errors.extend(new_errors) | |
elif isinstance(error, list): | |
use_errors.extend(_normalize_errors(error)) | |
else: | |
use_errors.append(error) | |
return use_errors | |
def _model_rebuild(model: Type[BaseModel]) -> None: | |
model.update_forward_refs() | |
def _model_dump( | |
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any | |
) -> Any: | |
return model.dict(**kwargs) | |
def _get_model_config(model: BaseModel) -> Any: | |
return model.__config__ # type: ignore[attr-defined] | |
def get_schema_from_model_field( | |
*, | |
field: ModelField, | |
schema_generator: GenerateJsonSchema, | |
model_name_map: ModelNameMap, | |
field_mapping: Dict[ | |
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue | |
], | |
separate_input_output_schemas: bool = True, | |
) -> Dict[str, Any]: | |
# This expects that GenerateJsonSchema was already used to generate the definitions | |
return field_schema( # type: ignore[no-any-return] | |
field, model_name_map=model_name_map, ref_prefix=REF_PREFIX | |
)[0] | |
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap: | |
models = get_flat_models_from_fields(fields, known_models=set()) | |
return get_model_name_map(models) # type: ignore[no-any-return] | |
def get_definitions( | |
*, | |
fields: List[ModelField], | |
schema_generator: GenerateJsonSchema, | |
model_name_map: ModelNameMap, | |
separate_input_output_schemas: bool = True, | |
) -> Tuple[ | |
Dict[ | |
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue | |
], | |
Dict[str, Dict[str, Any]], | |
]: | |
models = get_flat_models_from_fields(fields, known_models=set()) | |
return {}, get_model_definitions( | |
flat_models=models, model_name_map=model_name_map | |
) | |
def is_scalar_field(field: ModelField) -> bool: | |
return is_pv1_scalar_field(field) | |
def is_sequence_field(field: ModelField) -> bool: | |
return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined] | |
def is_scalar_sequence_field(field: ModelField) -> bool: | |
return is_pv1_scalar_sequence_field(field) | |
def is_bytes_field(field: ModelField) -> bool: | |
return lenient_issubclass(field.type_, bytes) | |
def is_bytes_sequence_field(field: ModelField) -> bool: | |
return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes) # type: ignore[attr-defined] | |
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo: | |
return copy(field_info) | |
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]: | |
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined] | |
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]: | |
missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg] | |
new_error = ValidationError([missing_field_error], RequestErrorModel) | |
return new_error.errors()[0] # type: ignore[return-value] | |
def create_body_model( | |
*, fields: Sequence[ModelField], model_name: str | |
) -> Type[BaseModel]: | |
BodyModel = create_model(model_name) | |
for f in fields: | |
BodyModel.__fields__[f.name] = f # type: ignore[index] | |
return BodyModel | |
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]: | |
return list(model.__fields__.values()) # type: ignore[attr-defined] | |
def _regenerate_error_with_loc( | |
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...] | |
) -> List[Dict[str, Any]]: | |
updated_loc_errors: List[Any] = [ | |
{**err, "loc": loc_prefix + err.get("loc", ())} | |
for err in _normalize_errors(errors) | |
] | |
return updated_loc_errors | |
def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: | |
if lenient_issubclass(annotation, (str, bytes)): | |
return False | |
return lenient_issubclass(annotation, sequence_types) | |
def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: | |
origin = get_origin(annotation) | |
if origin is Union or origin is UnionType: | |
for arg in get_args(annotation): | |
if field_annotation_is_sequence(arg): | |
return True | |
return False | |
return _annotation_is_sequence(annotation) or _annotation_is_sequence( | |
get_origin(annotation) | |
) | |
def value_is_sequence(value: Any) -> bool: | |
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type] | |
def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: | |
return ( | |
lenient_issubclass(annotation, (BaseModel, Mapping, UploadFile)) | |
or _annotation_is_sequence(annotation) | |
or is_dataclass(annotation) | |
) | |
def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: | |
origin = get_origin(annotation) | |
if origin is Union or origin is UnionType: | |
return any(field_annotation_is_complex(arg) for arg in get_args(annotation)) | |
return ( | |
_annotation_is_complex(annotation) | |
or _annotation_is_complex(origin) | |
or hasattr(origin, "__pydantic_core_schema__") | |
or hasattr(origin, "__get_pydantic_core_schema__") | |
) | |
def field_annotation_is_scalar(annotation: Any) -> bool: | |
# handle Ellipsis here to make tuple[int, ...] work nicely | |
return annotation is Ellipsis or not field_annotation_is_complex(annotation) | |
def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool: | |
origin = get_origin(annotation) | |
if origin is Union or origin is UnionType: | |
at_least_one_scalar_sequence = False | |
for arg in get_args(annotation): | |
if field_annotation_is_scalar_sequence(arg): | |
at_least_one_scalar_sequence = True | |
continue | |
elif not field_annotation_is_scalar(arg): | |
return False | |
return at_least_one_scalar_sequence | |
return field_annotation_is_sequence(annotation) and all( | |
field_annotation_is_scalar(sub_annotation) | |
for sub_annotation in get_args(annotation) | |
) | |
def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool: | |
if lenient_issubclass(annotation, bytes): | |
return True | |
origin = get_origin(annotation) | |
if origin is Union or origin is UnionType: | |
for arg in get_args(annotation): | |
if lenient_issubclass(arg, bytes): | |
return True | |
return False | |
def is_uploadfile_or_nonable_uploadfile_annotation(annotation: Any) -> bool: | |
if lenient_issubclass(annotation, UploadFile): | |
return True | |
origin = get_origin(annotation) | |
if origin is Union or origin is UnionType: | |
for arg in get_args(annotation): | |
if lenient_issubclass(arg, UploadFile): | |
return True | |
return False | |
def is_bytes_sequence_annotation(annotation: Any) -> bool: | |
origin = get_origin(annotation) | |
if origin is Union or origin is UnionType: | |
at_least_one = False | |
for arg in get_args(annotation): | |
if is_bytes_sequence_annotation(arg): | |
at_least_one = True | |
continue | |
return at_least_one | |
return field_annotation_is_sequence(annotation) and all( | |
is_bytes_or_nonable_bytes_annotation(sub_annotation) | |
for sub_annotation in get_args(annotation) | |
) | |
def is_uploadfile_sequence_annotation(annotation: Any) -> bool: | |
origin = get_origin(annotation) | |
if origin is Union or origin is UnionType: | |
at_least_one = False | |
for arg in get_args(annotation): | |
if is_uploadfile_sequence_annotation(arg): | |
at_least_one = True | |
continue | |
return at_least_one | |
return field_annotation_is_sequence(annotation) and all( | |
is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation) | |
for sub_annotation in get_args(annotation) | |
) | |
def get_cached_model_fields(model: Type[BaseModel]) -> List[ModelField]: | |
return get_model_fields(model) | |