#!/usr/bin/env python from __future__ import annotations import pathlib import sys import gradio as gr import numpy as np import PIL.Image import spaces import torch import torchvision.transforms as T from huggingface_hub import hf_hub_download sys.path.insert(0, "CelebAMask-HQ/face_parsing") from unet import unet from utils import generate_label TITLE = "CelebAMask-HQ Face Parsing" DESCRIPTION = "This is an unofficial demo for the model provided in https://github.com/switchablenorms/CelebAMask-HQ." device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") transform = T.Compose( [ T.Resize((512, 512), interpolation=PIL.Image.NEAREST), T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) path = hf_hub_download("public-data/CelebAMask-HQ-Face-Parsing", "models/model.pth") state_dict = torch.load(path, map_location="cpu") model = unet() model.load_state_dict(state_dict) model.eval() model.to(device) @spaces.GPU @torch.inference_mode() def predict(image: PIL.Image.Image) -> np.ndarray: data = transform(image) data = data.unsqueeze(0).to(device) out = model(data) out = generate_label(out, 512) out = out[0].cpu().numpy().transpose(1, 2, 0) out = np.clip(np.round(out * 255), 0, 255).astype(np.uint8) res = np.asarray(image.resize((512, 512))).astype(float) * 0.5 + out.astype(float) * 0.5 res = np.clip(np.round(res), 0, 255).astype(np.uint8) return out, res examples = sorted(pathlib.Path("images").glob("*.jpg")) demo = gr.Interface( fn=predict, inputs=gr.Image(label="Input", type="pil"), outputs=[ gr.Image(label="Predicted Labels"), gr.Image(label="Masked"), ], examples=examples, title=TITLE, description=DESCRIPTION, ) if __name__ == "__main__": demo.queue(max_size=20).launch()