|
import copy |
|
from argparse import ArgumentError, ArgumentParser |
|
from collections.abc import Awaitable |
|
from contextvars import ContextVar |
|
from dataclasses import dataclass, field |
|
from io import BytesIO |
|
from pathlib import Path |
|
from typing import ( |
|
IO, |
|
Any, |
|
Callable, |
|
Literal, |
|
Optional, |
|
TypeVar, |
|
Union, |
|
cast, |
|
) |
|
|
|
from pil_utils import BuildImage |
|
from pydantic import BaseModel, ValidationError |
|
|
|
from .exception import ( |
|
ArgModelMismatch, |
|
ArgParserExit, |
|
ImageNumberMismatch, |
|
OpenImageFailed, |
|
ParserExit, |
|
TextNumberMismatch, |
|
TextOrNameNotEnough, |
|
) |
|
from .utils import is_coroutine_callable, random_image, random_text, run_sync |
|
|
|
|
|
class UserInfo(BaseModel): |
|
name: str = "" |
|
gender: Literal["male", "female", "unknown"] = "unknown" |
|
|
|
|
|
class MemeArgsModel(BaseModel): |
|
user_infos: list[UserInfo] = [] |
|
|
|
|
|
ArgsModel = TypeVar("ArgsModel", bound=MemeArgsModel) |
|
|
|
MemeFunction = Union[ |
|
Callable[[list[BuildImage], list[str], ArgsModel], BytesIO], |
|
Callable[[list[BuildImage], list[str], ArgsModel], Awaitable[BytesIO]], |
|
] |
|
|
|
|
|
parser_message: ContextVar[str] = ContextVar("parser_message") |
|
|
|
|
|
class MemeArgsParser(ArgumentParser): |
|
"""`shell_like` 命令参数解析器,解析出错时不会退出程序。 |
|
|
|
用法: |
|
用法与 `argparse.ArgumentParser` 相同, |
|
参考文档: [argparse](https://docs.python.org/3/library/argparse.html) |
|
""" |
|
|
|
def _print_message(self, message: str, file: Optional[IO[str]] = None): |
|
if (msg := parser_message.get(None)) is not None: |
|
parser_message.set(msg + message) |
|
else: |
|
super()._print_message(message, file) |
|
|
|
def exit(self, status: int = 0, message: Optional[str] = None): |
|
if message: |
|
self._print_message(message) |
|
raise ParserExit(status=status, error_message=parser_message.get(None)) |
|
|
|
|
|
@dataclass |
|
class MemeArgsType: |
|
parser: MemeArgsParser |
|
model: type[MemeArgsModel] |
|
instances: list[MemeArgsModel] = field(default_factory=list) |
|
|
|
|
|
@dataclass |
|
class MemeParamsType: |
|
min_images: int = 0 |
|
max_images: int = 0 |
|
min_texts: int = 0 |
|
max_texts: int = 0 |
|
default_texts: list[str] = field(default_factory=list) |
|
args_type: Optional[MemeArgsType] = None |
|
|
|
|
|
@dataclass |
|
class Meme: |
|
key: str |
|
function: MemeFunction |
|
params_type: MemeParamsType |
|
keywords: list[str] = field(default_factory=list) |
|
patterns: list[str] = field(default_factory=list) |
|
|
|
async def __call__( |
|
self, |
|
*, |
|
images: Union[list[str], list[Path], list[bytes], list[BytesIO]] = [], |
|
texts: list[str] = [], |
|
args: dict[str, Any] = {}, |
|
) -> BytesIO: |
|
if not ( |
|
self.params_type.min_images <= len(images) <= self.params_type.max_images |
|
): |
|
raise ImageNumberMismatch( |
|
self.key, self.params_type.min_images, self.params_type.max_images |
|
) |
|
|
|
if not (self.params_type.min_texts <= len(texts) <= self.params_type.max_texts): |
|
raise TextNumberMismatch( |
|
self.key, self.params_type.min_texts, self.params_type.max_texts |
|
) |
|
|
|
if args_type := self.params_type.args_type: |
|
args_model = args_type.model |
|
else: |
|
args_model = MemeArgsModel |
|
|
|
try: |
|
model = args_model.parse_obj(args) |
|
except ValidationError as e: |
|
raise ArgModelMismatch(self.key, str(e)) |
|
|
|
imgs: list[BuildImage] = [] |
|
try: |
|
for image in images: |
|
if isinstance(image, bytes): |
|
image = BytesIO(image) |
|
imgs.append(BuildImage.open(image)) |
|
except Exception as e: |
|
raise OpenImageFailed(str(e)) |
|
|
|
values = {"images": imgs, "texts": texts, "args": model} |
|
|
|
if is_coroutine_callable(self.function): |
|
return await cast(Callable[..., Awaitable[BytesIO]], self.function)( |
|
**values |
|
) |
|
else: |
|
return await run_sync(cast(Callable[..., BytesIO], self.function))(**values) |
|
|
|
def parse_args(self, args: list[str] = []) -> dict[str, Any]: |
|
parser = ( |
|
copy.deepcopy(self.params_type.args_type.parser) |
|
if self.params_type.args_type |
|
else MemeArgsParser() |
|
) |
|
parser.add_argument("texts", nargs="*", default=[]) |
|
t = parser_message.set("") |
|
try: |
|
return vars(parser.parse_args(args)) |
|
except ArgumentError as e: |
|
raise ArgParserExit(self.key, str(e)) |
|
except ParserExit as e: |
|
raise ArgParserExit(self.key, e.error_message) |
|
finally: |
|
parser_message.reset(t) |
|
|
|
async def generate_preview(self, *, args: dict[str, Any] = {}) -> BytesIO: |
|
default_images = [random_image() for _ in range(self.params_type.min_images)] |
|
default_texts = ( |
|
self.params_type.default_texts.copy() |
|
if ( |
|
self.params_type.min_texts |
|
<= len(self.params_type.default_texts) |
|
<= self.params_type.max_texts |
|
) |
|
else [random_text() for _ in range(self.params_type.min_texts)] |
|
) |
|
|
|
async def _generate_preview(images: list[BytesIO], texts: list[str]): |
|
try: |
|
return await self.__call__(images=images, texts=texts, args=args) |
|
except TextOrNameNotEnough: |
|
texts.append(random_text()) |
|
return await _generate_preview(images, texts) |
|
|
|
return await _generate_preview(default_images, default_texts) |
|
|