customdiffusion360's picture
fix requirements.txt
a24f25c
raw
history blame
No virus
14.7 kB
import os
import gradio as gr
import plotly.graph_objects as go
import torch
import json
import glob
import numpy as np
from PIL import Image
import time
import copy
import sys
# Mesh imports
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.vis.plotly_vis import AxisArgs, plot_scene
from pytorch3d.transforms import Transform3d, RotateAxisAngle, Translate, Rotate
from sampling_for_demo import load_and_return_model_and_data, sample, load_base_model
# add current directory to path
# sys.path.append(os.path.dirname(os.path.realpath(__file__)))
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def transform_mesh(mesh, transform, scale=1.0):
mesh = mesh.clone()
verts = mesh.verts_packed() * scale
verts = transform.transform_points(verts)
mesh.offset_verts_(verts - mesh.verts_packed())
return mesh
def get_input_pose_fig():
global curr_camera_dict
global obj_filename
global plane_trans
plane_filename = 'assets/plane.obj'
mesh_scale = 0.75
mesh = load_objs_as_meshes([obj_filename], device=device)
mesh.scale_verts_(mesh_scale)
plane = load_objs_as_meshes([plane_filename], device=device)
### plane
rotate_x = RotateAxisAngle(angle=90.0, axis='X', device=device)
plane = transform_mesh(plane, rotate_x)
translate_y = Translate(0, plane_trans * mesh_scale, 0, device=device)
plane = transform_mesh(plane, translate_y)
fig = plot_scene({
"plot": {
"object": mesh,
},
},
axis_args=AxisArgs(showgrid=True, backgroundcolor='#cccde0'),
xaxis=dict(range=[-1, 1]),
yaxis=dict(range=[-1, 1]),
zaxis=dict(range=[-1, 1])
)
plane = plane.detach().cpu()
verts = plane.verts_packed()
faces = plane.faces_packed()
fig.add_trace(
go.Mesh3d(
x=verts[:, 0],
y=verts[:, 1],
z=verts[:, 2],
i=faces[:, 0],
j=faces[:, 1],
k=faces[:, 2],
opacity=0.7,
color='gray',
hoverinfo='skip',
),
)
print("fig: curr camera dict")
print(curr_camera_dict)
camera_dict = curr_camera_dict
fig.update_layout(scene=dict(
xaxis=dict(showticklabels=True, visible=True),
yaxis=dict(showticklabels=True, visible=True),
zaxis=dict(showticklabels=True, visible=True),
))
# show grid
fig.update_layout(scene=dict(
xaxis=dict(showgrid=True, gridwidth=1, gridcolor='black'),
yaxis=dict(showgrid=True, gridwidth=1, gridcolor='black'),
zaxis=dict(showgrid=True, gridwidth=1, gridcolor='black'),
bgcolor='#dedede',
))
fig.update_layout(
camera_dict,
width=512, height=512,
)
return fig
def run_inference(cam_pose_json, prompt, scale_im, scale, steps, seed):
print("prompt is ", prompt)
global current_data, current_model
# run model
images = sample(
current_model, current_data,
num_images=1,
prompt=prompt,
appendpath="",
camera_json=cam_pose_json,
train=False,
scale=scale,
scale_im=scale_im,
beta=1.0,
num_ref=8,
skipreflater=False,
num_steps=steps,
valid=False,
max_images=20,
seed=seed
)
result = images[0]
print(result.shape)
result = Image.fromarray((np.clip(((result+1.0)/2.0).permute(1, 2, 0).cpu().numpy(), 0., 1.)*255).astype(np.uint8))
print('result obtained')
return result
def update_curr_camera_dict(camera_json):
# TODO: this does not always update the figure, also there's always flashes
global curr_camera_dict
global prev_camera_dict
if camera_json is None:
camera_json = json.dumps(prev_camera_dict)
camera_json = camera_json.replace("'", "\"")
curr_camera_dict = json.loads(camera_json) # ["scene.camera"]
print("update curr camera dict")
print(curr_camera_dict)
return camera_json
MODELS_DIR = "pretrained-models/"
def select_and_load_model(category, category_single_id):
global current_data, current_model, base_model
current_model = None
current_model = copy.deepcopy(base_model)
### choose model checkpoint and config
delta_ckpt = glob.glob(f"{MODELS_DIR}/*{category}{category_single_id}*/checkpoints/step=*.ckpt")[0]
print(f"Loading model from {delta_ckpt}")
logdir = delta_ckpt.split('/checkpoints')[0]
config = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))[-1]
start_time = time.time()
current_model, current_data = load_and_return_model_and_data(config, current_model,
delta_ckpt=delta_ckpt
)
print(f"Time taken to load delta model: {time.time() - start_time:.2f}s")
print("!!! model loaded")
input_prompt = f"photo of a <new1> {category}"
return "### Model loaded!", input_prompt
global current_data
global current_model
current_data = None
current_model = None
global base_model
BASE_CONFIG = "configs/train_co3d_concept.yaml"
BASE_CKPT = "pretrained-models/sd_xl_base_1.0.safetensors"
start_time = time.time()
base_model = load_base_model(BASE_CONFIG, ckpt=BASE_CKPT, verbose=False)
print(f"Time taken to load base model: {time.time() - start_time:.2f}s")
global curr_camera_dict
curr_camera_dict = {
"scene.camera": {
"up": {"x": -0.13227683305740356,
"y": -0.9911391735076904,
"z": -0.013464212417602539},
"center": {"x": -0.005292057991027832,
"y": 0.020704858005046844,
"z": 0.0873757004737854},
"eye": {"x": 0.8585731983184814,
"y": -0.08790968358516693,
"z": -0.40458938479423523},
},
"scene.aspectratio": {"x": 1.974, "y": 1.974, "z": 1.974},
"scene.aspectmode": "manual"
}
global prev_camera_dict
prev_camera_dict = copy.deepcopy(curr_camera_dict)
global obj_filename
obj_filename = "assets/car0_mesh_centered_flipped.obj"
global plane_trans
plane_trans = 0.16
my_fig = get_input_pose_fig()
scripts = open("scripts.js", "r").read()
def update_category_single_id(category):
global curr_camera_dict
global prev_camera_dict
global obj_filename
global plane_trans
choices = None
if category == "car":
choices = ["0"]
curr_camera_dict = {
"scene.camera": {
"up": {"x": -0.13227683305740356,
"y": -0.9911391735076904,
"z": -0.013464212417602539},
"center": {"x": -0.005292057991027832,
"y": 0.020704858005046844,
"z": 0.0873757004737854},
"eye": {"x": 0.8585731983184814,
"y": -0.08790968358516693,
"z": -0.40458938479423523},
},
"scene.aspectratio": {"x": 1.974, "y": 1.974, "z": 1.974},
"scene.aspectmode": "manual"
}
plane_trans = 0.16
elif category == "chair":
choices = ["191"]
curr_camera_dict = {
"scene.camera": {
"up": {"x": 1.0477e-04,
"y": -9.9995e-01,
"z": 1.0288e-02},
"center": {"x": 0.0539,
"y": 0.0015,
"z": 0.0007},
"eye": {"x": 0.0410,
"y": -0.0091,
"z": -0.9991},
},
"scene.aspectratio": {"x": 0.9084, "y": 0.9084, "z": 0.9084},
"scene.aspectmode": "manual"
}
plane_trans = 0.38
elif category == "motorcycle":
choices = ["12"]
curr_camera_dict = {
"scene.camera": {
"up": {"x": 0.0308,
"y": -0.9994,
"z": -0.0147},
"center": {"x": 0.0240,
"y": -0.0310,
"z": -0.0016},
"eye": {"x": -0.0580,
"y": -0.0188,
"z": -0.9981},
},
"scene.aspectratio": {"x": 1.5786, "y": 1.5786, "z": 1.5786},
"scene.aspectmode": "manual"
}
plane_trans = 0.16
elif category == "teddybear":
choices = ["31"]
curr_camera_dict = {
"scene.camera": {
"up": {"x": 0.4304,
"y": -0.9023,
"z": -0.0221},
"center": {"x": -0.0658,
"y": 0.2081,
"z": 0.0175},
"eye": {"x": -0.4456,
"y": 0.0493,
"z": -0.8939},
},
"scene.aspectratio": {"x": 1.8052, "y": 1.8052, "z": 1.8052},
"scene.aspectmode": "manual",
}
plane_trans = 0.23
obj_filename = f"assets/{category}{choices[0]}_mesh_centered_flipped.obj"
prev_camera_dict = copy.deepcopy(curr_camera_dict)
return gr.Dropdown(choices=choices, label="Object ID", value=choices[0])
head = """
<script src="https://cdn.plot.ly/plotly-2.30.0.min.js" charset="utf-8"></script>
"""
ORIGINAL_SPACE_ID = 'customdiffusion360'
SPACE_ID = os.getenv('SPACE_ID')
SHARED_UI_WARNING = f'''## Attention - the demo requires at least 40GB VRAM for inference. Please clone this repository to run on your own machine.
<center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
'''
with gr.Blocks(head=head,
css="style.css",
js=scripts,
title="Customizing Text-to-Image Diffusion with Camera Viewpoint Control") as demo:
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h2><a href='https://customdiffusion360.github.io/index.html'>Customizing Text-to-Image Diffusion with Camera Viewpoint Control</a></h2>
</div>
</div>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href='https://customdiffusion360.github.io/index.html' style="padding: 10px;">
<img src='https://img.shields.io/badge/Project%20Page-8A2BE2'>
</a>
<a href='https://arxiv.org/abs/2404.12333'>
<img src="https://img.shields.io/badge/arXiv-2404.12333-red">
</a>
<a class="link" href='https://github.com/customdiffusion360/custom-diffusion360' style="padding: 10px;">
<img src='https://img.shields.io/badge/Github-%23121011.svg'>
</a>
</div>
<hr></hr>
""",
visible=True
)
if SPACE_ID == ORIGINAL_SPACE_ID:
gr.Markdown(SHARED_UI_WARNING)
with gr.Row():
with gr.Column(min_width=150):
gr.Markdown("## 1. SELECT CUSTOMIZED MODEL")
category = gr.Dropdown(choices=["car", "chair", "motorcycle", "teddybear"], label="Category", value="car")
category_single_id = gr.Dropdown(label="Object ID", choices=["0"], type="value", value="0", visible=False)
category.change(update_category_single_id, [category], [category_single_id])
load_model_btn = gr.Button(value="Load Model", elem_id="load_model_button")
load_model_status = gr.Markdown(elem_id="load_model_status", value="### Please select and load a model.")
with gr.Column(min_width=512):
gr.Markdown("## 2. CAMERA POSE VISUALIZATION")
# TODO ? don't use gradio plotly element so we can remove menu buttons
map = gr.Plot(value=my_fig, min_width=512, elem_id="map")
### hidden elements
update_pose_btn = gr.Button(value="Update Camera Pose", visible=False, elem_id="update_pose_button")
input_pose = gr.TextArea(value=curr_camera_dict, label="Input Camera Pose", visible=False, elem_id="input_pose", interactive=False)
check_pose_btn = gr.Button(value="Check Camera Pose", visible=False, elem_id="check_pose_button")
## TODO: track init_camera_dict and with js?
### visible elements
input_prompt = gr.Textbox(value="photo of a <new1> car", label="Prompt", interactive=True)
scale_im = gr.Slider(value=3.5, label="Image guidance scale", minimum=0, maximum=20.0, step=0.1)
scale = gr.Slider(value=7.5, label="Text guidance scale", minimum=0, maximum=20.0, step=0.1)
steps = gr.Slider(value=10, label="Inference steps", minimum=1, maximum=50, step=1)
seed = gr.Textbox(value=42, label="Seed")
with gr.Column(min_width=50, elem_id="column_process", scale=0.3):
run_btn = gr.Button(value="Run", elem_id="run_button", min_width=50)
with gr.Column(min_width=512):
gr.Markdown("## 3. OUR OUTPUT")
result = gr.Image(show_label=False, show_download_button=True, width=512, height=512, elem_id="result")
load_model_btn.click(select_and_load_model, [category, category_single_id], [load_model_status, input_prompt])
load_model_btn.click(get_input_pose_fig, [], [map])
update_pose_btn.click(update_curr_camera_dict, [input_pose], [input_pose],) # js=send_js_camera_to_gradio)
# check_pose_btn.click(check_curr_camera_dict, [], [input_pose])
run_btn.click(run_inference, [input_pose, input_prompt, scale_im, scale, steps, seed], result)
demo.load(js=scripts)
if __name__ == "__main__":
demo.queue().launch(debug=True)