JAMESPARK3's picture
Upload folder using huggingface_hub
1380717 verified
raw
history blame
7.95 kB
import re
import warnings
from dataclasses import is_dataclass
from typing import (
TYPE_CHECKING,
Any,
Dict,
MutableMapping,
Optional,
Set,
Type,
Union,
cast,
)
from weakref import WeakKeyDictionary
import fastapi
from fastapi._compat import (
PYDANTIC_V2,
BaseConfig,
ModelField,
PydanticSchemaGenerationError,
Undefined,
UndefinedType,
Validator,
lenient_issubclass,
)
from fastapi.datastructures import DefaultPlaceholder, DefaultType
from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo
from typing_extensions import Literal
if TYPE_CHECKING: # pragma: nocover
from .routing import APIRoute
# Cache for `create_cloned_field`
_CLONED_TYPES_CACHE: MutableMapping[Type[BaseModel], Type[BaseModel]] = (
WeakKeyDictionary()
)
def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool:
if status_code is None:
return True
# Ref: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#patterned-fields-1
if status_code in {
"default",
"1XX",
"2XX",
"3XX",
"4XX",
"5XX",
}:
return True
current_status_code = int(status_code)
return not (current_status_code < 200 or current_status_code in {204, 205, 304})
def get_path_param_names(path: str) -> Set[str]:
return set(re.findall("{(.*?)}", path))
def create_model_field(
name: str,
type_: Any,
class_validators: Optional[Dict[str, Validator]] = None,
default: Optional[Any] = Undefined,
required: Union[bool, UndefinedType] = Undefined,
model_config: Type[BaseConfig] = BaseConfig,
field_info: Optional[FieldInfo] = None,
alias: Optional[str] = None,
mode: Literal["validation", "serialization"] = "validation",
) -> ModelField:
class_validators = class_validators or {}
if PYDANTIC_V2:
field_info = field_info or FieldInfo(
annotation=type_, default=default, alias=alias
)
else:
field_info = field_info or FieldInfo()
kwargs = {"name": name, "field_info": field_info}
if PYDANTIC_V2:
kwargs.update({"mode": mode})
else:
kwargs.update(
{
"type_": type_,
"class_validators": class_validators,
"default": default,
"required": required,
"model_config": model_config,
"alias": alias,
}
)
try:
return ModelField(**kwargs) # type: ignore[arg-type]
except (RuntimeError, PydanticSchemaGenerationError):
raise fastapi.exceptions.FastAPIError(
"Invalid args for response field! Hint: "
f"check that {type_} is a valid Pydantic field type. "
"If you are using a return type annotation that is not a valid Pydantic "
"field (e.g. Union[Response, dict, None]) you can disable generating the "
"response model from the type annotation with the path operation decorator "
"parameter response_model=None. Read more: "
"https://fastapi.tiangolo.com/tutorial/response-model/"
) from None
def create_cloned_field(
field: ModelField,
*,
cloned_types: Optional[MutableMapping[Type[BaseModel], Type[BaseModel]]] = None,
) -> ModelField:
if PYDANTIC_V2:
return field
# cloned_types caches already cloned types to support recursive models and improve
# performance by avoiding unnecessary cloning
if cloned_types is None:
cloned_types = _CLONED_TYPES_CACHE
original_type = field.type_
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
original_type = original_type.__pydantic_model__
use_type = original_type
if lenient_issubclass(original_type, BaseModel):
original_type = cast(Type[BaseModel], original_type)
use_type = cloned_types.get(original_type)
if use_type is None:
use_type = create_model(original_type.__name__, __base__=original_type)
cloned_types[original_type] = use_type
for f in original_type.__fields__.values():
use_type.__fields__[f.name] = create_cloned_field(
f, cloned_types=cloned_types
)
new_field = create_model_field(name=field.name, type_=use_type)
new_field.has_alias = field.has_alias # type: ignore[attr-defined]
new_field.alias = field.alias # type: ignore[misc]
new_field.class_validators = field.class_validators # type: ignore[attr-defined]
new_field.default = field.default # type: ignore[misc]
new_field.required = field.required # type: ignore[misc]
new_field.model_config = field.model_config # type: ignore[attr-defined]
new_field.field_info = field.field_info
new_field.allow_none = field.allow_none # type: ignore[attr-defined]
new_field.validate_always = field.validate_always # type: ignore[attr-defined]
if field.sub_fields: # type: ignore[attr-defined]
new_field.sub_fields = [ # type: ignore[attr-defined]
create_cloned_field(sub_field, cloned_types=cloned_types)
for sub_field in field.sub_fields # type: ignore[attr-defined]
]
if field.key_field: # type: ignore[attr-defined]
new_field.key_field = create_cloned_field( # type: ignore[attr-defined]
field.key_field, # type: ignore[attr-defined]
cloned_types=cloned_types,
)
new_field.validators = field.validators # type: ignore[attr-defined]
new_field.pre_validators = field.pre_validators # type: ignore[attr-defined]
new_field.post_validators = field.post_validators # type: ignore[attr-defined]
new_field.parse_json = field.parse_json # type: ignore[attr-defined]
new_field.shape = field.shape # type: ignore[attr-defined]
new_field.populate_validators() # type: ignore[attr-defined]
return new_field
def generate_operation_id_for_path(
*, name: str, path: str, method: str
) -> str: # pragma: nocover
warnings.warn(
"fastapi.utils.generate_operation_id_for_path() was deprecated, "
"it is not used internally, and will be removed soon",
DeprecationWarning,
stacklevel=2,
)
operation_id = f"{name}{path}"
operation_id = re.sub(r"\W", "_", operation_id)
operation_id = f"{operation_id}_{method.lower()}"
return operation_id
def generate_unique_id(route: "APIRoute") -> str:
operation_id = f"{route.name}{route.path_format}"
operation_id = re.sub(r"\W", "_", operation_id)
assert route.methods
operation_id = f"{operation_id}_{list(route.methods)[0].lower()}"
return operation_id
def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None:
for key, value in update_dict.items():
if (
key in main_dict
and isinstance(main_dict[key], dict)
and isinstance(value, dict)
):
deep_dict_update(main_dict[key], value)
elif (
key in main_dict
and isinstance(main_dict[key], list)
and isinstance(update_dict[key], list)
):
main_dict[key] = main_dict[key] + update_dict[key]
else:
main_dict[key] = value
def get_value_or_default(
first_item: Union[DefaultPlaceholder, DefaultType],
*extra_items: Union[DefaultPlaceholder, DefaultType],
) -> Union[DefaultPlaceholder, DefaultType]:
"""
Pass items or `DefaultPlaceholder`s by descending priority.
The first one to _not_ be a `DefaultPlaceholder` will be returned.
Otherwise, the first item (a `DefaultPlaceholder`) will be returned.
"""
items = (first_item,) + extra_items
for item in items:
if not isinstance(item, DefaultPlaceholder):
return item
return first_item