|
import gradio as gr |
|
from multi_view import MultiViewDiffusion |
|
from vision_llm import VisionLLM |
|
from llama_mesh import LLaMAMesh |
|
from mast3r import MASt3R |
|
from utils import apply_gradient_color |
|
from utils import create_image_grid |
|
import os |
|
import torch |
|
|
|
DESCRIPTION = ''' |
|
<div> |
|
<h1 style="text-align: center;">TimeForge: Temporal Mesh Synthesis</h1> |
|
<p> This demo showcases a fusion of state-of-the-art generative models to create 3D representations with temporal variations. </p> |
|
</div> |
|
''' |
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
HF_TOKEN = os.environ.get("HF_TOKEN", None) |
|
|
|
|
|
mv_diff = MultiViewDiffusion(device=DEVICE) |
|
vllm = VisionLLM(device=DEVICE, use_auth_token=HF_TOKEN) |
|
llama_mesh = LLaMAMesh(device=DEVICE) |
|
mast3r = MASt3R(device=DEVICE) |
|
|
|
|
|
@torch.no_grad() |
|
def process_input(input_prompt, num_views=4, guidance_scale=5, num_inference_steps=30, elevation=0): |
|
|
|
multi_view_images = mv_diff.generate_views(input_prompt, num_views, guidance_scale, num_inference_steps, elevation) |
|
multi_view_image_grid = create_image_grid(multi_view_images) |
|
|
|
descriptions = vllm.describe_images(multi_view_images, f"Describe the object in the image, highlight its textures, material, and shape, and it's context, like environment and lighting:") |
|
refined_past_prompt = descriptions[0] + " ancient, weathered, eroded, original " |
|
refined_future_prompt = descriptions[0] + " futuristic, advanced, streamlined, evolved, modern " |
|
|
|
future_mesh = llama_mesh.generate_mesh(refined_future_prompt) |
|
|
|
past_point_cloud = mast3r.generate_point_cloud([multi_view_images[0]]) |
|
return multi_view_image_grid, future_mesh, past_point_cloud |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(DESCRIPTION) |
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
input_prompt = gr.Textbox(lines=2, placeholder="Enter prompt (e.g., 'A futuristic cyber-temple, once an ancient ruin')", label="Input Prompt") |
|
num_views = gr.Slider(minimum=2, maximum=8, value=4, step=1, label="Number of Views") |
|
guidance_scale = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="Guidance Scale") |
|
num_inference_steps = gr.Slider(minimum=10, maximum=50, value=30, step=1, label="Inference Steps") |
|
elevation = gr.Slider(minimum=-90, maximum=90, value=0, step=1, label="Elevation") |
|
run_button = gr.Button("Run") |
|
|
|
with gr.Column(scale=4): |
|
multi_view_grid_out = gr.Image(label = "Multi-view Images Output", height=300) |
|
with gr.Tab("Future Mesh"): |
|
future_mesh_output = gr.Model3D(label = "Future 3D Mesh output") |
|
with gr.Tab("Past Point Cloud"): |
|
past_point_cloud_output = gr.File(label = "Past 3D Point Cloud") |
|
|
|
run_button.click( |
|
fn=process_input, |
|
inputs=[input_prompt, num_views, guidance_scale, num_inference_steps, elevation], |
|
outputs=[multi_view_grid_out, future_mesh_output, past_point_cloud_output], |
|
) |
|
gr.Markdown("## Mesh Visualization (Past)") |
|
with gr.Row(): |
|
with gr.Column(): |
|
past_mesh_input = gr.Textbox( |
|
label="Past Point Cloud Input", |
|
placeholder="Paste your MASt3R file path here...", |
|
lines=2, |
|
) |
|
visualize_past_mesh_button = gr.Button("Visualize Past Mesh") |
|
with gr.Column(): |
|
past_mesh_output = gr.Model3D(label = "Past 3D Visualization") |
|
|
|
|
|
visualize_past_mesh_button.click( |
|
fn=apply_gradient_color, |
|
inputs=[past_mesh_input], |
|
outputs=[past_mesh_output] |
|
) |
|
gr.Markdown("## Mesh Visualization (Future)") |
|
with gr.Row(): |
|
with gr.Column(): |
|
future_mesh_input = gr.Textbox( |
|
label="Future Mesh Input", |
|
placeholder="Paste your 3D mesh in OBJ format here...", |
|
lines=2, |
|
) |
|
visualize_future_mesh_button = gr.Button("Visualize Future Mesh") |
|
with gr.Column(): |
|
future_mesh_output_2 = gr.Model3D(label = "Future 3D Visualization") |
|
visualize_future_mesh_button.click( |
|
fn=apply_gradient_color, |
|
inputs=[future_mesh_input], |
|
outputs=[future_mesh_output_2] |
|
) |