JAMESPARK3's picture
Upload folder using huggingface_hub
1380717 verified
raw
history blame
23.9 kB
from collections import deque
from copy import copy
from dataclasses import dataclass, is_dataclass
from enum import Enum
from functools import lru_cache
from typing import (
Any,
Callable,
Deque,
Dict,
FrozenSet,
List,
Mapping,
Sequence,
Set,
Tuple,
Type,
Union,
)
from fastapi.exceptions import RequestErrorModel
from fastapi.types import IncEx, ModelNameMap, UnionType
from pydantic import BaseModel, create_model
from pydantic.version import VERSION as P_VERSION
from starlette.datastructures import UploadFile
from typing_extensions import Annotated, Literal, get_args, get_origin
# Reassign variable to make it reexported for mypy
PYDANTIC_VERSION = P_VERSION
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
sequence_annotation_to_type = {
Sequence: list,
List: list,
list: list,
Tuple: tuple,
tuple: tuple,
Set: set,
set: set,
FrozenSet: frozenset,
frozenset: frozenset,
Deque: deque,
deque: deque,
}
sequence_types = tuple(sequence_annotation_to_type.keys())
if PYDANTIC_V2:
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from pydantic import TypeAdapter
from pydantic import ValidationError as ValidationError
from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined]
GetJsonSchemaHandler as GetJsonSchemaHandler,
)
from pydantic._internal._typing_extra import eval_type_lenient
from pydantic._internal._utils import lenient_issubclass as lenient_issubclass
from pydantic.fields import FieldInfo
from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema
from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue
from pydantic_core import CoreSchema as CoreSchema
from pydantic_core import PydanticUndefined, PydanticUndefinedType
from pydantic_core import Url as Url
try:
from pydantic_core.core_schema import (
with_info_plain_validator_function as with_info_plain_validator_function,
)
except ImportError: # pragma: no cover
from pydantic_core.core_schema import (
general_plain_validator_function as with_info_plain_validator_function, # noqa: F401
)
Required = PydanticUndefined
Undefined = PydanticUndefined
UndefinedType = PydanticUndefinedType
evaluate_forwardref = eval_type_lenient
Validator = Any
class BaseConfig:
pass
class ErrorWrapper(Exception):
pass
@dataclass
class ModelField:
field_info: FieldInfo
name: str
mode: Literal["validation", "serialization"] = "validation"
@property
def alias(self) -> str:
a = self.field_info.alias
return a if a is not None else self.name
@property
def required(self) -> bool:
return self.field_info.is_required()
@property
def default(self) -> Any:
return self.get_default()
@property
def type_(self) -> Any:
return self.field_info.annotation
def __post_init__(self) -> None:
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[self.field_info.annotation, self.field_info]
)
def get_default(self) -> Any:
if self.field_info.is_required():
return Undefined
return self.field_info.get_default(call_default_factory=True)
def validate(
self,
value: Any,
values: Dict[str, Any] = {}, # noqa: B006
*,
loc: Tuple[Union[int, str], ...] = (),
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
try:
return (
self._type_adapter.validate_python(value, from_attributes=True),
None,
)
except ValidationError as exc:
return None, _regenerate_error_with_loc(
errors=exc.errors(include_url=False), loc_prefix=loc
)
def serialize(
self,
value: Any,
*,
mode: Literal["json", "python"] = "json",
include: Union[IncEx, None] = None,
exclude: Union[IncEx, None] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> Any:
# What calls this code passes a value that already called
# self._type_adapter.validate_python(value)
return self._type_adapter.dump_python(
value,
mode=mode,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
def __hash__(self) -> int:
# Each ModelField is unique for our purposes, to allow making a dict from
# ModelField to its JSON Schema.
return id(self)
def get_annotation_from_field_info(
annotation: Any, field_info: FieldInfo, field_name: str
) -> Any:
return annotation
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
return errors # type: ignore[return-value]
def _model_rebuild(model: Type[BaseModel]) -> None:
model.model_rebuild()
def _model_dump(
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
) -> Any:
return model.model_dump(mode=mode, **kwargs)
def _get_model_config(model: BaseModel) -> Any:
return model.model_config
def get_schema_from_model_field(
*,
field: ModelField,
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
separate_input_output_schemas: bool = True,
) -> Dict[str, Any]:
override_mode: Union[Literal["validation"], None] = (
None if separate_input_output_schemas else "validation"
)
# This expects that GenerateJsonSchema was already used to generate the definitions
json_schema = field_mapping[(field, override_mode or field.mode)]
if "$ref" not in json_schema:
# TODO remove when deprecating Pydantic v1
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
json_schema["title"] = (
field.field_info.title or field.alias.title().replace("_", " ")
)
return json_schema
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
return {}
def get_definitions(
*,
fields: List[ModelField],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> Tuple[
Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
Dict[str, Dict[str, Any]],
]:
override_mode: Union[Literal["validation"], None] = (
None if separate_input_output_schemas else "validation"
)
inputs = [
(field, override_mode or field.mode, field._type_adapter.core_schema)
for field in fields
]
field_mapping, definitions = schema_generator.generate_definitions(
inputs=inputs
)
return field_mapping, definitions # type: ignore[return-value]
def is_scalar_field(field: ModelField) -> bool:
from fastapi import params
return field_annotation_is_scalar(
field.field_info.annotation
) and not isinstance(field.field_info, params.Body)
def is_sequence_field(field: ModelField) -> bool:
return field_annotation_is_sequence(field.field_info.annotation)
def is_scalar_sequence_field(field: ModelField) -> bool:
return field_annotation_is_scalar_sequence(field.field_info.annotation)
def is_bytes_field(field: ModelField) -> bool:
return is_bytes_or_nonable_bytes_annotation(field.type_)
def is_bytes_sequence_field(field: ModelField) -> bool:
return is_bytes_sequence_annotation(field.type_)
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
cls = type(field_info)
merged_field_info = cls.from_annotation(annotation)
new_field_info = copy(field_info)
new_field_info.metadata = merged_field_info.metadata
new_field_info.annotation = merged_field_info.annotation
return new_field_info
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
origin_type = (
get_origin(field.field_info.annotation) or field.field_info.annotation
)
assert issubclass(origin_type, sequence_types) # type: ignore[arg-type]
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors(include_url=False)[0]
error["input"] = None
return error # type: ignore[return-value]
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> Type[BaseModel]:
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
return BodyModel
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return [
ModelField(field_info=field_info, name=name)
for name, field_info in model.model_fields.items()
]
else:
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
from pydantic import AnyUrl as Url # noqa: F401
from pydantic import ( # type: ignore[assignment]
BaseConfig as BaseConfig, # noqa: F401
)
from pydantic import ValidationError as ValidationError # noqa: F401
from pydantic.class_validators import ( # type: ignore[no-redef]
Validator as Validator, # noqa: F401
)
from pydantic.error_wrappers import ( # type: ignore[no-redef]
ErrorWrapper as ErrorWrapper, # noqa: F401
)
from pydantic.errors import MissingError
from pydantic.fields import ( # type: ignore[attr-defined]
SHAPE_FROZENSET,
SHAPE_LIST,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
SHAPE_TUPLE,
SHAPE_TUPLE_ELLIPSIS,
)
from pydantic.fields import FieldInfo as FieldInfo
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
ModelField as ModelField, # noqa: F401
)
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
Required as Required, # noqa: F401
)
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
Undefined as Undefined,
)
from pydantic.fields import ( # type: ignore[no-redef, attr-defined]
UndefinedType as UndefinedType, # noqa: F401
)
from pydantic.schema import (
field_schema,
get_flat_models_from_fields,
get_model_name_map,
model_process_schema,
)
from pydantic.schema import ( # type: ignore[no-redef] # noqa: F401
get_annotation_from_field_info as get_annotation_from_field_info,
)
from pydantic.typing import ( # type: ignore[no-redef]
evaluate_forwardref as evaluate_forwardref, # noqa: F401
)
from pydantic.utils import ( # type: ignore[no-redef]
lenient_issubclass as lenient_issubclass, # noqa: F401
)
GetJsonSchemaHandler = Any # type: ignore[assignment,misc]
JsonSchemaValue = Dict[str, Any] # type: ignore[misc]
CoreSchema = Any # type: ignore[assignment,misc]
sequence_shapes = {
SHAPE_LIST,
SHAPE_SET,
SHAPE_FROZENSET,
SHAPE_TUPLE,
SHAPE_SEQUENCE,
SHAPE_TUPLE_ELLIPSIS,
}
sequence_shape_to_type = {
SHAPE_LIST: list,
SHAPE_SET: set,
SHAPE_TUPLE: tuple,
SHAPE_SEQUENCE: list,
SHAPE_TUPLE_ELLIPSIS: list,
}
@dataclass
class GenerateJsonSchema: # type: ignore[no-redef]
ref_template: str
class PydanticSchemaGenerationError(Exception): # type: ignore[no-redef]
pass
def with_info_plain_validator_function( # type: ignore[misc]
function: Callable[..., Any],
*,
ref: Union[str, None] = None,
metadata: Any = None,
serialization: Any = None,
) -> Any:
return {}
def get_model_definitions(
*,
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Dict[str, Any]:
definitions: Dict[str, Dict[str, Any]] = {}
for model in flat_models:
m_schema, m_definitions, m_nested_models = model_process_schema(
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
definitions.update(m_definitions)
model_name = model_name_map[model]
if "description" in m_schema:
m_schema["description"] = m_schema["description"].split("\f")[0]
definitions[model_name] = m_schema
return definitions
def is_pv1_scalar_field(field: ModelField) -> bool:
from fastapi import params
field_info = field.field_info
if not (
field.shape == SHAPE_SINGLETON # type: ignore[attr-defined]
and not lenient_issubclass(field.type_, BaseModel)
and not lenient_issubclass(field.type_, dict)
and not field_annotation_is_sequence(field.type_)
and not is_dataclass(field.type_)
and not isinstance(field_info, params.Body)
):
return False
if field.sub_fields: # type: ignore[attr-defined]
if not all(
is_pv1_scalar_field(f)
for f in field.sub_fields # type: ignore[attr-defined]
):
return False
return True
def is_pv1_scalar_sequence_field(field: ModelField) -> bool:
if (field.shape in sequence_shapes) and not lenient_issubclass( # type: ignore[attr-defined]
field.type_, BaseModel
):
if field.sub_fields is not None: # type: ignore[attr-defined]
for sub_field in field.sub_fields: # type: ignore[attr-defined]
if not is_pv1_scalar_field(sub_field):
return False
return True
if _annotation_is_sequence(field.type_):
return True
return False
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
use_errors: List[Any] = []
for error in errors:
if isinstance(error, ErrorWrapper):
new_errors = ValidationError( # type: ignore[call-arg]
errors=[error], model=RequestErrorModel
).errors()
use_errors.extend(new_errors)
elif isinstance(error, list):
use_errors.extend(_normalize_errors(error))
else:
use_errors.append(error)
return use_errors
def _model_rebuild(model: Type[BaseModel]) -> None:
model.update_forward_refs()
def _model_dump(
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
) -> Any:
return model.dict(**kwargs)
def _get_model_config(model: BaseModel) -> Any:
return model.__config__ # type: ignore[attr-defined]
def get_schema_from_model_field(
*,
field: ModelField,
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
separate_input_output_schemas: bool = True,
) -> Dict[str, Any]:
# This expects that GenerateJsonSchema was already used to generate the definitions
return field_schema( # type: ignore[no-any-return]
field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)[0]
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
models = get_flat_models_from_fields(fields, known_models=set())
return get_model_name_map(models) # type: ignore[no-any-return]
def get_definitions(
*,
fields: List[ModelField],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> Tuple[
Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
Dict[str, Dict[str, Any]],
]:
models = get_flat_models_from_fields(fields, known_models=set())
return {}, get_model_definitions(
flat_models=models, model_name_map=model_name_map
)
def is_scalar_field(field: ModelField) -> bool:
return is_pv1_scalar_field(field)
def is_sequence_field(field: ModelField) -> bool:
return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined]
def is_scalar_sequence_field(field: ModelField) -> bool:
return is_pv1_scalar_sequence_field(field)
def is_bytes_field(field: ModelField) -> bool:
return lenient_issubclass(field.type_, bytes)
def is_bytes_sequence_field(field: ModelField) -> bool:
return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes) # type: ignore[attr-defined]
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
return copy(field_info)
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined]
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg]
new_error = ValidationError([missing_field_error], RequestErrorModel)
return new_error.errors()[0] # type: ignore[return-value]
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> Type[BaseModel]:
BodyModel = create_model(model_name)
for f in fields:
BodyModel.__fields__[f.name] = f # type: ignore[index]
return BodyModel
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return list(model.__fields__.values()) # type: ignore[attr-defined]
def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
) -> List[Dict[str, Any]]:
updated_loc_errors: List[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())}
for err in _normalize_errors(errors)
]
return updated_loc_errors
def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
if lenient_issubclass(annotation, (str, bytes)):
return False
return lenient_issubclass(annotation, sequence_types)
def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if field_annotation_is_sequence(arg):
return True
return False
return _annotation_is_sequence(annotation) or _annotation_is_sequence(
get_origin(annotation)
)
def value_is_sequence(value: Any) -> bool:
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type]
def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
return (
lenient_issubclass(annotation, (BaseModel, Mapping, UploadFile))
or _annotation_is_sequence(annotation)
or is_dataclass(annotation)
)
def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
return (
_annotation_is_complex(annotation)
or _annotation_is_complex(origin)
or hasattr(origin, "__pydantic_core_schema__")
or hasattr(origin, "__get_pydantic_core_schema__")
)
def field_annotation_is_scalar(annotation: Any) -> bool:
# handle Ellipsis here to make tuple[int, ...] work nicely
return annotation is Ellipsis or not field_annotation_is_complex(annotation)
def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one_scalar_sequence = False
for arg in get_args(annotation):
if field_annotation_is_scalar_sequence(arg):
at_least_one_scalar_sequence = True
continue
elif not field_annotation_is_scalar(arg):
return False
return at_least_one_scalar_sequence
return field_annotation_is_sequence(annotation) and all(
field_annotation_is_scalar(sub_annotation)
for sub_annotation in get_args(annotation)
)
def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
if lenient_issubclass(annotation, bytes):
return True
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if lenient_issubclass(arg, bytes):
return True
return False
def is_uploadfile_or_nonable_uploadfile_annotation(annotation: Any) -> bool:
if lenient_issubclass(annotation, UploadFile):
return True
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if lenient_issubclass(arg, UploadFile):
return True
return False
def is_bytes_sequence_annotation(annotation: Any) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one = False
for arg in get_args(annotation):
if is_bytes_sequence_annotation(arg):
at_least_one = True
continue
return at_least_one
return field_annotation_is_sequence(annotation) and all(
is_bytes_or_nonable_bytes_annotation(sub_annotation)
for sub_annotation in get_args(annotation)
)
def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one = False
for arg in get_args(annotation):
if is_uploadfile_sequence_annotation(arg):
at_least_one = True
continue
return at_least_one
return field_annotation_is_sequence(annotation) and all(
is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation)
for sub_annotation in get_args(annotation)
)
@lru_cache
def get_cached_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return get_model_fields(model)