File size: 4,169 Bytes
1ad1a85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b61a40
1ad1a85
7b61a40
1ad1a85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b61a40
1ad1a85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import os
import uuid
import moviepy.editor as mp

# Set the torch precision
torch.set_float32_matmul_precision("medium")

# Set the device to use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the pre-trained image segmentation model
birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to(device)

# Define the image transformation pipeline
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

def process(image, color):
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to(device)
    
    # Predict the segmentation mask
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    
    # Convert prediction to PIL image
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)

    # Convert hex color to RGB
    color_rgb = tuple(int(color[i:i+2], 16) for i in (1, 3, 5))
    background = Image.new("RGBA", image_size, color_rgb + (255,))

    # Composite the image onto the background using the mask
    image = Image.composite(image.convert("RGBA"), background, mask)

    return image

def fn(vid, color="#00FF00", fps=0):
    try:
        # Load the video using moviepy
        video = mp.VideoFileClip(vid)

        # Extract original FPS if fps is set to 0
        if fps == 0:
            fps = video.fps

        # Extract audio from the video
        audio = video.audio

        # Extract frames at the specified FPS
        frames = video.iter_frames(fps=fps)

        processed_frames = []
        yield gr.update(visible=True), gr.update(visible=False)

        # Process each frame for background removal
        for frame in frames:
            pil_image = Image.fromarray(frame)
            processed_image = process(pil_image, color)
            processed_frames.append(np.array(processed_image))
            yield processed_image, None

        # Create a new video from the processed frames
        processed_video = mp.ImageSequenceClip(processed_frames, fps=fps)

        # Add the original audio back to the processed video
        processed_video = processed_video.set_audio(audio)

        # Save the processed video to a temporary file
        temp_dir = "temp"
        os.makedirs(temp_dir, exist_ok=True)
        unique_filename = str(uuid.uuid4()) + ".mp4"
        temp_filepath = os.path.join(temp_dir, unique_filename)
        processed_video.write_videofile(temp_filepath, codec="libx264")

        yield gr.update(visible=False), gr.update(visible=True)
        yield None, temp_filepath  # Return the final video path here

    except Exception as e:
        print(f"Error: {e}")
        yield gr.update(visible=False), gr.update(visible=True)
        yield None, f"Error processing video: {e}"

with gr.Blocks(theme=gr.themes.Ocean()) as demo:
    gr.Markdown("# Video Background Remover & Changer\n### You can replace the video background with any solid color.")
    with gr.Row():
        in_video = gr.Video(label="Input Video", interactive=True)
        stream_image = gr.Image(label="Streaming Output", visible=False)
        out_video = gr.Video(label="Final Output Video")
    submit_button = gr.Button("Change Background", interactive=True)
    with gr.Row():
        fps_slider = gr.Slider(
            minimum=0,
            maximum=60,
            step=1,
            value=0,
            label="Output FPS (0 will inherit the original fps value)",
            interactive=True
        )
        color_picker = gr.ColorPicker(label="Background Color", value="#00FF00", interactive=True)

    submit_button.click(
        fn,
        inputs=[in_video, color_picker, fps_slider],
        outputs=[stream_image, out_video],
    )

if __name__ == "__main__":
    demo.launch(show_error=True)