import os import argparse from pathlib import Path from typing import Optional, Union, Tuple, List import subprocess from itertools import chain import gradio as gr from PIL import Image from omegaconf import OmegaConf, DictConfig from inference import InferenceServicer PATH_DOCS = os.getenv("PATH_DOCS", default="docs/ml-font-style-transfer.md") MODEL_CONFIG = os.getenv("MODEL_CONFIG", default="config/models/google-font.yaml") MODEL_CHECKPOINT_PATH = os.getenv("MODEL_CHECKPOINT_PATH", default=None) LOCAL_CHECKPOINT_PATH = "checkpoint/checkpoint.ckpt" LOCAL_NOTO_ZIP_PATH = "data/NotoSans.zip" if MODEL_CHECKPOINT_PATH is not None: subprocess.call(f"wget --no-check-certificate -O {LOCAL_CHECKPOINT_PATH} {MODEL_CHECKPOINT_PATH}", shell=True) subprocess.call(f"unzip data/NotoSans.zip -d {str(Path(LOCAL_NOTO_ZIP_PATH).parent)}", shell=True) assert Path("checkpoint/checkpoint.ckpt").exists() assert Path("data/NotoSans").exists() EXAMPLE_FONTS = sorted([str(x) for x in chain(Path("example_fonts").glob("*.ttf"), Path("example_fonts").glob("*.otf"))]) def parse_args(): parser = argparse.ArgumentParser(description="Augmentation simulator for NetsPresso Trainer") # -------- User arguments ---------------------------------------- parser.add_argument( '--docs', type=Path, default=PATH_DOCS, help="Docs string file") parser.add_argument( '--config', type=Path, default=MODEL_CONFIG, help="Config for model") parser.add_argument( '--local', action='store_true', help="Whether to run in local environment or not") parser.add_argument( '--port', type=int, default=50003, help="Service port (only applicable when running on local server)") args, _ = parser.parse_known_args() return args class InferenceServiceResolver(InferenceServicer): def __init__(self, hp, checkpoint_path, content_image_dir, imsize=64, gpu_id='0') -> None: super().__init__(hp, checkpoint_path, content_image_dir, imsize, gpu_id) def generate(self, content_char: str, style_font: Union[str, Path]) -> List[Image.Image]: try: content_image, style_images, result = self.inference(content_char=content_char, style_font=style_font) return [content_image, *style_images, result] except Exception as e: raise gr.Error(str(e)) def launch_gradio(docs_path: Path, hp: DictConfig, checkpoint_path: Path, content_image_dir: Path, is_local: bool, port: Optional[int] = None): servicer = InferenceServiceResolver(hp, checkpoint_path, content_image_dir, gpu_id=None) with gr.Blocks(title="Multilingual Font Style Transfer (training with Google Fonts)") as demo: gr.Markdown(docs_path.read_text()) with gr.Row(equal_height=True): character_input = gr.Textbox(max_lines=1, value="7", info="Only single character is acceptable (e.g. '간', '7', or 'ជ')") style_font = gr.Dropdown(label="Select example font: ", choices=EXAMPLE_FONTS, value=EXAMPLE_FONTS[0]) run_button = gr.Button(value="Generate", variant='primary') with gr.Row(equal_height=True): with gr.Column(scale=1): with gr.Group(): gr.Markdown(f"

Content character

") content_char = gr.Image(label="Content character", show_label=False) with gr.Column(scale=5): with gr.Group(): gr.Markdown(f"

Style font images

") with gr.Row(equal_height=True): style_char_1 = gr.Image(label="Style #1", show_label=False) style_char_2 = gr.Image(label="Style #2", show_label=False) style_char_3 = gr.Image(label="Style #3", show_label=False) style_char_4 = gr.Image(label="Style #4", show_label=False) style_char_5 = gr.Image(label="Style #5", show_label=False) with gr.Column(scale=1): with gr.Group(): gr.Markdown(f"

Generated font image

") generated_font = gr.Image(label="Generated font image", show_label=False) outputs = [content_char, style_char_1, style_char_2, style_char_3, style_char_4, style_char_5, generated_font] run_inputs = [character_input, style_font] run_button.click(servicer.generate, inputs=run_inputs, outputs=outputs) if is_local: demo.launch(server_name="0.0.0.0", server_port=port) else: demo.launch() if __name__ == "__main__": args = parse_args() hp = OmegaConf.load(args.config) checkpoint_path = Path(LOCAL_CHECKPOINT_PATH) content_image_dir = Path(LOCAL_NOTO_ZIP_PATH).with_suffix("") launch_gradio(args.docs, hp, checkpoint_path, content_image_dir, args.local, args.port)