from pathlib import Path from typing import Dict, List, Union, Tuple from omegaconf import OmegaConf import numpy as np import torch from torch import nn from PIL import Image, ImageDraw, ImageFont import models GENERATOR_PREFIX = "networks.g." WHITE = 255 EXAMPLE_CHARACTERS = ['A', 'B', 'C', 'D', 'E'] class InferenceServicer: def __init__(self, hp, checkpoint_path, content_image_dir, imsize=64, gpu_id='0') -> None: self.hp = hp self.imsize = imsize if gpu_id is None: self.device = torch.device(f'cuda:0') if torch.cuda.is_available() else 'cpu' else: self.device = torch.device(f'cuda:{gpu_id}') model_config = self.hp.models.G self.model: nn.Module = models.Generator(model_config) # Load Generator model weight model_state_dict_pl = torch.load(checkpoint_path, map_location='cpu') generator_state_dict = self.convert_generator_state_dict(model_state_dict_pl) self.model.load_state_dict(generator_state_dict) self.model.to(device=self.device) self.model.eval() # Setting Content font files self.content_character_dict = self.load_content_character_dict(Path(content_image_dir)) @staticmethod def convert_generator_state_dict(model_state_dict_pl): generator_prefix = GENERATOR_PREFIX generator_state_dict = {} for module_name, module_state in model_state_dict_pl['state_dict'].items(): if module_name.startswith(generator_prefix): generator_state_dict[module_name[len(generator_prefix):]] = module_state return generator_state_dict @staticmethod def load_content_character_dict(content_image_dir: Path) -> Dict[str, Path]: content_character_dict = {} for filepath in content_image_dir.glob("**/*.png"): content_character_dict[filepath.stem] = filepath return content_character_dict @staticmethod def center_align(bg_img: Image.Image, item_img: Image.Image, fit=False) -> Image.Image: bg_img = bg_img.copy() item_img = item_img.copy() item_w, item_h = item_img.size W, H = bg_img.size if fit: item_ratio = item_w / item_h bg_ratio = W / H if bg_ratio > item_ratio: # height fitting resize_ratio = H / item_h else: # width fitting resize_ratio = W / item_w item_img = item_img.resize((int(item_w * resize_ratio), int(item_h * resize_ratio))) item_w, item_h = item_img.size bg_img.paste(item_img, ((W - item_w) // 2, (H - item_h) // 2)) return bg_img def set_image(self, image: Union[Path, Image.Image]) -> Image.Image: if isinstance(image, (str, Path)): image = Image.open(image) assert isinstance(image, Image.Image) bg_img = Image.new('RGB', (self.imsize, self.imsize), color='white') blend_img = self.center_align(bg_img, image, fit=True) return blend_img @staticmethod def pil_image_to_array(blend_img: Image.Image) -> np.ndarray: normalized_array = np.mean(np.array(blend_img, dtype=np.float32), axis=-1) / WHITE # L-only image normalized to [0, 1] return normalized_array def get_images_from_fontfile(self, font_file_path: Path, imgmode: str = 'RGB', position: tuple = (0, 0), font_size: int = 128, padding: int = 100) -> List[Image.Image]: imagefont = ImageFont.truetype(str(font_file_path), size=font_size) example_characters = EXAMPLE_CHARACTERS font_images: List[Image.Image] = [] for character in example_characters: x, y, _, _ = imagefont.getbbox(character) img = Image.new(imgmode, (x + padding, y + padding), color='white') draw = ImageDraw.Draw(img) # bbox = draw.textbbox((0,0), character, font=imagefont) # w = bbox[2] - bbox[0] # h = bbox[3] - bbox[1] w, h = draw.textsize(character, font=imagefont) img = Image.new(imgmode, (w + padding, h + padding), color='white') draw = ImageDraw.Draw(img) draw.text(position, text=character, font=imagefont, fill='black') img = img.convert(imgmode) font_images.append(img) return font_images @staticmethod def get_hex_from_char(char: str) -> str: assert len(char) == 1 return f"{ord(char):04X}".upper() # 4-digit hex string @torch.no_grad() def inference(self, content_char: str, style_font: Union[str, Path]) -> Tuple[Image.Image, List[Image.Image], Image.Image]: assert len(content_char) > 0 content_char = content_char[:1] # only get the first character if the length > 1 char_hex = self.get_hex_from_char(content_char) if char_hex not in self.content_character_dict: raise ValueError(f"The character {content_char} (hex: {char_hex}) is not supported in this model!") content_image = self.set_image(self.content_character_dict[char_hex]) style_images: List[Image.Image] = self.get_images_from_fontfile(Path(style_font)) style_images: List[Image.Image] = [self.set_image(image) for image in style_images] content_image_array = self.pil_image_to_array(content_image)[np.newaxis, np.newaxis, ...] # 1 x C(=1) x H x W style_images_array: np.ndarray = np.array([self.pil_image_to_array(image) for image in style_images])[np.newaxis, ...] # 1 x C(=5, # shots) x H x W, k-shots goes to batch content_input_tensor = torch.from_numpy(content_image_array).to(self.device) style_input_tensor = torch.from_numpy(style_images_array).to(self.device) generated_images: torch.Tensor = self.model((content_input_tensor, style_input_tensor)) generated_images = torch.clip(generated_images, 0, 1) assert generated_images.size(0) == 1 generated_image_numpy = (generated_images[0].cpu().numpy() * 255).astype(np.uint8)[0, ...] # H x W return content_image, style_images, Image.fromarray(generated_image_numpy, mode='L') if __name__ == '__main__': hp = OmegaConf.load("config/models/google-font.yaml") checkpoint_path = "epoch=199-step=257400.ckpt" content_image_dir = "../DATA/NotoSans" servicer = InferenceServicer(hp, checkpoint_path, content_image_dir) style_font = "example_fonts/MaShanZheng-Regular.ttf" content_image, style_images, result = servicer.inference("7", style_font) content_image.save("result_content.png") for idx, style_image in enumerate(style_images): style_image.save(f"result_style_{idx:02d}.png") result.save("result_generated.png")