rome / app.py
Pie31415's picture
updated app
f16a6ab
raw
history blame
2.16 kB
import os, sys
import subprocess
import argparse
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
subprocess.run(["git", "submodule", "update", "--init", "--recursive"])
print(os.getcwd())
print(os.listdir('.'))
sys.path.append("./rome")
from rome.src.utils import args as args_utils
from rome.src.utils.processing import process_black_shape, tensor2image
# loading models ---- create model repo
from huggingface_hub import hf_hub_url
default_modnet_path = hf_hub_url('Pie31415/rome','modnet_photographic_portrait_matting.ckpt')
default_model_path = hf_hub_url('Pie31415/rome','models/rome.pth')
# parser configurations
parser = argparse.ArgumentParser(conflict_handler='resolve')
parser.add_argument('--save_dir', default='.', type=str)
parser.add_argument('--save_render', default='True', type=args_utils.str2bool, choices=[True, False])
parser.add_argument('--model_checkpoint', default=default_model_path, type=str)
parser.add_argument('--modnet_path', default=default_modnet_path, type=str)
parser.add_argument('--random_seed', default=0, type=int)
parser.add_argument('--debug', action='store_true')
parser.add_argument('--verbose', default='False', type=args_utils.str2bool, choices=[True, False])
args, _ = parser.parse_known_args()
parser = importlib.import_module(f'src.rome').ROME.add_argparse_args(parser)
args = parser.parse_args()
args.deca_path = 'DECA'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from infer import Infer
infer = Infer(args)
infer = infer.to(device)
def predict(source_img, driver_img):
out = infer.evaluate(source_img, driver_img, crop_center=False)
res = tensor2image(torch.cat([out['source_information']['data_dict']['source_img'][0].cpu(),
out['source_information']['data_dict']['target_img'][0].cpu(),
out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2))
return res[..., ::-1]
import gradio as gr
gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil"),
gr.Image(type="pil")
],
outputs=gr.Image(),
examples=[]).launch()