Spaces:
Running
Running
import re | |
import warnings | |
from dataclasses import is_dataclass | |
from typing import ( | |
TYPE_CHECKING, | |
Any, | |
Dict, | |
MutableMapping, | |
Optional, | |
Set, | |
Type, | |
Union, | |
cast, | |
) | |
from weakref import WeakKeyDictionary | |
import fastapi | |
from fastapi._compat import ( | |
PYDANTIC_V2, | |
BaseConfig, | |
ModelField, | |
PydanticSchemaGenerationError, | |
Undefined, | |
UndefinedType, | |
Validator, | |
lenient_issubclass, | |
) | |
from fastapi.datastructures import DefaultPlaceholder, DefaultType | |
from pydantic import BaseModel, create_model | |
from pydantic.fields import FieldInfo | |
from typing_extensions import Literal | |
if TYPE_CHECKING: # pragma: nocover | |
from .routing import APIRoute | |
# Cache for `create_cloned_field` | |
_CLONED_TYPES_CACHE: MutableMapping[Type[BaseModel], Type[BaseModel]] = ( | |
WeakKeyDictionary() | |
) | |
def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool: | |
if status_code is None: | |
return True | |
# Ref: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#patterned-fields-1 | |
if status_code in { | |
"default", | |
"1XX", | |
"2XX", | |
"3XX", | |
"4XX", | |
"5XX", | |
}: | |
return True | |
current_status_code = int(status_code) | |
return not (current_status_code < 200 or current_status_code in {204, 205, 304}) | |
def get_path_param_names(path: str) -> Set[str]: | |
return set(re.findall("{(.*?)}", path)) | |
def create_model_field( | |
name: str, | |
type_: Any, | |
class_validators: Optional[Dict[str, Validator]] = None, | |
default: Optional[Any] = Undefined, | |
required: Union[bool, UndefinedType] = Undefined, | |
model_config: Type[BaseConfig] = BaseConfig, | |
field_info: Optional[FieldInfo] = None, | |
alias: Optional[str] = None, | |
mode: Literal["validation", "serialization"] = "validation", | |
) -> ModelField: | |
class_validators = class_validators or {} | |
if PYDANTIC_V2: | |
field_info = field_info or FieldInfo( | |
annotation=type_, default=default, alias=alias | |
) | |
else: | |
field_info = field_info or FieldInfo() | |
kwargs = {"name": name, "field_info": field_info} | |
if PYDANTIC_V2: | |
kwargs.update({"mode": mode}) | |
else: | |
kwargs.update( | |
{ | |
"type_": type_, | |
"class_validators": class_validators, | |
"default": default, | |
"required": required, | |
"model_config": model_config, | |
"alias": alias, | |
} | |
) | |
try: | |
return ModelField(**kwargs) # type: ignore[arg-type] | |
except (RuntimeError, PydanticSchemaGenerationError): | |
raise fastapi.exceptions.FastAPIError( | |
"Invalid args for response field! Hint: " | |
f"check that {type_} is a valid Pydantic field type. " | |
"If you are using a return type annotation that is not a valid Pydantic " | |
"field (e.g. Union[Response, dict, None]) you can disable generating the " | |
"response model from the type annotation with the path operation decorator " | |
"parameter response_model=None. Read more: " | |
"https://fastapi.tiangolo.com/tutorial/response-model/" | |
) from None | |
def create_cloned_field( | |
field: ModelField, | |
*, | |
cloned_types: Optional[MutableMapping[Type[BaseModel], Type[BaseModel]]] = None, | |
) -> ModelField: | |
if PYDANTIC_V2: | |
return field | |
# cloned_types caches already cloned types to support recursive models and improve | |
# performance by avoiding unnecessary cloning | |
if cloned_types is None: | |
cloned_types = _CLONED_TYPES_CACHE | |
original_type = field.type_ | |
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"): | |
original_type = original_type.__pydantic_model__ | |
use_type = original_type | |
if lenient_issubclass(original_type, BaseModel): | |
original_type = cast(Type[BaseModel], original_type) | |
use_type = cloned_types.get(original_type) | |
if use_type is None: | |
use_type = create_model(original_type.__name__, __base__=original_type) | |
cloned_types[original_type] = use_type | |
for f in original_type.__fields__.values(): | |
use_type.__fields__[f.name] = create_cloned_field( | |
f, cloned_types=cloned_types | |
) | |
new_field = create_model_field(name=field.name, type_=use_type) | |
new_field.has_alias = field.has_alias # type: ignore[attr-defined] | |
new_field.alias = field.alias # type: ignore[misc] | |
new_field.class_validators = field.class_validators # type: ignore[attr-defined] | |
new_field.default = field.default # type: ignore[misc] | |
new_field.required = field.required # type: ignore[misc] | |
new_field.model_config = field.model_config # type: ignore[attr-defined] | |
new_field.field_info = field.field_info | |
new_field.allow_none = field.allow_none # type: ignore[attr-defined] | |
new_field.validate_always = field.validate_always # type: ignore[attr-defined] | |
if field.sub_fields: # type: ignore[attr-defined] | |
new_field.sub_fields = [ # type: ignore[attr-defined] | |
create_cloned_field(sub_field, cloned_types=cloned_types) | |
for sub_field in field.sub_fields # type: ignore[attr-defined] | |
] | |
if field.key_field: # type: ignore[attr-defined] | |
new_field.key_field = create_cloned_field( # type: ignore[attr-defined] | |
field.key_field, # type: ignore[attr-defined] | |
cloned_types=cloned_types, | |
) | |
new_field.validators = field.validators # type: ignore[attr-defined] | |
new_field.pre_validators = field.pre_validators # type: ignore[attr-defined] | |
new_field.post_validators = field.post_validators # type: ignore[attr-defined] | |
new_field.parse_json = field.parse_json # type: ignore[attr-defined] | |
new_field.shape = field.shape # type: ignore[attr-defined] | |
new_field.populate_validators() # type: ignore[attr-defined] | |
return new_field | |
def generate_operation_id_for_path( | |
*, name: str, path: str, method: str | |
) -> str: # pragma: nocover | |
warnings.warn( | |
"fastapi.utils.generate_operation_id_for_path() was deprecated, " | |
"it is not used internally, and will be removed soon", | |
DeprecationWarning, | |
stacklevel=2, | |
) | |
operation_id = f"{name}{path}" | |
operation_id = re.sub(r"\W", "_", operation_id) | |
operation_id = f"{operation_id}_{method.lower()}" | |
return operation_id | |
def generate_unique_id(route: "APIRoute") -> str: | |
operation_id = f"{route.name}{route.path_format}" | |
operation_id = re.sub(r"\W", "_", operation_id) | |
assert route.methods | |
operation_id = f"{operation_id}_{list(route.methods)[0].lower()}" | |
return operation_id | |
def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None: | |
for key, value in update_dict.items(): | |
if ( | |
key in main_dict | |
and isinstance(main_dict[key], dict) | |
and isinstance(value, dict) | |
): | |
deep_dict_update(main_dict[key], value) | |
elif ( | |
key in main_dict | |
and isinstance(main_dict[key], list) | |
and isinstance(update_dict[key], list) | |
): | |
main_dict[key] = main_dict[key] + update_dict[key] | |
else: | |
main_dict[key] = value | |
def get_value_or_default( | |
first_item: Union[DefaultPlaceholder, DefaultType], | |
*extra_items: Union[DefaultPlaceholder, DefaultType], | |
) -> Union[DefaultPlaceholder, DefaultType]: | |
""" | |
Pass items or `DefaultPlaceholder`s by descending priority. | |
The first one to _not_ be a `DefaultPlaceholder` will be returned. | |
Otherwise, the first item (a `DefaultPlaceholder`) will be returned. | |
""" | |
items = (first_item,) + extra_items | |
for item in items: | |
if not isinstance(item, DefaultPlaceholder): | |
return item | |
return first_item | |