Spaces:
Runtime error
Runtime error
customdiffusion360
commited on
Commit
•
8eb5f81
1
Parent(s):
43b6675
add instructions, do not load sdxl on original space
Browse files
app.py
CHANGED
@@ -28,7 +28,7 @@ def transform_mesh(mesh, transform, scale=1.0):
|
|
28 |
return mesh
|
29 |
|
30 |
|
31 |
-
def get_input_pose_fig():
|
32 |
global curr_camera_dict
|
33 |
global obj_filename
|
34 |
global plane_trans
|
@@ -44,6 +44,11 @@ def get_input_pose_fig():
|
|
44 |
### plane
|
45 |
rotate_x = RotateAxisAngle(angle=90.0, axis='X', device=device)
|
46 |
plane = transform_mesh(plane, rotate_x)
|
|
|
|
|
|
|
|
|
|
|
47 |
translate_y = Translate(0, plane_trans * mesh_scale, 0, device=device)
|
48 |
plane = transform_mesh(plane, translate_y)
|
49 |
|
@@ -171,7 +176,15 @@ def select_and_load_model(category, category_single_id):
|
|
171 |
|
172 |
print("!!! model loaded")
|
173 |
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
return "### Model loaded!", input_prompt
|
176 |
|
177 |
|
@@ -184,9 +197,15 @@ global base_model
|
|
184 |
BASE_CONFIG = "custom-diffusion360/configs/train_co3d_concept.yaml"
|
185 |
BASE_CKPT = "pretrained-models/sd_xl_base_1.0.safetensors"
|
186 |
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
global curr_camera_dict
|
192 |
curr_camera_dict = {
|
@@ -280,7 +299,7 @@ def update_category_single_id(category):
|
|
280 |
"scene.aspectratio": {"x": 1.5786, "y": 1.5786, "z": 1.5786},
|
281 |
"scene.aspectmode": "manual"
|
282 |
}
|
283 |
-
plane_trans = 0.
|
284 |
|
285 |
elif category == "teddybear":
|
286 |
choices = ["31"]
|
@@ -299,7 +318,7 @@ def update_category_single_id(category):
|
|
299 |
"scene.aspectratio": {"x": 1.8052, "y": 1.8052, "z": 1.8052},
|
300 |
"scene.aspectmode": "manual",
|
301 |
}
|
302 |
-
plane_trans = 0.
|
303 |
|
304 |
obj_filename = f"assets/{category}{choices[0]}_mesh_centered_flipped.obj"
|
305 |
prev_camera_dict = copy.deepcopy(curr_camera_dict)
|
@@ -310,13 +329,6 @@ head = """
|
|
310 |
<script src="https://cdn.plot.ly/plotly-2.30.0.min.js" charset="utf-8"></script>
|
311 |
"""
|
312 |
|
313 |
-
ORIGINAL_SPACE_ID = 'customdiffusion360'
|
314 |
-
SPACE_ID = os.getenv('SPACE_ID')
|
315 |
-
|
316 |
-
SHARED_UI_WARNING = f'''## Attention - the demo requires at least 40GB VRAM for inference. Please clone this repository to run on your own machine.
|
317 |
-
<center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
|
318 |
-
'''
|
319 |
-
|
320 |
with gr.Blocks(head=head,
|
321 |
css="style.css",
|
322 |
js=scripts,
|
@@ -339,14 +351,21 @@ with gr.Blocks(head=head,
|
|
339 |
<img src='https://img.shields.io/badge/Github-%23121011.svg'>
|
340 |
</a>
|
341 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
<hr></hr>
|
343 |
""",
|
344 |
visible=True
|
345 |
)
|
346 |
|
347 |
-
|
348 |
-
gr.Markdown(SHARED_UI_WARNING)
|
349 |
-
|
350 |
with gr.Row():
|
351 |
with gr.Column(min_width=150):
|
352 |
gr.Markdown("## 1. SELECT CUSTOMIZED MODEL")
|
@@ -375,7 +394,7 @@ with gr.Blocks(head=head,
|
|
375 |
## TODO: track init_camera_dict and with js?
|
376 |
|
377 |
### visible elements
|
378 |
-
input_prompt = gr.Textbox(value="
|
379 |
scale_im = gr.Slider(value=3.5, label="Image guidance scale", minimum=0, maximum=20.0, step=0.1)
|
380 |
scale = gr.Slider(value=7.5, label="Text guidance scale", minimum=0, maximum=20.0, step=0.1)
|
381 |
steps = gr.Slider(value=10, label="Inference steps", minimum=1, maximum=50, step=1)
|
@@ -389,8 +408,18 @@ with gr.Blocks(head=head,
|
|
389 |
gr.Markdown("## 3. OUR OUTPUT")
|
390 |
result = gr.Image(show_label=False, show_download_button=True, width=512, height=512, elem_id="result")
|
391 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
load_model_btn.click(select_and_load_model, [category, category_single_id], [load_model_status, input_prompt])
|
393 |
-
load_model_btn.click(get_input_pose_fig, [], [map])
|
394 |
|
395 |
update_pose_btn.click(update_curr_camera_dict, [input_pose], [input_pose],) # js=send_js_camera_to_gradio)
|
396 |
# check_pose_btn.click(check_curr_camera_dict, [], [input_pose])
|
|
|
28 |
return mesh
|
29 |
|
30 |
|
31 |
+
def get_input_pose_fig(category=None):
|
32 |
global curr_camera_dict
|
33 |
global obj_filename
|
34 |
global plane_trans
|
|
|
44 |
### plane
|
45 |
rotate_x = RotateAxisAngle(angle=90.0, axis='X', device=device)
|
46 |
plane = transform_mesh(plane, rotate_x)
|
47 |
+
|
48 |
+
if category == "teddybear":
|
49 |
+
rotate_teddy = RotateAxisAngle(angle=15.0, axis='X', device=device)
|
50 |
+
plane = transform_mesh(plane, rotate_teddy)
|
51 |
+
|
52 |
translate_y = Translate(0, plane_trans * mesh_scale, 0, device=device)
|
53 |
plane = transform_mesh(plane, translate_y)
|
54 |
|
|
|
176 |
|
177 |
print("!!! model loaded")
|
178 |
|
179 |
+
if category == "car":
|
180 |
+
input_prompt = "A <new1> car parked by a snowy mountain range"
|
181 |
+
elif category == "chair":
|
182 |
+
input_prompt = "A <new1> chair in a garden surrounded by flowers"
|
183 |
+
elif category == "motorcycle":
|
184 |
+
input_prompt = "A <new1> motorcycle beside a calm lake"
|
185 |
+
elif category == "teddybear":
|
186 |
+
input_prompt = "A <new1> teddy bear on the sand at the beach"
|
187 |
+
|
188 |
return "### Model loaded!", input_prompt
|
189 |
|
190 |
|
|
|
197 |
BASE_CONFIG = "custom-diffusion360/configs/train_co3d_concept.yaml"
|
198 |
BASE_CKPT = "pretrained-models/sd_xl_base_1.0.safetensors"
|
199 |
|
200 |
+
base_model = None
|
201 |
+
|
202 |
+
ORIGINAL_SPACE_ID = "customdiffusion360/customdiffusion360"
|
203 |
+
SPACE_ID = os.getenv("SPACE_ID")
|
204 |
+
|
205 |
+
if SPACE_ID != ORIGINAL_SPACE_ID:
|
206 |
+
start_time = time.time()
|
207 |
+
base_model = load_base_model(BASE_CONFIG, ckpt=BASE_CKPT, verbose=False)
|
208 |
+
print(f"Time taken to load base model: {time.time() - start_time:.2f}s")
|
209 |
|
210 |
global curr_camera_dict
|
211 |
curr_camera_dict = {
|
|
|
299 |
"scene.aspectratio": {"x": 1.5786, "y": 1.5786, "z": 1.5786},
|
300 |
"scene.aspectmode": "manual"
|
301 |
}
|
302 |
+
plane_trans = 0.2
|
303 |
|
304 |
elif category == "teddybear":
|
305 |
choices = ["31"]
|
|
|
318 |
"scene.aspectratio": {"x": 1.8052, "y": 1.8052, "z": 1.8052},
|
319 |
"scene.aspectmode": "manual",
|
320 |
}
|
321 |
+
plane_trans = 0.3
|
322 |
|
323 |
obj_filename = f"assets/{category}{choices[0]}_mesh_centered_flipped.obj"
|
324 |
prev_camera_dict = copy.deepcopy(curr_camera_dict)
|
|
|
329 |
<script src="https://cdn.plot.ly/plotly-2.30.0.min.js" charset="utf-8"></script>
|
330 |
"""
|
331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
with gr.Blocks(head=head,
|
333 |
css="style.css",
|
334 |
js=scripts,
|
|
|
351 |
<img src='https://img.shields.io/badge/Github-%23121011.svg'>
|
352 |
</a>
|
353 |
</div>
|
354 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
355 |
+
<p>
|
356 |
+
This is a demo for <a href='https://github.com/customdiffusion360/custom-diffusion360'>Custom Diffusion 360</a>.
|
357 |
+
Please duplicate this space and upgrade the GPU to A10G Large in Settings to run the demo.
|
358 |
+
</p>
|
359 |
+
</div>
|
360 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
361 |
+
<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/customdiffusion360/customdiffusion360?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a>
|
362 |
+
</div>
|
363 |
<hr></hr>
|
364 |
""",
|
365 |
visible=True
|
366 |
)
|
367 |
|
368 |
+
|
|
|
|
|
369 |
with gr.Row():
|
370 |
with gr.Column(min_width=150):
|
371 |
gr.Markdown("## 1. SELECT CUSTOMIZED MODEL")
|
|
|
394 |
## TODO: track init_camera_dict and with js?
|
395 |
|
396 |
### visible elements
|
397 |
+
input_prompt = gr.Textbox(value="A <new1> car parked by a snowy mountain range", label="Prompt", interactive=True)
|
398 |
scale_im = gr.Slider(value=3.5, label="Image guidance scale", minimum=0, maximum=20.0, step=0.1)
|
399 |
scale = gr.Slider(value=7.5, label="Text guidance scale", minimum=0, maximum=20.0, step=0.1)
|
400 |
steps = gr.Slider(value=10, label="Inference steps", minimum=1, maximum=50, step=1)
|
|
|
408 |
gr.Markdown("## 3. OUR OUTPUT")
|
409 |
result = gr.Image(show_label=False, show_download_button=True, width=512, height=512, elem_id="result")
|
410 |
|
411 |
+
gr.Markdown("### Camera Pose Controls:")
|
412 |
+
gr.Markdown("* Orbital rotation: Left-click and drag.")
|
413 |
+
gr.Markdown("* Zoom: Mouse wheel scroll.")
|
414 |
+
gr.Markdown("* Pan (translate the camera): Right-click and drag.")
|
415 |
+
gr.Markdown("* Tilt camera: Tilt mouse wheel left/right.")
|
416 |
+
gr.Markdown("* Reset to initial camera pose: Hover over the top right corner of the plot and click the camera icon.")
|
417 |
+
gr.Markdown("### Note:")
|
418 |
+
gr.Markdown("The models only work within a range of elevation angles and distances near the initial camera pose.")
|
419 |
+
|
420 |
+
|
421 |
load_model_btn.click(select_and_load_model, [category, category_single_id], [load_model_status, input_prompt])
|
422 |
+
load_model_btn.click(get_input_pose_fig, [category], [map])
|
423 |
|
424 |
update_pose_btn.click(update_curr_camera_dict, [input_pose], [input_pose],) # js=send_js_camera_to_gradio)
|
425 |
# check_pose_btn.click(check_curr_camera_dict, [], [input_pose])
|