Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import spaces | |
import moviepy.editor as mp | |
from PIL import Image | |
import numpy as np | |
import tempfile | |
import time | |
import os | |
import shutil | |
import ffmpeg | |
from concurrent.futures import ThreadPoolExecutor | |
from gradio.themes.base import Base | |
from gradio.themes.utils import colors, fonts | |
from infer import lotus # Import the depth model inference function | |
# Custom Theme Definition | |
class WhiteTheme(Base): | |
def __init__( | |
self, | |
*, | |
primary_hue: colors.Color | str = colors.orange, | |
font: fonts.Font | str | tuple[fonts.Font | str, ...] = ( | |
fonts.GoogleFont("Inter"), | |
"ui-sans-serif", | |
"system-ui", | |
"sans-serif", | |
), | |
font_mono: fonts.Font | str | tuple[fonts.Font | str, ...] = ( | |
fonts.GoogleFont("Inter"), | |
"ui-monospace", | |
"system-ui", | |
"monospace", | |
) | |
): | |
super().__init__( | |
primary_hue=primary_hue, | |
font=font, | |
font_mono=font_mono, | |
) | |
self.set( | |
background_fill_primary="*primary_50", | |
background_fill_secondary="white", | |
border_color_primary="*primary_300", | |
body_background_fill="white", | |
body_background_fill_dark="white", | |
block_background_fill="white", | |
block_background_fill_dark="white", | |
panel_background_fill="white", | |
panel_background_fill_dark="white", | |
body_text_color="black", | |
body_text_color_dark="black", | |
block_label_text_color="black", | |
block_label_text_color_dark="black", | |
block_border_color="white", | |
panel_border_color="white", | |
input_border_color="lightgray", | |
input_background_fill="white", | |
input_background_fill_dark="white", | |
shadow_drop="none" | |
) | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def process_frame(frame, seed=0): | |
""" | |
Process a single frame through the depth model. | |
Returns the discriminative depth map. | |
""" | |
try: | |
# Convert frame to PIL Image | |
image = Image.fromarray(frame) | |
# Save temporary image (lotus requires a file path) | |
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp: | |
image.save(tmp.name) | |
# Process through lotus model | |
_, output_d = lotus(tmp.name, 'depth', seed, device) | |
# Clean up temp file | |
os.unlink(tmp.name) | |
# Convert depth output to numpy array | |
depth_array = np.array(output_d) | |
return depth_array | |
except Exception as e: | |
print(f"Error processing frame: {e}") | |
return None | |
def process_video(video_path, fps=0, seed=0, max_workers=6): | |
""" | |
Process video to create depth map sequence and video. | |
Maintains original resolution and framerate if fps=0. | |
""" | |
temp_dir = None | |
try: | |
start_time = time.time() | |
video = mp.VideoFileClip(video_path) | |
# Use original video FPS if not specified | |
if fps == 0: | |
fps = video.fps | |
frames = list(video.iter_frames(fps=fps)) | |
total_frames = len(frames) | |
print(f"Processing {total_frames} frames at {fps} FPS...") | |
# Create temporary directory for frame sequence | |
temp_dir = tempfile.mkdtemp() | |
frames_dir = os.path.join(temp_dir, "frames") | |
os.makedirs(frames_dir, exist_ok=True) | |
# Process frames with parallel execution | |
processed_frames = [] | |
with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
futures = [executor.submit(process_frame, frame, seed) for frame in frames] | |
for i, future in enumerate(futures): | |
try: | |
result = future.result() | |
if result is not None: | |
# Save frame | |
frame_path = os.path.join(frames_dir, f"frame_{i:06d}.png") | |
Image.fromarray(result).save(frame_path) | |
# Collect processed frame for preview | |
processed_frames.append(result) | |
# Update preview | |
elapsed_time = time.time() - start_time | |
yield processed_frames[-1], None, None, f"Processing frame {i+1}/{total_frames}... Elapsed time: {elapsed_time:.2f} seconds" | |
if (i + 1) % 10 == 0: | |
print(f"Processed {i+1}/{total_frames} frames") | |
except Exception as e: | |
print(f"Error processing frame {i+1}: {e}") | |
print("Creating output files...") | |
# Create output directory | |
output_dir = os.path.join(os.path.dirname(video_path), "output") | |
os.makedirs(output_dir, exist_ok=True) | |
# Create ZIP of frame sequence | |
zip_filename = f"depth_frames_{int(time.time())}.zip" | |
zip_path = os.path.join(output_dir, zip_filename) | |
shutil.make_archive(zip_path[:-4], 'zip', frames_dir) | |
# Create MP4 video | |
print("Creating MP4 video...") | |
video_filename = f"depth_video_{int(time.time())}.mp4" | |
video_path = os.path.join(output_dir, video_filename) | |
try: | |
# FFmpeg settings for high-quality MP4 | |
stream = ffmpeg.input( | |
os.path.join(frames_dir, 'frame_%06d.png'), | |
pattern_type='sequence', | |
framerate=fps | |
) | |
stream = ffmpeg.output( | |
stream, | |
video_path, | |
vcodec='libx264', | |
pix_fmt='yuv420p', | |
crf=17, # High quality | |
threads=max_workers | |
) | |
ffmpeg.run(stream, overwrite_output=True, capture_stdout=True, capture_stderr=True) | |
print("MP4 video created successfully!") | |
except ffmpeg.Error as e: | |
print(f"Error creating video: {e.stderr.decode() if e.stderr else str(e)}") | |
video_path = None | |
print("Processing complete!") | |
yield None, zip_path, video_path, f"Processing complete! Total time: {time.time() - start_time:.2f} seconds" | |
except Exception as e: | |
print(f"Error: {e}") | |
yield None, None, None, f"Error processing video: {e}" | |
finally: | |
if temp_dir and os.path.exists(temp_dir): | |
try: | |
shutil.rmtree(temp_dir) | |
except Exception as e: | |
print(f"Error cleaning up temp directory: {e}") | |
def process_wrapper(video, fps=0, seed=0, max_workers=6): | |
if video is None: | |
raise gr.Error("Please upload a video.") | |
try: | |
outputs = [] | |
for output in process_video(video, fps, seed, max_workers): | |
outputs.append(output) | |
yield output | |
return outputs[-1] | |
except Exception as e: | |
raise gr.Error(f"Error processing video: {str(e)}") | |
# Custom CSS for styling | |
custom_css = """ | |
.title-container { | |
text-align: center; | |
padding: 10px 0; | |
} | |
#title { | |
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; | |
font-size: 36px; | |
font-weight: bold; | |
color: #000000; | |
padding: 10px; | |
border-radius: 10px; | |
display: inline-block; | |
background: linear-gradient( | |
135deg, | |
#e0f7fa, #e8f5e9, #fff9c4, #ffebee, | |
#f3e5f5, #e1f5fe, #fff3e0, #e8eaf6 | |
); | |
background-size: 400% 400%; | |
animation: gradient-animation 15s ease infinite; | |
} | |
@keyframes gradient-animation { | |
0% { background-position: 0% 50%; } | |
50% { background-position: 100% 50%; } | |
100% { background-position: 0% 50%; } | |
} | |
""" | |
# Gradio Interface | |
with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo: | |
gr.HTML(''' | |
<div class="title-container"> | |
<div id="title">Video Depth Estimation</div> | |
</div> | |
''') | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video( | |
label="Upload Video", | |
interactive=True, | |
show_label=True, | |
height=360, | |
width=640 | |
) | |
with gr.Row(): | |
fps_slider = gr.Slider( | |
minimum=0, | |
maximum=60, | |
step=1, | |
value=0, | |
label="Output FPS (0 will inherit the original fps value)", | |
) | |
seed_slider = gr.Slider( | |
minimum=0, | |
maximum=999999999, | |
step=1, | |
value=0, | |
label="Seed", | |
) | |
max_workers_slider = gr.Slider( | |
minimum=1, | |
maximum=32, | |
step=1, | |
value=6, | |
label="Max Workers", | |
info="Determines how many frames to process in parallel" | |
) | |
btn = gr.Button("Process Video", elem_id="submit-button") | |
with gr.Column(): | |
preview_image = gr.Image(label="Live Preview", show_label=True) | |
output_frames_zip = gr.File(label="Download Frame Sequence (ZIP)") | |
output_video = gr.File(label="Download Video (MP4)") | |
time_textbox = gr.Textbox(label="Status", interactive=False) | |
gr.Markdown(""" | |
### Output Information | |
- High-quality MP4 video output | |
- Original resolution and framerate are maintained | |
- Frame sequence provided for maximum compatibility | |
""") | |
btn.click( | |
fn=process_wrapper, | |
inputs=[video_input, fps_slider, seed_slider, max_workers_slider], | |
outputs=[preview_image, output_frames_zip, output_video, time_textbox] | |
) | |
demo.queue() | |
api = gr.Interface( | |
fn=process_wrapper, | |
inputs=[ | |
gr.Video(label="Upload Video"), | |
gr.Number(label="FPS", value=0), | |
gr.Number(label="Seed", value=0), | |
gr.Number(label="Max Workers", value=6) | |
], | |
outputs=[ | |
gr.Image(label="Preview"), | |
gr.File(label="Frame Sequence"), | |
gr.File(label="Video"), | |
gr.Textbox(label="Status") | |
], | |
title="Video Depth Estimation API", | |
description="Generate depth maps from videos", | |
api_name="/process_video" | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True, show_error=True, share=False, server_name="0.0.0.0", server_port=7860) |