TimeForge / app.py
Ryukijano's picture
Update app.py
3a54a7a verified
import gradio as gr
from multi_view import MultiViewDiffusion
from vision_llm import VisionLLM
from llama_mesh import LLaMAMesh
from mast3r import MASt3R
from utils import apply_gradient_color
from utils import create_image_grid
import os
import torch
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">TimeForge: Temporal Mesh Synthesis</h1>
<p> This demo showcases a fusion of state-of-the-art generative models to create 3D representations with temporal variations. </p>
</div>
'''
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# Initialize models
mv_diff = MultiViewDiffusion(device=DEVICE)
vllm = VisionLLM(device=DEVICE, use_auth_token=HF_TOKEN)
llama_mesh = LLaMAMesh(device=DEVICE)
mast3r = MASt3R(device=DEVICE)
@torch.no_grad()
def process_input(input_prompt, num_views=4, guidance_scale=5, num_inference_steps=30, elevation=0):
# MultiView Diffusion
multi_view_images = mv_diff.generate_views(input_prompt, num_views, guidance_scale, num_inference_steps, elevation)
multi_view_image_grid = create_image_grid(multi_view_images)
# Vision LLM Analysis
descriptions = vllm.describe_images(multi_view_images, f"Describe the object in the image, highlight its textures, material, and shape, and it's context, like environment and lighting:")
refined_past_prompt = descriptions[0] + " ancient, weathered, eroded, original "
refined_future_prompt = descriptions[0] + " futuristic, advanced, streamlined, evolved, modern "
# LLaMA-Mesh Generation
future_mesh = llama_mesh.generate_mesh(refined_future_prompt)
# MASt3R Point Cloud Generation
past_point_cloud = mast3r.generate_point_cloud([multi_view_images[0]])
return multi_view_image_grid, future_mesh, past_point_cloud
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=3):
input_prompt = gr.Textbox(lines=2, placeholder="Enter prompt (e.g., 'A futuristic cyber-temple, once an ancient ruin')", label="Input Prompt")
num_views = gr.Slider(minimum=2, maximum=8, value=4, step=1, label="Number of Views")
guidance_scale = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="Guidance Scale")
num_inference_steps = gr.Slider(minimum=10, maximum=50, value=30, step=1, label="Inference Steps")
elevation = gr.Slider(minimum=-90, maximum=90, value=0, step=1, label="Elevation")
run_button = gr.Button("Run")
with gr.Column(scale=4):
multi_view_grid_out = gr.Image(label = "Multi-view Images Output", height=300)
with gr.Tab("Future Mesh"):
future_mesh_output = gr.Model3D(label = "Future 3D Mesh output")
with gr.Tab("Past Point Cloud"):
past_point_cloud_output = gr.File(label = "Past 3D Point Cloud")
run_button.click(
fn=process_input,
inputs=[input_prompt, num_views, guidance_scale, num_inference_steps, elevation],
outputs=[multi_view_grid_out, future_mesh_output, past_point_cloud_output],
)
gr.Markdown("## Mesh Visualization (Past)")
with gr.Row():
with gr.Column():
past_mesh_input = gr.Textbox(
label="Past Point Cloud Input",
placeholder="Paste your MASt3R file path here...",
lines=2,
)
visualize_past_mesh_button = gr.Button("Visualize Past Mesh")
with gr.Column():
past_mesh_output = gr.Model3D(label = "Past 3D Visualization")
visualize_past_mesh_button.click(
fn=apply_gradient_color,
inputs=[past_mesh_input],
outputs=[past_mesh_output]
)
gr.Markdown("## Mesh Visualization (Future)")
with gr.Row():
with gr.Column():
future_mesh_input = gr.Textbox(
label="Future Mesh Input",
placeholder="Paste your 3D mesh in OBJ format here...",
lines=2,
)
visualize_future_mesh_button = gr.Button("Visualize Future Mesh")
with gr.Column():
future_mesh_output_2 = gr.Model3D(label = "Future 3D Visualization")
visualize_future_mesh_button.click(
fn=apply_gradient_color,
inputs=[future_mesh_input],
outputs=[future_mesh_output_2]
)