deepkyu's picture
get NotoSans file from local
2551121
raw
history blame
No virus
4.97 kB
import os
import argparse
from pathlib import Path
from typing import Optional, Union, Tuple, List
import subprocess
from itertools import chain
import gradio as gr
from PIL import Image
from omegaconf import OmegaConf, DictConfig
from inference import InferenceServicer
PATH_DOCS = os.getenv("PATH_DOCS", default="docs/ml-font-style-transfer.md")
MODEL_CONFIG = os.getenv("MODEL_CONFIG", default="config/models/google-font.yaml")
MODEL_CHECKPOINT_PATH = os.getenv("MODEL_CHECKPOINT_PATH", default=None)
LOCAL_CHECKPOINT_PATH = "checkpoint/checkpoint.ckpt"
LOCAL_NOTO_ZIP_PATH = "data/NotoSans.zip"
if MODEL_CHECKPOINT_PATH is not None:
subprocess.call(f"wget --no-check-certificate -O {LOCAL_CHECKPOINT_PATH} {MODEL_CHECKPOINT_PATH}", shell=True)
subprocess.call(f"unzip data/NotoSans.zip -d {str(Path(LOCAL_NOTO_ZIP_PATH).parent)}", shell=True)
assert Path("checkpoint/checkpoint.ckpt").exists()
assert Path("data/NotoSans").exists()
EXAMPLE_FONTS = sorted([str(x) for x in chain(Path("example_fonts").glob("*.ttf"), Path("example_fonts").glob("*.otf"))])
def parse_args():
parser = argparse.ArgumentParser(description="Augmentation simulator for NetsPresso Trainer")
# -------- User arguments ----------------------------------------
parser.add_argument(
'--docs', type=Path, default=PATH_DOCS,
help="Docs string file")
parser.add_argument(
'--config', type=Path, default=MODEL_CONFIG,
help="Config for model")
parser.add_argument(
'--local', action='store_true',
help="Whether to run in local environment or not")
parser.add_argument(
'--port', type=int, default=50003,
help="Service port (only applicable when running on local server)")
args, _ = parser.parse_known_args()
return args
class InferenceServiceResolver(InferenceServicer):
def __init__(self, hp, checkpoint_path, content_image_dir, imsize=64, gpu_id='0') -> None:
super().__init__(hp, checkpoint_path, content_image_dir, imsize, gpu_id)
def generate(self, content_char: str, style_font: Union[str, Path]) -> List[Image.Image]:
try:
content_image, style_images, result = self.inference(content_char=content_char, style_font=style_font)
return [content_image, *style_images, result]
except Exception as e:
raise gr.Error(str(e))
def launch_gradio(docs_path: Path, hp: DictConfig, checkpoint_path: Path, content_image_dir: Path, is_local: bool, port: Optional[int] = None):
servicer = InferenceServiceResolver(hp, checkpoint_path, content_image_dir, gpu_id=None)
with gr.Blocks(title="Multilingual Font Style Transfer (training with Google Fonts)") as demo:
gr.Markdown(docs_path.read_text())
with gr.Row(equal_height=True):
character_input = gr.Textbox(max_lines=1, value="7", info="Only single character is acceptable (e.g. '간', '7', or 'ជ')")
style_font = gr.Dropdown(label="Select example font: ", choices=EXAMPLE_FONTS, value=EXAMPLE_FONTS[0])
run_button = gr.Button(value="Generate", variant='primary')
with gr.Row(equal_height=True):
with gr.Column(scale=1):
with gr.Group():
gr.Markdown(f"<center><h3>Content character</h3></center>")
content_char = gr.Image(label="Content character", show_label=False)
with gr.Column(scale=5):
with gr.Group():
gr.Markdown(f"<center><h3>Style font images</h3></center>")
with gr.Row(equal_height=True):
style_char_1 = gr.Image(label="Style #1", show_label=False)
style_char_2 = gr.Image(label="Style #2", show_label=False)
style_char_3 = gr.Image(label="Style #3", show_label=False)
style_char_4 = gr.Image(label="Style #4", show_label=False)
style_char_5 = gr.Image(label="Style #5", show_label=False)
with gr.Column(scale=1):
with gr.Group():
gr.Markdown(f"<center><h3>Generated font image</h3></center>")
generated_font = gr.Image(label="Generated font image", show_label=False)
outputs = [content_char, style_char_1, style_char_2, style_char_3, style_char_4, style_char_5, generated_font]
run_inputs = [character_input, style_font]
run_button.click(servicer.generate, inputs=run_inputs, outputs=outputs)
if is_local:
demo.launch(server_name="0.0.0.0", server_port=port)
else:
demo.launch()
if __name__ == "__main__":
args = parse_args()
hp = OmegaConf.load(args.config)
checkpoint_path = Path(LOCAL_CHECKPOINT_PATH)
content_image_dir = Path(LOCAL_NOTO_ZIP_PATH).with_suffix("")
launch_gradio(args.docs, hp, checkpoint_path, content_image_dir, args.local, args.port)