Spaces:
Build error
Build error
import asyncio | |
import enum | |
import io | |
import json | |
import mimetypes | |
import os | |
import warnings | |
from abc import ABC, abstractmethod | |
from itertools import chain | |
from typing import ( | |
IO, | |
TYPE_CHECKING, | |
Any, | |
ByteString, | |
Dict, | |
Final, | |
Iterable, | |
Optional, | |
TextIO, | |
Tuple, | |
Type, | |
Union, | |
) | |
from multidict import CIMultiDict | |
from . import hdrs | |
from .abc import AbstractStreamWriter | |
from .helpers import ( | |
_SENTINEL, | |
content_disposition_header, | |
guess_filename, | |
parse_mimetype, | |
sentinel, | |
) | |
from .streams import StreamReader | |
from .typedefs import JSONEncoder, _CIMultiDict | |
__all__ = ( | |
"PAYLOAD_REGISTRY", | |
"get_payload", | |
"payload_type", | |
"Payload", | |
"BytesPayload", | |
"StringPayload", | |
"IOBasePayload", | |
"BytesIOPayload", | |
"BufferedReaderPayload", | |
"TextIOPayload", | |
"StringIOPayload", | |
"JsonPayload", | |
"AsyncIterablePayload", | |
) | |
TOO_LARGE_BYTES_BODY: Final[int] = 2**20 # 1 MB | |
if TYPE_CHECKING: | |
from typing import List | |
class LookupError(Exception): | |
pass | |
class Order(str, enum.Enum): | |
normal = "normal" | |
try_first = "try_first" | |
try_last = "try_last" | |
def get_payload(data: Any, *args: Any, **kwargs: Any) -> "Payload": | |
return PAYLOAD_REGISTRY.get(data, *args, **kwargs) | |
def register_payload( | |
factory: Type["Payload"], type: Any, *, order: Order = Order.normal | |
) -> None: | |
PAYLOAD_REGISTRY.register(factory, type, order=order) | |
class payload_type: | |
def __init__(self, type: Any, *, order: Order = Order.normal) -> None: | |
self.type = type | |
self.order = order | |
def __call__(self, factory: Type["Payload"]) -> Type["Payload"]: | |
register_payload(factory, self.type, order=self.order) | |
return factory | |
PayloadType = Type["Payload"] | |
_PayloadRegistryItem = Tuple[PayloadType, Any] | |
class PayloadRegistry: | |
"""Payload registry. | |
note: we need zope.interface for more efficient adapter search | |
""" | |
def __init__(self) -> None: | |
self._first: List[_PayloadRegistryItem] = [] | |
self._normal: List[_PayloadRegistryItem] = [] | |
self._last: List[_PayloadRegistryItem] = [] | |
def get( | |
self, | |
data: Any, | |
*args: Any, | |
_CHAIN: "Type[chain[_PayloadRegistryItem]]" = chain, | |
**kwargs: Any, | |
) -> "Payload": | |
if isinstance(data, Payload): | |
return data | |
for factory, type in _CHAIN(self._first, self._normal, self._last): | |
if isinstance(data, type): | |
return factory(data, *args, **kwargs) | |
raise LookupError() | |
def register( | |
self, factory: PayloadType, type: Any, *, order: Order = Order.normal | |
) -> None: | |
if order is Order.try_first: | |
self._first.append((factory, type)) | |
elif order is Order.normal: | |
self._normal.append((factory, type)) | |
elif order is Order.try_last: | |
self._last.append((factory, type)) | |
else: | |
raise ValueError(f"Unsupported order {order!r}") | |
class Payload(ABC): | |
_default_content_type: str = "application/octet-stream" | |
_size: Optional[int] = None | |
def __init__( | |
self, | |
value: Any, | |
headers: Optional[ | |
Union[_CIMultiDict, Dict[str, str], Iterable[Tuple[str, str]]] | |
] = None, | |
content_type: Union[str, None, _SENTINEL] = sentinel, | |
filename: Optional[str] = None, | |
encoding: Optional[str] = None, | |
**kwargs: Any, | |
) -> None: | |
self._encoding = encoding | |
self._filename = filename | |
self._headers: _CIMultiDict = CIMultiDict() | |
self._value = value | |
if content_type is not sentinel and content_type is not None: | |
self._headers[hdrs.CONTENT_TYPE] = content_type | |
elif self._filename is not None: | |
content_type = mimetypes.guess_type(self._filename)[0] | |
if content_type is None: | |
content_type = self._default_content_type | |
self._headers[hdrs.CONTENT_TYPE] = content_type | |
else: | |
self._headers[hdrs.CONTENT_TYPE] = self._default_content_type | |
self._headers.update(headers or {}) | |
def size(self) -> Optional[int]: | |
"""Size of the payload.""" | |
return self._size | |
def filename(self) -> Optional[str]: | |
"""Filename of the payload.""" | |
return self._filename | |
def headers(self) -> _CIMultiDict: | |
"""Custom item headers""" | |
return self._headers | |
def _binary_headers(self) -> bytes: | |
return ( | |
"".join([k + ": " + v + "\r\n" for k, v in self.headers.items()]).encode( | |
"utf-8" | |
) | |
+ b"\r\n" | |
) | |
def encoding(self) -> Optional[str]: | |
"""Payload encoding""" | |
return self._encoding | |
def content_type(self) -> str: | |
"""Content type""" | |
return self._headers[hdrs.CONTENT_TYPE] | |
def set_content_disposition( | |
self, | |
disptype: str, | |
quote_fields: bool = True, | |
_charset: str = "utf-8", | |
**params: Any, | |
) -> None: | |
"""Sets ``Content-Disposition`` header.""" | |
self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header( | |
disptype, quote_fields=quote_fields, _charset=_charset, **params | |
) | |
async def write(self, writer: AbstractStreamWriter) -> None: | |
"""Write payload. | |
writer is an AbstractStreamWriter instance: | |
""" | |
class BytesPayload(Payload): | |
def __init__(self, value: ByteString, *args: Any, **kwargs: Any) -> None: | |
if not isinstance(value, (bytes, bytearray, memoryview)): | |
raise TypeError(f"value argument must be byte-ish, not {type(value)!r}") | |
if "content_type" not in kwargs: | |
kwargs["content_type"] = "application/octet-stream" | |
super().__init__(value, *args, **kwargs) | |
if isinstance(value, memoryview): | |
self._size = value.nbytes | |
else: | |
self._size = len(value) | |
if self._size > TOO_LARGE_BYTES_BODY: | |
kwargs = {"source": self} | |
warnings.warn( | |
"Sending a large body directly with raw bytes might" | |
" lock the event loop. You should probably pass an " | |
"io.BytesIO object instead", | |
ResourceWarning, | |
**kwargs, | |
) | |
async def write(self, writer: AbstractStreamWriter) -> None: | |
await writer.write(self._value) | |
class StringPayload(BytesPayload): | |
def __init__( | |
self, | |
value: str, | |
*args: Any, | |
encoding: Optional[str] = None, | |
content_type: Optional[str] = None, | |
**kwargs: Any, | |
) -> None: | |
if encoding is None: | |
if content_type is None: | |
real_encoding = "utf-8" | |
content_type = "text/plain; charset=utf-8" | |
else: | |
mimetype = parse_mimetype(content_type) | |
real_encoding = mimetype.parameters.get("charset", "utf-8") | |
else: | |
if content_type is None: | |
content_type = "text/plain; charset=%s" % encoding | |
real_encoding = encoding | |
super().__init__( | |
value.encode(real_encoding), | |
encoding=real_encoding, | |
content_type=content_type, | |
*args, | |
**kwargs, | |
) | |
class StringIOPayload(StringPayload): | |
def __init__(self, value: IO[str], *args: Any, **kwargs: Any) -> None: | |
super().__init__(value.read(), *args, **kwargs) | |
class IOBasePayload(Payload): | |
_value: IO[Any] | |
def __init__( | |
self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any | |
) -> None: | |
if "filename" not in kwargs: | |
kwargs["filename"] = guess_filename(value) | |
super().__init__(value, *args, **kwargs) | |
if self._filename is not None and disposition is not None: | |
if hdrs.CONTENT_DISPOSITION not in self.headers: | |
self.set_content_disposition(disposition, filename=self._filename) | |
async def write(self, writer: AbstractStreamWriter) -> None: | |
loop = asyncio.get_event_loop() | |
try: | |
chunk = await loop.run_in_executor(None, self._value.read, 2**16) | |
while chunk: | |
await writer.write(chunk) | |
chunk = await loop.run_in_executor(None, self._value.read, 2**16) | |
finally: | |
await loop.run_in_executor(None, self._value.close) | |
class TextIOPayload(IOBasePayload): | |
_value: TextIO | |
def __init__( | |
self, | |
value: TextIO, | |
*args: Any, | |
encoding: Optional[str] = None, | |
content_type: Optional[str] = None, | |
**kwargs: Any, | |
) -> None: | |
if encoding is None: | |
if content_type is None: | |
encoding = "utf-8" | |
content_type = "text/plain; charset=utf-8" | |
else: | |
mimetype = parse_mimetype(content_type) | |
encoding = mimetype.parameters.get("charset", "utf-8") | |
else: | |
if content_type is None: | |
content_type = "text/plain; charset=%s" % encoding | |
super().__init__( | |
value, | |
content_type=content_type, | |
encoding=encoding, | |
*args, | |
**kwargs, | |
) | |
def size(self) -> Optional[int]: | |
try: | |
return os.fstat(self._value.fileno()).st_size - self._value.tell() | |
except OSError: | |
return None | |
async def write(self, writer: AbstractStreamWriter) -> None: | |
loop = asyncio.get_event_loop() | |
try: | |
chunk = await loop.run_in_executor(None, self._value.read, 2**16) | |
while chunk: | |
data = ( | |
chunk.encode(encoding=self._encoding) | |
if self._encoding | |
else chunk.encode() | |
) | |
await writer.write(data) | |
chunk = await loop.run_in_executor(None, self._value.read, 2**16) | |
finally: | |
await loop.run_in_executor(None, self._value.close) | |
class BytesIOPayload(IOBasePayload): | |
def size(self) -> int: | |
position = self._value.tell() | |
end = self._value.seek(0, os.SEEK_END) | |
self._value.seek(position) | |
return end - position | |
class BufferedReaderPayload(IOBasePayload): | |
def size(self) -> Optional[int]: | |
try: | |
return os.fstat(self._value.fileno()).st_size - self._value.tell() | |
except OSError: | |
# data.fileno() is not supported, e.g. | |
# io.BufferedReader(io.BytesIO(b'data')) | |
return None | |
class JsonPayload(BytesPayload): | |
def __init__( | |
self, | |
value: Any, | |
encoding: str = "utf-8", | |
content_type: str = "application/json", | |
dumps: JSONEncoder = json.dumps, | |
*args: Any, | |
**kwargs: Any, | |
) -> None: | |
super().__init__( | |
dumps(value).encode(encoding), | |
content_type=content_type, | |
encoding=encoding, | |
*args, | |
**kwargs, | |
) | |
if TYPE_CHECKING: | |
from typing import AsyncIterable, AsyncIterator | |
_AsyncIterator = AsyncIterator[bytes] | |
_AsyncIterable = AsyncIterable[bytes] | |
else: | |
from collections.abc import AsyncIterable, AsyncIterator | |
_AsyncIterator = AsyncIterator | |
_AsyncIterable = AsyncIterable | |
class AsyncIterablePayload(Payload): | |
_iter: Optional[_AsyncIterator] = None | |
def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None: | |
if not isinstance(value, AsyncIterable): | |
raise TypeError( | |
"value argument must support " | |
"collections.abc.AsyncIterable interface, " | |
"got {!r}".format(type(value)) | |
) | |
if "content_type" not in kwargs: | |
kwargs["content_type"] = "application/octet-stream" | |
super().__init__(value, *args, **kwargs) | |
self._iter = value.__aiter__() | |
async def write(self, writer: AbstractStreamWriter) -> None: | |
if self._iter: | |
try: | |
# iter is not None check prevents rare cases | |
# when the case iterable is used twice | |
while True: | |
chunk = await self._iter.__anext__() | |
await writer.write(chunk) | |
except StopAsyncIteration: | |
self._iter = None | |
class StreamReaderPayload(AsyncIterablePayload): | |
def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None: | |
super().__init__(value.iter_any(), *args, **kwargs) | |
PAYLOAD_REGISTRY = PayloadRegistry() | |
PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview)) | |
PAYLOAD_REGISTRY.register(StringPayload, str) | |
PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO) | |
PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase) | |
PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO) | |
PAYLOAD_REGISTRY.register(BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom)) | |
PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase) | |
PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader) | |
# try_last for giving a chance to more specialized async interables like | |
# multidict.BodyPartReaderPayload override the default | |
PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable, order=Order.try_last) | |