import re import warnings from dataclasses import is_dataclass from typing import ( TYPE_CHECKING, Any, Dict, MutableMapping, Optional, Set, Type, Union, cast, ) from weakref import WeakKeyDictionary import fastapi from fastapi._compat import ( PYDANTIC_V2, BaseConfig, ModelField, PydanticSchemaGenerationError, Undefined, UndefinedType, Validator, lenient_issubclass, ) from fastapi.datastructures import DefaultPlaceholder, DefaultType from pydantic import BaseModel, create_model from pydantic.fields import FieldInfo from typing_extensions import Literal if TYPE_CHECKING: # pragma: nocover from .routing import APIRoute # Cache for `create_cloned_field` _CLONED_TYPES_CACHE: MutableMapping[Type[BaseModel], Type[BaseModel]] = ( WeakKeyDictionary() ) def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool: if status_code is None: return True # Ref: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#patterned-fields-1 if status_code in { "default", "1XX", "2XX", "3XX", "4XX", "5XX", }: return True current_status_code = int(status_code) return not (current_status_code < 200 or current_status_code in {204, 205, 304}) def get_path_param_names(path: str) -> Set[str]: return set(re.findall("{(.*?)}", path)) def create_model_field( name: str, type_: Any, class_validators: Optional[Dict[str, Validator]] = None, default: Optional[Any] = Undefined, required: Union[bool, UndefinedType] = Undefined, model_config: Type[BaseConfig] = BaseConfig, field_info: Optional[FieldInfo] = None, alias: Optional[str] = None, mode: Literal["validation", "serialization"] = "validation", ) -> ModelField: class_validators = class_validators or {} if PYDANTIC_V2: field_info = field_info or FieldInfo( annotation=type_, default=default, alias=alias ) else: field_info = field_info or FieldInfo() kwargs = {"name": name, "field_info": field_info} if PYDANTIC_V2: kwargs.update({"mode": mode}) else: kwargs.update( { "type_": type_, "class_validators": class_validators, "default": default, "required": required, "model_config": model_config, "alias": alias, } ) try: return ModelField(**kwargs) # type: ignore[arg-type] except (RuntimeError, PydanticSchemaGenerationError): raise fastapi.exceptions.FastAPIError( "Invalid args for response field! Hint: " f"check that {type_} is a valid Pydantic field type. " "If you are using a return type annotation that is not a valid Pydantic " "field (e.g. Union[Response, dict, None]) you can disable generating the " "response model from the type annotation with the path operation decorator " "parameter response_model=None. Read more: " "https://fastapi.tiangolo.com/tutorial/response-model/" ) from None def create_cloned_field( field: ModelField, *, cloned_types: Optional[MutableMapping[Type[BaseModel], Type[BaseModel]]] = None, ) -> ModelField: if PYDANTIC_V2: return field # cloned_types caches already cloned types to support recursive models and improve # performance by avoiding unnecessary cloning if cloned_types is None: cloned_types = _CLONED_TYPES_CACHE original_type = field.type_ if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"): original_type = original_type.__pydantic_model__ use_type = original_type if lenient_issubclass(original_type, BaseModel): original_type = cast(Type[BaseModel], original_type) use_type = cloned_types.get(original_type) if use_type is None: use_type = create_model(original_type.__name__, __base__=original_type) cloned_types[original_type] = use_type for f in original_type.__fields__.values(): use_type.__fields__[f.name] = create_cloned_field( f, cloned_types=cloned_types ) new_field = create_model_field(name=field.name, type_=use_type) new_field.has_alias = field.has_alias # type: ignore[attr-defined] new_field.alias = field.alias # type: ignore[misc] new_field.class_validators = field.class_validators # type: ignore[attr-defined] new_field.default = field.default # type: ignore[misc] new_field.required = field.required # type: ignore[misc] new_field.model_config = field.model_config # type: ignore[attr-defined] new_field.field_info = field.field_info new_field.allow_none = field.allow_none # type: ignore[attr-defined] new_field.validate_always = field.validate_always # type: ignore[attr-defined] if field.sub_fields: # type: ignore[attr-defined] new_field.sub_fields = [ # type: ignore[attr-defined] create_cloned_field(sub_field, cloned_types=cloned_types) for sub_field in field.sub_fields # type: ignore[attr-defined] ] if field.key_field: # type: ignore[attr-defined] new_field.key_field = create_cloned_field( # type: ignore[attr-defined] field.key_field, # type: ignore[attr-defined] cloned_types=cloned_types, ) new_field.validators = field.validators # type: ignore[attr-defined] new_field.pre_validators = field.pre_validators # type: ignore[attr-defined] new_field.post_validators = field.post_validators # type: ignore[attr-defined] new_field.parse_json = field.parse_json # type: ignore[attr-defined] new_field.shape = field.shape # type: ignore[attr-defined] new_field.populate_validators() # type: ignore[attr-defined] return new_field def generate_operation_id_for_path( *, name: str, path: str, method: str ) -> str: # pragma: nocover warnings.warn( "fastapi.utils.generate_operation_id_for_path() was deprecated, " "it is not used internally, and will be removed soon", DeprecationWarning, stacklevel=2, ) operation_id = f"{name}{path}" operation_id = re.sub(r"\W", "_", operation_id) operation_id = f"{operation_id}_{method.lower()}" return operation_id def generate_unique_id(route: "APIRoute") -> str: operation_id = f"{route.name}{route.path_format}" operation_id = re.sub(r"\W", "_", operation_id) assert route.methods operation_id = f"{operation_id}_{list(route.methods)[0].lower()}" return operation_id def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None: for key, value in update_dict.items(): if ( key in main_dict and isinstance(main_dict[key], dict) and isinstance(value, dict) ): deep_dict_update(main_dict[key], value) elif ( key in main_dict and isinstance(main_dict[key], list) and isinstance(update_dict[key], list) ): main_dict[key] = main_dict[key] + update_dict[key] else: main_dict[key] = value def get_value_or_default( first_item: Union[DefaultPlaceholder, DefaultType], *extra_items: Union[DefaultPlaceholder, DefaultType], ) -> Union[DefaultPlaceholder, DefaultType]: """ Pass items or `DefaultPlaceholder`s by descending priority. The first one to _not_ be a `DefaultPlaceholder` will be returned. Otherwise, the first item (a `DefaultPlaceholder`) will be returned. """ items = (first_item,) + extra_items for item in items: if not isinstance(item, DefaultPlaceholder): return item return first_item