|
from typing import Any, Literal, Optional |
|
|
|
import filetype |
|
from fastapi import Depends, FastAPI, Form, HTTPException, Response, UploadFile |
|
from pil_utils.types import ColorType, FontStyle, FontWeight |
|
from pydantic import BaseModel, ValidationError |
|
|
|
from meme_generator.config import meme_config |
|
from meme_generator.exception import MemeGeneratorException, NoSuchMeme |
|
from meme_generator.log import LOGGING_CONFIG, setup_logger |
|
from meme_generator.manager import get_meme, get_meme_keys, get_memes |
|
from meme_generator.meme import Meme, MemeArgsModel |
|
from meme_generator.utils import TextProperties, render_meme_list |
|
|
|
app = FastAPI() |
|
|
|
|
|
class MemeArgsResponse(BaseModel): |
|
name: str |
|
type: str |
|
description: Optional[str] = None |
|
default: Optional[Any] = None |
|
enum: Optional[list[Any]] = None |
|
|
|
|
|
class MemeParamsResponse(BaseModel): |
|
min_images: int |
|
max_images: int |
|
min_texts: int |
|
max_texts: int |
|
default_texts: list[str] |
|
args: list[MemeArgsResponse] |
|
|
|
|
|
class MemeInfoResponse(BaseModel): |
|
key: str |
|
keywords: list[str] |
|
patterns: list[str] |
|
params: MemeParamsResponse |
|
|
|
|
|
def register_router(meme: Meme): |
|
if args_type := meme.params_type.args_type: |
|
args_model = args_type.model |
|
else: |
|
args_model = MemeArgsModel |
|
|
|
def args_checker(args: Optional[str] = Form(default=str(args_model().json()))): |
|
if not args: |
|
return MemeArgsModel() |
|
try: |
|
model = args_model.parse_raw(args) |
|
except ValidationError as e: |
|
raise HTTPException(status_code=552, detail=str(e)) |
|
return model |
|
|
|
@app.post(f"/memes/{meme.key}/") |
|
async def _( |
|
images: list[UploadFile] = [], |
|
texts: list[str] = meme.params_type.default_texts, |
|
args: args_model = Depends(args_checker), |
|
): |
|
imgs: list[bytes] = [] |
|
for image in images: |
|
imgs.append(await image.read()) |
|
|
|
texts = [text for text in texts if text] |
|
|
|
assert isinstance(args, args_model) |
|
|
|
try: |
|
result = await meme(images=imgs, texts=texts, args=args.dict()) |
|
except MemeGeneratorException as e: |
|
raise HTTPException(status_code=e.status_code, detail=str(e)) |
|
|
|
content = result.getvalue() |
|
media_type = str(filetype.guess_mime(content)) or "text/plain" |
|
return Response(content=content, media_type=media_type) |
|
|
|
|
|
class MemeKeyWithProperties(BaseModel): |
|
meme_key: str |
|
fill: ColorType = "black" |
|
style: FontStyle = "normal" |
|
weight: FontWeight = "normal" |
|
stroke_width: int = 0 |
|
stroke_fill: Optional[ColorType] = None |
|
|
|
|
|
default_meme_list = [ |
|
MemeKeyWithProperties(meme_key=meme.key) |
|
for meme in sorted(get_memes(), key=lambda meme: meme.key) |
|
] |
|
|
|
|
|
class RenderMemeListRequest(BaseModel): |
|
meme_list: list[MemeKeyWithProperties] = default_meme_list |
|
order_direction: Literal["row", "column"] = "column" |
|
columns: int = 4 |
|
column_align: Literal["left", "center", "right"] = "left" |
|
item_padding: tuple[int, int] = (15, 2) |
|
image_padding: tuple[int, int] = (50, 50) |
|
bg_color: ColorType = "white" |
|
fontsize: int = 30 |
|
fontname: str = "" |
|
fallback_fonts: list[str] = [] |
|
|
|
|
|
def register_routers(): |
|
@app.post("/memes/render_list") |
|
def _(params: RenderMemeListRequest = RenderMemeListRequest()): |
|
try: |
|
meme_list = [ |
|
( |
|
get_meme(p.meme_key), |
|
TextProperties( |
|
fill=p.fill, |
|
style=p.style, |
|
weight=p.weight, |
|
stroke_width=p.stroke_width, |
|
stroke_fill=p.stroke_fill, |
|
), |
|
) |
|
for p in params.meme_list |
|
] |
|
except NoSuchMeme as e: |
|
raise HTTPException(status_code=e.status_code, detail=str(e)) |
|
|
|
result = render_meme_list( |
|
meme_list, |
|
order_direction=params.order_direction, |
|
columns=params.columns, |
|
column_align=params.column_align, |
|
item_padding=params.item_padding, |
|
image_padding=params.image_padding, |
|
bg_color=params.bg_color, |
|
fontsize=params.fontsize, |
|
fontname=params.fontname, |
|
fallback_fonts=params.fallback_fonts, |
|
) |
|
content = result.getvalue() |
|
media_type = str(filetype.guess_mime(content)) or "text/plain" |
|
return Response(content=content, media_type=media_type) |
|
|
|
@app.get("/memes/keys") |
|
def _(): |
|
return get_meme_keys() |
|
|
|
@app.get("/memes/{key}/info") |
|
def _(key: str): |
|
try: |
|
meme = get_meme(key) |
|
except NoSuchMeme as e: |
|
raise HTTPException(status_code=e.status_code, detail=str(e)) |
|
|
|
args_model = ( |
|
meme.params_type.args_type.model |
|
if meme.params_type.args_type |
|
else MemeArgsModel |
|
) |
|
properties: dict[str, dict[str, Any]] = ( |
|
args_model.schema().get("properties", {}).copy() |
|
) |
|
properties.pop("user_infos") |
|
return MemeInfoResponse( |
|
key=meme.key, |
|
keywords=meme.keywords, |
|
patterns=meme.patterns, |
|
params=MemeParamsResponse( |
|
min_images=meme.params_type.min_images, |
|
max_images=meme.params_type.max_images, |
|
min_texts=meme.params_type.min_texts, |
|
max_texts=meme.params_type.max_texts, |
|
default_texts=meme.params_type.default_texts, |
|
args=[ |
|
MemeArgsResponse( |
|
name=name, |
|
type=info.get("type", ""), |
|
description=info.get("description"), |
|
default=info.get("default"), |
|
enum=info.get("enum"), |
|
) |
|
for name, info in properties.items() |
|
], |
|
), |
|
) |
|
|
|
@app.get("/memes/{key}/preview") |
|
async def _(key: str): |
|
try: |
|
meme = get_meme(key) |
|
result = await meme.generate_preview() |
|
except MemeGeneratorException as e: |
|
raise HTTPException(status_code=e.status_code, detail=str(e)) |
|
|
|
content = result.getvalue() |
|
media_type = str(filetype.guess_mime(content)) or "text/plain" |
|
return Response(content=content, media_type=media_type) |
|
|
|
@app.post("/memes/{key}/parse_args") |
|
async def _(key: str, args: list[str] = []): |
|
try: |
|
meme = get_meme(key) |
|
return meme.parse_args(args) |
|
except MemeGeneratorException as e: |
|
raise HTTPException(status_code=e.status_code, detail=str(e)) |
|
|
|
for meme in sorted(get_memes(), key=lambda meme: meme.key): |
|
register_router(meme) |
|
|
|
|
|
def run_server(): |
|
import uvicorn |
|
|
|
register_routers() |
|
uvicorn.run( |
|
app, |
|
host=meme_config.server.host, |
|
port=meme_config.server.port, |
|
log_config=LOGGING_CONFIG, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
setup_logger() |
|
run_server() |
|
|