import os
# os.system('pip uninstall -y gradio_fake3d')
# os.system('pip install gradio_fake3d-0.0.2-py3-none-any.whl')
import gradio as gr
import re
from gradio_fake3d import Fake3D
from PIL import Image
from Rodin import Generator, crop_image
from constant import *
generator = Generator(USER, PASSWORD)
change_button_name = """
function updateButton(input) {
var buttonGenerate = document.getElementById('button_generate');
buttonGenerate.innerText = 'Redo';
return '';
}
"""
reset_button_name = """
function updateButton(input) {
var buttonGenerate = document.getElementById('button_generate');
buttonGenerate.innerText = 'Generate';
return '';
}
"""
jump_to_rodin = """
function redirectToGithub(input) {
if (input.includes('OpenClay')) {
window.open("https://github.com/CLAY-3D/OpenCLAY", "_blank");
}
return "Rodin Gen-1(0525)";
}
"""
html_content = """
"""
options = [
"Rodin Gen-1(0525)",
"OpenClay(600M) - Coming soon",
"OpenClay(200M) - Coming soon"
]
def do_nothing(text):
return ""
def handle_selection(selection):
return "Rodin Gen-1(0525)"
# if selection in ["OpenClay(600M)", "OpenClay(200M)"]:
# # 返回一个 HTML 字符串,用于在新窗口中打开指定的 URL
# return f""
# else:
# return "You selected Rodin Gen-1(0525)."
def hint_in_prompt(hint, prompt):
return re.search(fr"{hint[:-1]}", prompt) is not None
def prompt_remove_hint(prompt, hint):
return re.sub(fr"\s*{hint[:-1]}[\.,]*", "", prompt)
def handle_hint_change(prompt: str, prompt_hint):
prompt = prompt.strip()
if prompt != "" and not prompt.endswith("."):
prompt = prompt + "."
for _, hint in PROMPT_HINT_LIST:
if hint in prompt_hint:
if not hint_in_prompt(hint, prompt):
prompt = prompt + " " + hint
else:
prompt = prompt_remove_hint(prompt, hint)
prompt = prompt.strip()
return prompt
def handle_prompt_change(prompt):
hint_list = []
for _, hint in PROMPT_HINT_LIST:
if hint_in_prompt(hint, prompt):
hint_list.append(hint)
return hint_list
def clear_task_uuid():
return ""
def return_render(image):
image = Image.fromarray(image)
return image, crop_image(image, DEFAULT)
def crop_image_default(image):
return crop_image(image, DEFAULT)
def crop_image_metal(image):
return crop_image(image, METAL)
def crop_image_contrast(image):
return crop_image(image, CONTRAST)
def crop_image_normal(image):
return crop_image(image, NORMAL)
with gr.Blocks() as demo:
gr.HTML(html_content)
with gr.Row():
with gr.Column():
block_image = gr.Image(height=256, image_mode="RGB", sources="upload", elem_classes="elem_imageupload", type="filepath")
block_model_card = gr.Dropdown(choices=options, label="Model Card", value="Rodin Gen-1(0525)", interactive=True)
# block_image_scale = gr.Slider(minimum=0, maximum=1, value=1, label="scale", interactive=True)
# block_model_card = gr.Radio(
# choices=["Rodin Gen-1(0525)", "OpenClay(600M)", "OpenClay(200M)"],
# label="Model Selection",
# value="Rodin Gen-1(0525)" # 默认选中第一个选项
# )
with gr.Group():
# with gr.Row(equal_height=True):
block_prompt = gr.Textbox(
value="",
placeholder="Auto generated description of 3d geometry",
lines=1,
show_label=True,
label="Prompt",
)
block_prompt_hint = gr.CheckboxGroup(value="Labels", choices=PROMPT_HINT_LIST)
with gr.Column():
with gr.Group():
fake3d = Fake3D(interactive=False, label="3D Preview")
with gr.Row():
button_generate = gr.Button(value="Generate", variant="primary", elem_id="button_generate")
with gr.Column(min_width=200, scale=20):
with gr.Row():
block_default = gr.Button("Default", min_width=0)
block_metal = gr.Button("Metal", min_width=0)
with gr.Row():
block_contrast = gr.Button("Contrast", min_width=0)
block_normal = gr.Button("Normal", min_width=0)
button_more = gr.Button(value="Download", variant="primary", link=rodin_url)
cache_raw_image = gr.Image(visible=False, type="pil")
cache_image_base64 = gr.Text(visible=False)
cacha_empty = gr.Text(visible=False)
cache_task_uuid = gr.Text(value="", visible=False)
# button_generate.click(fn=return_render, inputs=[block_image], outputs=[raw_image, fake3d])
block_image.change(
fn=do_nothing,
js=reset_button_name,
inputs=[cacha_empty],
outputs=[cacha_empty]
).then(fn=clear_task_uuid, outputs=[cache_task_uuid], show_progress="hidden")
button_generate.click(
fn=generator.preprocess,
inputs=[block_prompt, block_image],
outputs=[block_prompt, cache_image_base64],
show_progress="minimal"
).success(
fn=generator.generate_mesh,
inputs=[block_prompt, cache_image_base64, cache_task_uuid],
outputs=[cache_raw_image, cache_task_uuid, fake3d],
).success(
fn=do_nothing,
js=change_button_name,
inputs=[cacha_empty],
outputs=[cacha_empty]
)
block_default.click(fn=crop_image_default, inputs=[cache_raw_image], outputs=fake3d, show_progress="minimal")
block_metal.click(fn=crop_image_metal, inputs=[cache_raw_image], outputs=fake3d, show_progress="minimal")
block_contrast.click(fn=crop_image_contrast, inputs=[cache_raw_image], outputs=fake3d, show_progress="minimal")
block_normal.click(fn=crop_image_normal, inputs=[cache_raw_image], outputs=fake3d, show_progress="minimal")
button_more.click()
block_prompt_hint.input(
fn=handle_hint_change, inputs=[block_prompt, block_prompt_hint], outputs=[block_prompt],
show_progress="hidden",
queue=False,
)
block_prompt.change(
fn=handle_prompt_change,
inputs=[block_prompt],
outputs=[block_prompt_hint],
trigger_mode="always_last",
show_progress="hidden",
)
block_model_card.change(fn=handle_selection, inputs=[block_model_card], outputs=[block_model_card], show_progress="hidden", js=jump_to_rodin)
if __name__ == "__main__":
demo.launch(show_api=False)