import os import gradio as gr import plotly.graph_objects as go import torch import json import glob import numpy as np from PIL import Image import time import copy import sys # Mesh imports from pytorch3d.io import load_objs_as_meshes from pytorch3d.vis.plotly_vis import AxisArgs, plot_scene from pytorch3d.transforms import Transform3d, RotateAxisAngle, Translate, Rotate from sampling_for_demo import load_and_return_model_and_data, sample, load_base_model # add current directory to path # sys.path.append(os.path.dirname(os.path.realpath(__file__))) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") def transform_mesh(mesh, transform, scale=1.0): mesh = mesh.clone() verts = mesh.verts_packed() * scale verts = transform.transform_points(verts) mesh.offset_verts_(verts - mesh.verts_packed()) return mesh def get_input_pose_fig(): global curr_camera_dict global obj_filename global plane_trans plane_filename = 'assets/plane.obj' mesh_scale = 0.75 mesh = load_objs_as_meshes([obj_filename], device=device) mesh.scale_verts_(mesh_scale) plane = load_objs_as_meshes([plane_filename], device=device) ### plane rotate_x = RotateAxisAngle(angle=90.0, axis='X', device=device) plane = transform_mesh(plane, rotate_x) translate_y = Translate(0, plane_trans * mesh_scale, 0, device=device) plane = transform_mesh(plane, translate_y) fig = plot_scene({ "plot": { "object": mesh, }, }, axis_args=AxisArgs(showgrid=True, backgroundcolor='#cccde0'), xaxis=dict(range=[-1, 1]), yaxis=dict(range=[-1, 1]), zaxis=dict(range=[-1, 1]) ) plane = plane.detach().cpu() verts = plane.verts_packed() faces = plane.faces_packed() fig.add_trace( go.Mesh3d( x=verts[:, 0], y=verts[:, 1], z=verts[:, 2], i=faces[:, 0], j=faces[:, 1], k=faces[:, 2], opacity=0.7, color='gray', hoverinfo='skip', ), ) print("fig: curr camera dict") print(curr_camera_dict) camera_dict = curr_camera_dict fig.update_layout(scene=dict( xaxis=dict(showticklabels=True, visible=True), yaxis=dict(showticklabels=True, visible=True), zaxis=dict(showticklabels=True, visible=True), )) # show grid fig.update_layout(scene=dict( xaxis=dict(showgrid=True, gridwidth=1, gridcolor='black'), yaxis=dict(showgrid=True, gridwidth=1, gridcolor='black'), zaxis=dict(showgrid=True, gridwidth=1, gridcolor='black'), bgcolor='#dedede', )) fig.update_layout( camera_dict, width=512, height=512, ) return fig def run_inference(cam_pose_json, prompt, scale_im, scale, steps, seed): print("prompt is ", prompt) global current_data, current_model # run model images = sample( current_model, current_data, num_images=1, prompt=prompt, appendpath="", camera_json=cam_pose_json, train=False, scale=scale, scale_im=scale_im, beta=1.0, num_ref=8, skipreflater=False, num_steps=steps, valid=False, max_images=20, seed=seed ) result = images[0] print(result.shape) result = Image.fromarray((np.clip(((result+1.0)/2.0).permute(1, 2, 0).cpu().numpy(), 0., 1.)*255).astype(np.uint8)) print('result obtained') return result def update_curr_camera_dict(camera_json): # TODO: this does not always update the figure, also there's always flashes global curr_camera_dict global prev_camera_dict if camera_json is None: camera_json = json.dumps(prev_camera_dict) camera_json = camera_json.replace("'", "\"") curr_camera_dict = json.loads(camera_json) # ["scene.camera"] print("update curr camera dict") print(curr_camera_dict) return camera_json MODELS_DIR = "pretrained-models/" def select_and_load_model(category, category_single_id): global current_data, current_model, base_model current_model = None current_model = copy.deepcopy(base_model) ### choose model checkpoint and config delta_ckpt = glob.glob(f"{MODELS_DIR}/*{category}{category_single_id}*/checkpoints/step=*.ckpt")[0] print(f"Loading model from {delta_ckpt}") logdir = delta_ckpt.split('/checkpoints')[0] config = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))[-1] start_time = time.time() current_model, current_data = load_and_return_model_and_data(config, current_model, delta_ckpt=delta_ckpt ) print(f"Time taken to load delta model: {time.time() - start_time:.2f}s") print("!!! model loaded") input_prompt = f"photo of a {category}" return "### Model loaded!", input_prompt global current_data global current_model current_data = None current_model = None global base_model BASE_CONFIG = "configs/train_co3d_concept.yaml" BASE_CKPT = "pretrained-models/sd_xl_base_1.0.safetensors" start_time = time.time() base_model = load_base_model(BASE_CONFIG, ckpt=BASE_CKPT, verbose=False) print(f"Time taken to load base model: {time.time() - start_time:.2f}s") global curr_camera_dict curr_camera_dict = { "scene.camera": { "up": {"x": -0.13227683305740356, "y": -0.9911391735076904, "z": -0.013464212417602539}, "center": {"x": -0.005292057991027832, "y": 0.020704858005046844, "z": 0.0873757004737854}, "eye": {"x": 0.8585731983184814, "y": -0.08790968358516693, "z": -0.40458938479423523}, }, "scene.aspectratio": {"x": 1.974, "y": 1.974, "z": 1.974}, "scene.aspectmode": "manual" } global prev_camera_dict prev_camera_dict = copy.deepcopy(curr_camera_dict) global obj_filename obj_filename = "assets/car0_mesh_centered_flipped.obj" global plane_trans plane_trans = 0.16 my_fig = get_input_pose_fig() scripts = open("scripts.js", "r").read() def update_category_single_id(category): global curr_camera_dict global prev_camera_dict global obj_filename global plane_trans choices = None if category == "car": choices = ["0"] curr_camera_dict = { "scene.camera": { "up": {"x": -0.13227683305740356, "y": -0.9911391735076904, "z": -0.013464212417602539}, "center": {"x": -0.005292057991027832, "y": 0.020704858005046844, "z": 0.0873757004737854}, "eye": {"x": 0.8585731983184814, "y": -0.08790968358516693, "z": -0.40458938479423523}, }, "scene.aspectratio": {"x": 1.974, "y": 1.974, "z": 1.974}, "scene.aspectmode": "manual" } plane_trans = 0.16 elif category == "chair": choices = ["191"] curr_camera_dict = { "scene.camera": { "up": {"x": 1.0477e-04, "y": -9.9995e-01, "z": 1.0288e-02}, "center": {"x": 0.0539, "y": 0.0015, "z": 0.0007}, "eye": {"x": 0.0410, "y": -0.0091, "z": -0.9991}, }, "scene.aspectratio": {"x": 0.9084, "y": 0.9084, "z": 0.9084}, "scene.aspectmode": "manual" } plane_trans = 0.38 elif category == "motorcycle": choices = ["12"] curr_camera_dict = { "scene.camera": { "up": {"x": 0.0308, "y": -0.9994, "z": -0.0147}, "center": {"x": 0.0240, "y": -0.0310, "z": -0.0016}, "eye": {"x": -0.0580, "y": -0.0188, "z": -0.9981}, }, "scene.aspectratio": {"x": 1.5786, "y": 1.5786, "z": 1.5786}, "scene.aspectmode": "manual" } plane_trans = 0.16 elif category == "teddybear": choices = ["31"] curr_camera_dict = { "scene.camera": { "up": {"x": 0.4304, "y": -0.9023, "z": -0.0221}, "center": {"x": -0.0658, "y": 0.2081, "z": 0.0175}, "eye": {"x": -0.4456, "y": 0.0493, "z": -0.8939}, }, "scene.aspectratio": {"x": 1.8052, "y": 1.8052, "z": 1.8052}, "scene.aspectmode": "manual", } plane_trans = 0.23 obj_filename = f"assets/{category}{choices[0]}_mesh_centered_flipped.obj" prev_camera_dict = copy.deepcopy(curr_camera_dict) return gr.Dropdown(choices=choices, label="Object ID", value=choices[0]) head = """ """ ORIGINAL_SPACE_ID = 'customdiffusion360' SPACE_ID = os.getenv('SPACE_ID') SHARED_UI_WARNING = f'''## Attention - the demo requires at least 40GB VRAM for inference. Please clone this repository to run on your own machine.
Duplicate Space
''' with gr.Blocks(head=head, css="style.css", js=scripts, title="Customizing Text-to-Image Diffusion with Camera Viewpoint Control") as demo: gr.HTML("""

Customizing Text-to-Image Diffusion with Camera Viewpoint Control


""", visible=True ) if SPACE_ID == ORIGINAL_SPACE_ID: gr.Markdown(SHARED_UI_WARNING) with gr.Row(): with gr.Column(min_width=150): gr.Markdown("## 1. SELECT CUSTOMIZED MODEL") category = gr.Dropdown(choices=["car", "chair", "motorcycle", "teddybear"], label="Category", value="car") category_single_id = gr.Dropdown(label="Object ID", choices=["0"], type="value", value="0", visible=False) category.change(update_category_single_id, [category], [category_single_id]) load_model_btn = gr.Button(value="Load Model", elem_id="load_model_button") load_model_status = gr.Markdown(elem_id="load_model_status", value="### Please select and load a model.") with gr.Column(min_width=512): gr.Markdown("## 2. CAMERA POSE VISUALIZATION") # TODO ? don't use gradio plotly element so we can remove menu buttons map = gr.Plot(value=my_fig, min_width=512, elem_id="map") ### hidden elements update_pose_btn = gr.Button(value="Update Camera Pose", visible=False, elem_id="update_pose_button") input_pose = gr.TextArea(value=curr_camera_dict, label="Input Camera Pose", visible=False, elem_id="input_pose", interactive=False) check_pose_btn = gr.Button(value="Check Camera Pose", visible=False, elem_id="check_pose_button") ## TODO: track init_camera_dict and with js? ### visible elements input_prompt = gr.Textbox(value="photo of a car", label="Prompt", interactive=True) scale_im = gr.Slider(value=3.5, label="Image guidance scale", minimum=0, maximum=20.0, step=0.1) scale = gr.Slider(value=7.5, label="Text guidance scale", minimum=0, maximum=20.0, step=0.1) steps = gr.Slider(value=10, label="Inference steps", minimum=1, maximum=50, step=1) seed = gr.Textbox(value=42, label="Seed") with gr.Column(min_width=50, elem_id="column_process", scale=0.3): run_btn = gr.Button(value="Run", elem_id="run_button", min_width=50) with gr.Column(min_width=512): gr.Markdown("## 3. OUR OUTPUT") result = gr.Image(show_label=False, show_download_button=True, width=512, height=512, elem_id="result") load_model_btn.click(select_and_load_model, [category, category_single_id], [load_model_status, input_prompt]) load_model_btn.click(get_input_pose_fig, [], [map]) update_pose_btn.click(update_curr_camera_dict, [input_pose], [input_pose],) # js=send_js_camera_to_gradio) # check_pose_btn.click(check_curr_camera_dict, [], [input_pose]) run_btn.click(run_inference, [input_pose, input_prompt, scale_im, scale, steps, seed], result) demo.load(js=scripts) if __name__ == "__main__": demo.queue().launch(debug=True)