Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import plotly.graph_objects as go | |
import torch | |
import json | |
import glob | |
import numpy as np | |
from PIL import Image | |
import time | |
import tqdm | |
import copy | |
import sys | |
# Mesh imports | |
from pytorch3d.io import load_objs_as_meshes | |
from pytorch3d.vis.plotly_vis import AxisArgs, plot_scene | |
from pytorch3d.transforms import Transform3d, RotateAxisAngle, Translate, Rotate | |
from sampling_for_demo import load_and_return_model_and_data, sample, load_base_model | |
# add current directory to path | |
# sys.path.append(os.path.dirname(os.path.realpath(__file__))) | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
def transform_mesh(mesh, transform, scale=1.0): | |
mesh = mesh.clone() | |
verts = mesh.verts_packed() * scale | |
verts = transform.transform_points(verts) | |
mesh.offset_verts_(verts - mesh.verts_packed()) | |
return mesh | |
def get_input_pose_fig(): | |
global curr_camera_dict | |
global obj_filename | |
global plane_trans | |
plane_filename = 'assets/plane.obj' | |
mesh_scale = 0.75 | |
mesh = load_objs_as_meshes([obj_filename], device=device) | |
mesh.scale_verts_(mesh_scale) | |
plane = load_objs_as_meshes([plane_filename], device=device) | |
### plane | |
rotate_x = RotateAxisAngle(angle=90.0, axis='X', device=device) | |
plane = transform_mesh(plane, rotate_x) | |
translate_y = Translate(0, plane_trans * mesh_scale, 0, device=device) | |
plane = transform_mesh(plane, translate_y) | |
fig = plot_scene({ | |
"plot": { | |
"object": mesh, | |
}, | |
}, | |
axis_args=AxisArgs(showgrid=True, backgroundcolor='#cccde0'), | |
xaxis=dict(range=[-1, 1]), | |
yaxis=dict(range=[-1, 1]), | |
zaxis=dict(range=[-1, 1]) | |
) | |
plane = plane.detach().cpu() | |
verts = plane.verts_packed() | |
faces = plane.faces_packed() | |
fig.add_trace( | |
go.Mesh3d( | |
x=verts[:, 0], | |
y=verts[:, 1], | |
z=verts[:, 2], | |
i=faces[:, 0], | |
j=faces[:, 1], | |
k=faces[:, 2], | |
opacity=0.7, | |
color='gray', | |
hoverinfo='skip', | |
), | |
) | |
print("fig: curr camera dict") | |
print(curr_camera_dict) | |
camera_dict = curr_camera_dict | |
fig.update_layout(scene=dict( | |
xaxis=dict(showticklabels=True, visible=True), | |
yaxis=dict(showticklabels=True, visible=True), | |
zaxis=dict(showticklabels=True, visible=True), | |
)) | |
# show grid | |
fig.update_layout(scene=dict( | |
xaxis=dict(showgrid=True, gridwidth=1, gridcolor='black'), | |
yaxis=dict(showgrid=True, gridwidth=1, gridcolor='black'), | |
zaxis=dict(showgrid=True, gridwidth=1, gridcolor='black'), | |
bgcolor='#dedede', | |
)) | |
fig.update_layout( | |
camera_dict, | |
width=512, height=512, | |
) | |
return fig | |
def run_inference(cam_pose_json, prompt, scale_im, scale, steps, seed): | |
print("prompt is ", prompt) | |
global current_data, current_model | |
# run model | |
images = sample( | |
current_model, current_data, | |
num_images=1, | |
prompt=prompt, | |
appendpath="", | |
camera_json=cam_pose_json, | |
train=False, | |
scale=scale, | |
scale_im=scale_im, | |
beta=1.0, | |
num_ref=8, | |
skipreflater=False, | |
num_steps=steps, | |
valid=False, | |
max_images=20, | |
seed=seed | |
) | |
result = images[0] | |
print(result.shape) | |
result = Image.fromarray((np.clip(((result+1.0)/2.0).permute(1, 2, 0).cpu().numpy(), 0., 1.)*255).astype(np.uint8)) | |
print('result obtained') | |
return result | |
def update_curr_camera_dict(camera_json): | |
# TODO: this does not always update the figure, also there's always flashes | |
global curr_camera_dict | |
global prev_camera_dict | |
if camera_json is None: | |
camera_json = json.dumps(prev_camera_dict) | |
camera_json = camera_json.replace("'", "\"") | |
curr_camera_dict = json.loads(camera_json) # ["scene.camera"] | |
print("update curr camera dict") | |
print(curr_camera_dict) | |
return camera_json | |
MODELS_DIR = "pretrained-models/" | |
def select_and_load_model(category, category_single_id): | |
global current_data, current_model, base_model | |
current_model = None | |
current_model = copy.deepcopy(base_model) | |
### choose model checkpoint and config | |
delta_ckpt = glob.glob(f"{MODELS_DIR}/*{category}{category_single_id}*/checkpoints/step=*.ckpt")[0] | |
print(f"Loading model from {delta_ckpt}") | |
logdir = delta_ckpt.split('/checkpoints')[0] | |
config = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))[-1] | |
start_time = time.time() | |
current_model, current_data = load_and_return_model_and_data(config, current_model, | |
delta_ckpt=delta_ckpt | |
) | |
print(f"Time taken to load delta model: {time.time() - start_time:.2f}s") | |
print("!!! model loaded") | |
input_prompt = f"photo of a <new1> {category}" | |
return "### Model loaded!", input_prompt | |
global current_data | |
global current_model | |
current_data = None | |
current_model = None | |
global base_model | |
BASE_CONFIG = "configs/train_co3d_concept.yaml" | |
BASE_CKPT = "pretrained-models/sd_xl_base_1.0.safetensors" | |
start_time = time.time() | |
# base_model = load_base_model(BASE_CONFIG, ckpt=BASE_CKPT, verbose=False) | |
print(f"Time taken to load base model: {time.time() - start_time:.2f}s") | |
global curr_camera_dict | |
curr_camera_dict = { | |
"scene.camera": { | |
"up": {"x": -0.13227683305740356, | |
"y": -0.9911391735076904, | |
"z": -0.013464212417602539}, | |
"center": {"x": -0.005292057991027832, | |
"y": 0.020704858005046844, | |
"z": 0.0873757004737854}, | |
"eye": {"x": 0.8585731983184814, | |
"y": -0.08790968358516693, | |
"z": -0.40458938479423523}, | |
}, | |
"scene.aspectratio": {"x": 1.974, "y": 1.974, "z": 1.974}, | |
"scene.aspectmode": "manual" | |
} | |
global prev_camera_dict | |
prev_camera_dict = copy.deepcopy(curr_camera_dict) | |
global obj_filename | |
obj_filename = "assets/car0_mesh_centered_flipped.obj" | |
global plane_trans | |
plane_trans = 0.16 | |
my_fig = get_input_pose_fig() | |
scripts = open("scripts.js", "r").read() | |
def update_category_single_id(category): | |
global curr_camera_dict | |
global prev_camera_dict | |
global obj_filename | |
global plane_trans | |
choices = None | |
if category == "car": | |
choices = ["0"] | |
curr_camera_dict = { | |
"scene.camera": { | |
"up": {"x": -0.13227683305740356, | |
"y": -0.9911391735076904, | |
"z": -0.013464212417602539}, | |
"center": {"x": -0.005292057991027832, | |
"y": 0.020704858005046844, | |
"z": 0.0873757004737854}, | |
"eye": {"x": 0.8585731983184814, | |
"y": -0.08790968358516693, | |
"z": -0.40458938479423523}, | |
}, | |
"scene.aspectratio": {"x": 1.974, "y": 1.974, "z": 1.974}, | |
"scene.aspectmode": "manual" | |
} | |
plane_trans = 0.16 | |
elif category == "chair": | |
choices = ["191"] | |
curr_camera_dict = { | |
"scene.camera": { | |
"up": {"x": 1.0477e-04, | |
"y": -9.9995e-01, | |
"z": 1.0288e-02}, | |
"center": {"x": 0.0539, | |
"y": 0.0015, | |
"z": 0.0007}, | |
"eye": {"x": 0.0410, | |
"y": -0.0091, | |
"z": -0.9991}, | |
}, | |
"scene.aspectratio": {"x": 0.9084, "y": 0.9084, "z": 0.9084}, | |
"scene.aspectmode": "manual" | |
} | |
plane_trans = 0.38 | |
elif category == "motorcycle": | |
choices = ["12"] | |
curr_camera_dict = { | |
"scene.camera": { | |
"up": {"x": 0.0308, | |
"y": -0.9994, | |
"z": -0.0147}, | |
"center": {"x": 0.0240, | |
"y": -0.0310, | |
"z": -0.0016}, | |
"eye": {"x": -0.0580, | |
"y": -0.0188, | |
"z": -0.9981}, | |
}, | |
"scene.aspectratio": {"x": 1.5786, "y": 1.5786, "z": 1.5786}, | |
"scene.aspectmode": "manual" | |
} | |
plane_trans = 0.16 | |
elif category == "teddybear": | |
choices = ["31"] | |
curr_camera_dict = { | |
"scene.camera": { | |
"up": {"x": 0.4304, | |
"y": -0.9023, | |
"z": -0.0221}, | |
"center": {"x": -0.0658, | |
"y": 0.2081, | |
"z": 0.0175}, | |
"eye": {"x": -0.4456, | |
"y": 0.0493, | |
"z": -0.8939}, | |
}, | |
"scene.aspectratio": {"x": 1.8052, "y": 1.8052, "z": 1.8052}, | |
"scene.aspectmode": "manual", | |
} | |
plane_trans = 0.23 | |
obj_filename = f"assets/{category}{choices[0]}_mesh_centered_flipped.obj" | |
prev_camera_dict = copy.deepcopy(curr_camera_dict) | |
return gr.Dropdown(choices=choices, label="Object ID", value=choices[0]) | |
head = """ | |
<script src="https://cdn.plot.ly/plotly-2.30.0.min.js" charset="utf-8"></script> | |
""" | |
ORIGINAL_SPACE_ID = 'customdiffusion360' | |
SPACE_ID = os.getenv('SPACE_ID') | |
SHARED_UI_WARNING = f'''## Attention - the demo requires at least 40GB VRAM for inference. Please clone this repository to run on your own machine. | |
<center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center> | |
''' | |
with gr.Blocks(head=head, | |
css="style.css", | |
js=scripts, | |
title="Customizing Text-to-Image Diffusion with Camera Viewpoint Control") as demo: | |
gr.HTML(""" | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
<div> | |
<h2><a href='https://customdiffusion360.github.io/index.html'>Customizing Text-to-Image Diffusion with Camera Viewpoint Control</a></h2> | |
</div> | |
</div> | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
<a href='https://customdiffusion360.github.io/index.html' style="padding: 10px;"> | |
<img src='https://img.shields.io/badge/Project%20Page-8A2BE2'> | |
</a> | |
<a class="link" href='https://github.com/customdiffusion360/custom-diffusion360' style="padding: 10px;"> | |
<img src='https://img.shields.io/badge/Github-%23121011.svg'> | |
</a> | |
</div> | |
<hr></hr> | |
""", | |
visible=True | |
) | |
gr | |
if SPACE_ID == ORIGINAL_SPACE_ID: | |
gr.Markdown(SHARED_UI_WARNING) | |
with gr.Row(): | |
with gr.Column(min_width=150): | |
gr.Markdown("## 1. SELECT CUSTOMIZED MODEL") | |
category = gr.Dropdown(choices=["car", "chair", "motorcycle", "teddybear"], label="Category", value="car") | |
category_single_id = gr.Dropdown(label="Object ID", choices=["0"], type="value", value="0", visible=False) | |
category.change(update_category_single_id, [category], [category_single_id]) | |
load_model_btn = gr.Button(value="Load Model", elem_id="load_model_button") | |
load_model_status = gr.Markdown(elem_id="load_model_status", value="### Please select and load a model.") | |
with gr.Column(min_width=512): | |
gr.Markdown("## 2. CAMERA POSE VISUALIZATION") | |
# TODO ? don't use gradio plotly element so we can remove menu buttons | |
map = gr.Plot(value=my_fig, min_width=512, elem_id="map") | |
### hidden elements | |
update_pose_btn = gr.Button(value="Update Camera Pose", visible=False, elem_id="update_pose_button") | |
input_pose = gr.TextArea(value=curr_camera_dict, label="Input Camera Pose", visible=False, elem_id="input_pose", interactive=False) | |
check_pose_btn = gr.Button(value="Check Camera Pose", visible=False, elem_id="check_pose_button") | |
## TODO: track init_camera_dict and with js? | |
### visible elements | |
input_prompt = gr.Textbox(value="photo of a <new1> car", label="Prompt", interactive=True) | |
scale_im = gr.Slider(value=3.5, label="Image guidance scale", minimum=0, maximum=20.0, step=0.1) | |
scale = gr.Slider(value=7.5, label="Text guidance scale", minimum=0, maximum=20.0, step=0.1) | |
steps = gr.Slider(value=10, label="Inference steps", minimum=1, maximum=50, step=1) | |
seed = gr.Textbox(value=42, label="Seed") | |
with gr.Column(min_width=50, elem_id="column_process", scale=0.3): | |
run_btn = gr.Button(value="Run", elem_id="run_button", min_width=50) | |
with gr.Column(min_width=512): | |
gr.Markdown("## 3. OUR OUTPUT") | |
result = gr.Image(show_label=False, show_download_button=True, width=512, height=512, elem_id="result") | |
load_model_btn.click(select_and_load_model, [category, category_single_id], [load_model_status, input_prompt]) | |
load_model_btn.click(get_input_pose_fig, [], [map]) | |
update_pose_btn.click(update_curr_camera_dict, [input_pose], [input_pose],) # js=send_js_camera_to_gradio) | |
# check_pose_btn.click(check_curr_camera_dict, [], [input_pose]) | |
run_btn.click(run_inference, [input_pose, input_prompt, scale_im, scale, steps, seed], result) | |
demo.load(js=scripts) | |
if __name__ == "__main__": | |
demo.queue().launch(debug=True) | |