Spaces:
Runtime error
Runtime error
File size: 4,169 Bytes
1ad1a85 7b61a40 1ad1a85 7b61a40 1ad1a85 7b61a40 1ad1a85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import gradio as gr
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import os
import uuid
import moviepy.editor as mp
# Set the torch precision
torch.set_float32_matmul_precision("medium")
# Set the device to use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the pre-trained image segmentation model
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to(device)
# Define the image transformation pipeline
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def process(image, color):
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to(device)
# Predict the segmentation mask
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
# Convert prediction to PIL image
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
# Convert hex color to RGB
color_rgb = tuple(int(color[i:i+2], 16) for i in (1, 3, 5))
background = Image.new("RGBA", image_size, color_rgb + (255,))
# Composite the image onto the background using the mask
image = Image.composite(image.convert("RGBA"), background, mask)
return image
def fn(vid, color="#00FF00", fps=0):
try:
# Load the video using moviepy
video = mp.VideoFileClip(vid)
# Extract original FPS if fps is set to 0
if fps == 0:
fps = video.fps
# Extract audio from the video
audio = video.audio
# Extract frames at the specified FPS
frames = video.iter_frames(fps=fps)
processed_frames = []
yield gr.update(visible=True), gr.update(visible=False)
# Process each frame for background removal
for frame in frames:
pil_image = Image.fromarray(frame)
processed_image = process(pil_image, color)
processed_frames.append(np.array(processed_image))
yield processed_image, None
# Create a new video from the processed frames
processed_video = mp.ImageSequenceClip(processed_frames, fps=fps)
# Add the original audio back to the processed video
processed_video = processed_video.set_audio(audio)
# Save the processed video to a temporary file
temp_dir = "temp"
os.makedirs(temp_dir, exist_ok=True)
unique_filename = str(uuid.uuid4()) + ".mp4"
temp_filepath = os.path.join(temp_dir, unique_filename)
processed_video.write_videofile(temp_filepath, codec="libx264")
yield gr.update(visible=False), gr.update(visible=True)
yield None, temp_filepath # Return the final video path here
except Exception as e:
print(f"Error: {e}")
yield gr.update(visible=False), gr.update(visible=True)
yield None, f"Error processing video: {e}"
with gr.Blocks(theme=gr.themes.Ocean()) as demo:
gr.Markdown("# Video Background Remover & Changer\n### You can replace the video background with any solid color.")
with gr.Row():
in_video = gr.Video(label="Input Video", interactive=True)
stream_image = gr.Image(label="Streaming Output", visible=False)
out_video = gr.Video(label="Final Output Video")
submit_button = gr.Button("Change Background", interactive=True)
with gr.Row():
fps_slider = gr.Slider(
minimum=0,
maximum=60,
step=1,
value=0,
label="Output FPS (0 will inherit the original fps value)",
interactive=True
)
color_picker = gr.ColorPicker(label="Background Color", value="#00FF00", interactive=True)
submit_button.click(
fn,
inputs=[in_video, color_picker, fps_slider],
outputs=[stream_image, out_video],
)
if __name__ == "__main__":
demo.launch(show_error=True)
|