|
import gradio as gr |
|
from loadimg import load_img |
|
|
|
from transformers import AutoModelForImageSegmentation |
|
import torch |
|
from torchvision import transforms |
|
import moviepy.editor as mp |
|
from pydub import AudioSegment |
|
from PIL import Image |
|
import numpy as np |
|
import os |
|
import tempfile |
|
import uuid |
|
import devicetorch |
|
|
|
torch.set_float32_matmul_precision("medium") |
|
|
|
device = devicetorch.get(torch) |
|
|
|
birefnet = AutoModelForImageSegmentation.from_pretrained( |
|
"ZhengPeng7/BiRefNet", trust_remote_code=True |
|
) |
|
birefnet.to(device) |
|
transform_image = transforms.Compose( |
|
[ |
|
transforms.Resize((1024, 1024)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
|
|
|
|
def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=0, video_handling="slow_down"): |
|
try: |
|
|
|
video = mp.VideoFileClip(vid) |
|
|
|
|
|
if fps == 0: |
|
fps = video.fps |
|
|
|
|
|
audio = video.audio |
|
|
|
|
|
frames = video.iter_frames(fps=fps) |
|
|
|
|
|
processed_frames = [] |
|
|
|
|
|
if bg_type == "Video": |
|
background_video = mp.VideoFileClip(bg_video) |
|
if background_video.duration < video.duration: |
|
if video_handling == "slow_down": |
|
background_video = background_video.fx(mp.vfx.speedx, factor=video.duration / background_video.duration) |
|
else: |
|
background_video = mp.concatenate_videoclips([background_video] * int(video.duration / background_video.duration + 1)) |
|
background_frames = list(background_video.iter_frames(fps=fps)) |
|
else: |
|
background_frames = None |
|
|
|
bg_frame_index = 0 |
|
|
|
for i, frame in enumerate(frames): |
|
pil_image = Image.fromarray(frame) |
|
if bg_type == "Color": |
|
processed_image = process(pil_image, color) |
|
elif bg_type == "Image": |
|
processed_image = process(pil_image, bg_image) |
|
elif bg_type == "Video": |
|
if video_handling == "slow_down": |
|
background_frame = background_frames[bg_frame_index % len(background_frames)] |
|
bg_frame_index += 1 |
|
background_image = Image.fromarray(background_frame) |
|
processed_image = process(pil_image, background_image) |
|
else: |
|
background_frame = background_frames[bg_frame_index % len(background_frames)] |
|
bg_frame_index += 1 |
|
background_image = Image.fromarray(background_frame) |
|
processed_image = process(pil_image, background_image) |
|
else: |
|
processed_image = pil_image |
|
|
|
processed_frames.append(np.array(processed_image)) |
|
|
|
|
|
|
|
processed_video = mp.ImageSequenceClip(processed_frames, fps=fps) |
|
|
|
|
|
processed_video = processed_video.set_audio(audio) |
|
|
|
|
|
temp_dir = "temp" |
|
os.makedirs(temp_dir, exist_ok=True) |
|
unique_filename = str(uuid.uuid4()) + ".mp4" |
|
temp_filepath = os.path.join(temp_dir, unique_filename) |
|
processed_video.write_videofile(temp_filepath, codec="libx264") |
|
|
|
|
|
|
|
|
|
return processed_image, temp_filepath |
|
|
|
except Exception as e: |
|
print(f"Error: {e}") |
|
|
|
|
|
return None, f"Error processing video: {e}" |
|
|
|
|
|
|
|
def process(image, bg): |
|
image_size = image.size |
|
input_images = transform_image(image).unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
preds = birefnet(input_images)[-1].sigmoid().cpu() |
|
pred = preds[0].squeeze() |
|
pred_pil = transforms.ToPILImage()(pred) |
|
mask = pred_pil.resize(image_size) |
|
|
|
if isinstance(bg, str) and bg.startswith("#"): |
|
color_rgb = tuple(int(bg[i:i+2], 16) for i in (1, 3, 5)) |
|
background = Image.new("RGBA", image_size, color_rgb + (255,)) |
|
elif isinstance(bg, Image.Image): |
|
background = bg.convert("RGBA").resize(image_size) |
|
else: |
|
background = Image.open(bg).convert("RGBA").resize(image_size) |
|
|
|
|
|
image = Image.composite(image, background, mask) |
|
|
|
return image |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Ocean()) as demo: |
|
gr.Markdown("# Video Background Remover & Changer\n### You can replace image background with any color, image or video.\nNOTE: As this Space is running on ZERO GPU it has limit. It can handle approx 200frmaes at once. So, if you have big video than use small chunks or Duplicate this space.") |
|
with gr.Row(): |
|
in_video = gr.Video(label="Input Video", interactive=True) |
|
stream_image = gr.Image(label="Streaming Output", visible=False) |
|
out_video = gr.Video(label="Final Output Video") |
|
submit_button = gr.Button("Change Background", interactive=True) |
|
with gr.Row(): |
|
fps_slider = gr.Slider( |
|
minimum=0, |
|
maximum=60, |
|
step=1, |
|
value=0, |
|
label="Output FPS (0 will inherit the original fps value)", |
|
interactive=True |
|
) |
|
bg_type = gr.Radio(["Color", "Image", "Video"], label="Background Type", value="Color", interactive=True) |
|
color_picker = gr.ColorPicker(label="Background Color", value="#00FF00", visible=True, interactive=True) |
|
bg_image = gr.Image(label="Background Image", type="filepath", visible=False, interactive=True) |
|
bg_video = gr.Video(label="Background Video", visible=False, interactive=True) |
|
with gr.Column(visible=False) as video_handling_options: |
|
video_handling_radio = gr.Radio(["slow_down", "loop"], label="Video Handling", value="slow_down", interactive=True) |
|
|
|
def update_visibility(bg_type): |
|
if bg_type == "Color": |
|
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) |
|
elif bg_type == "Image": |
|
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) |
|
elif bg_type == "Video": |
|
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) |
|
else: |
|
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) |
|
|
|
|
|
bg_type.change(update_visibility, inputs=bg_type, outputs=[color_picker, bg_image, bg_video, video_handling_options]) |
|
|
|
|
|
examples = gr.Examples( |
|
[ |
|
["rickroll-2sec.mp4", "Video", None, "background.mp4"], |
|
["rickroll-2sec.mp4", "Image", "images.webp", None], |
|
["rickroll-2sec.mp4", "Color", None, None], |
|
], |
|
inputs=[in_video, bg_type, bg_image, bg_video], |
|
outputs=[stream_image, out_video], |
|
fn=fn, |
|
|
|
|
|
) |
|
|
|
|
|
submit_button.click( |
|
fn, |
|
inputs=[in_video, bg_type, bg_image, bg_video, color_picker, fps_slider, video_handling_radio], |
|
outputs=[stream_image, out_video], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(show_error=True) |
|
|