svjack commited on
Commit
abe5686
·
verified ·
1 Parent(s): 5c7858a

Upload 8 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ pexels-cottonbro-5319934.mp4 filter=lfs diff=lfs merge=lfs -text
300_A_car_is_running_on_the_road.mp4 ADDED
Binary file (186 kB). View file
 
A_Terracotta_Warrior_is_skateboarding_9033688.mp4 ADDED
Binary file (138 kB). View file
 
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from gradio_imageslider import ImageSlider
4
+ from loadimg import load_img
5
+ import spaces
6
+ from transformers import AutoModelForImageSegmentation
7
+ import torch
8
+ from torchvision import transforms
9
+ from PIL import Image, ImageChops
10
+ from moviepy.editor import VideoFileClip, ImageSequenceClip
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+ from uuid import uuid1
14
+
15
+ # Check CUDA availability
16
+ if torch.cuda.is_available():
17
+ device = "cuda"
18
+ else:
19
+ device = "cpu"
20
+
21
+ torch.set_float32_matmul_precision(["high", "highest"][0])
22
+
23
+ # Load the model
24
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
25
+ "briaai/RMBG-2.0", trust_remote_code=True
26
+ )
27
+ birefnet.to(device)
28
+ transform_image = transforms.Compose(
29
+ [
30
+ transforms.Resize((1024, 1024)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
33
+ ]
34
+ )
35
+
36
+ output_folder = 'output_images'
37
+ if not os.path.exists(output_folder):
38
+ os.makedirs(output_folder)
39
+
40
+ def fn(image):
41
+ im = load_img(image, output_type="pil")
42
+ im = im.convert("RGB")
43
+ origin = im.copy()
44
+ image = process(im)
45
+ image_path = os.path.join(output_folder, "no_bg_image.png")
46
+ image.save(image_path)
47
+ return (image, origin), image_path
48
+
49
+ @spaces.GPU
50
+ def process(image):
51
+ image_size = image.size
52
+ input_images = transform_image(image).unsqueeze(0).to(device)
53
+ # Prediction
54
+ with torch.no_grad():
55
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
56
+ pred = preds[0].squeeze()
57
+ pred_pil = transforms.ToPILImage()(pred)
58
+ mask = pred_pil.resize(image_size)
59
+ image.putalpha(mask)
60
+ return image
61
+
62
+ def process_file(f):
63
+ name_path = f.rsplit(".",1)[0]+".png"
64
+ im = load_img(f, output_type="pil")
65
+ im = im.convert("RGB")
66
+ transparent = process(im)
67
+ transparent.save(name_path)
68
+ return name_path
69
+
70
+ def remove_background(image):
71
+ """Remove background from a single image."""
72
+ input_images = transform_image(image).unsqueeze(0).to(device)
73
+
74
+ # Prediction
75
+ with torch.no_grad():
76
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
77
+ pred = preds[0].squeeze()
78
+
79
+ # Convert the prediction to a mask
80
+ mask = (pred * 255).byte() # Convert to 0-255 range
81
+ mask_pil = transforms.ToPILImage()(mask).convert("L")
82
+ mask_resized = mask_pil.resize(image.size, Image.LANCZOS)
83
+
84
+ # Apply the mask to the image
85
+ image.putalpha(mask_resized)
86
+
87
+ return image, mask_resized
88
+
89
+ def process_video(input_video_path):
90
+ """Process a video to remove the background from each frame."""
91
+ # Load the video
92
+ video_clip = VideoFileClip(input_video_path)
93
+
94
+ # Process each frame
95
+ frames = []
96
+ for frame in tqdm(video_clip.iter_frames()):
97
+ frame_pil = Image.fromarray(frame)
98
+ frame_no_bg, mask_resized = remove_background(frame_pil)
99
+ path = "{}.png".format(uuid1())
100
+ frame_no_bg.save(path)
101
+ frame_no_bg = Image.open(path).convert("RGBA")
102
+ os.remove(path)
103
+
104
+ # Convert mask_resized to RGBA mode
105
+ mask_resized_rgba = mask_resized.convert("RGBA")
106
+
107
+ # Apply the mask using ImageChops.multiply
108
+ output = ImageChops.multiply(frame_no_bg, mask_resized_rgba)
109
+ output_np = np.array(output)
110
+ frames.append(output_np)
111
+
112
+ # Save the processed frames as a new video
113
+ output_video_path = os.path.join(output_folder, "no_bg_video.mp4")
114
+ processed_clip = ImageSequenceClip(frames, fps=video_clip.fps)
115
+ processed_clip.write_videofile(output_video_path, codec='libx264', ffmpeg_params=['-pix_fmt', 'yuva420p'])
116
+
117
+ return output_video_path
118
+
119
+ # Gradio components
120
+ slider1 = ImageSlider(label="RMBG-2.0", type="pil")
121
+ slider2 = ImageSlider(label="RMBG-2.0", type="pil")
122
+ image = gr.Image(label="Upload an image")
123
+ image2 = gr.Image(label="Upload an image", type="filepath")
124
+ text = gr.Textbox(label="Paste an image URL")
125
+ png_file = gr.File(label="output png file")
126
+ video_input = gr.Video(label="Upload a video")
127
+ video_output = gr.Video(label="Processed video")
128
+
129
+ # Example videos
130
+ example_videos = [
131
+ "pexels-cottonbro-5319934.mp4",
132
+ "300_A_car_is_running_on_the_road.mp4",
133
+ "A_Terracotta_Warrior_is_skateboarding_9033688.mp4"
134
+ ]
135
+
136
+ # Gradio interfaces
137
+ tab1 = gr.Interface(
138
+ fn, inputs=image, outputs=[slider1, gr.File(label="output png file")], examples=[load_img("giraffe.jpg", output_type="pil")], api_name="image"
139
+ )
140
+
141
+ tab2 = gr.Interface(fn, inputs=text, outputs=[slider2, gr.File(label="output png file")], examples=["http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg"], api_name="text")
142
+ #tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["giraffe.jpg"], api_name="png")
143
+ tab4 = gr.Interface(process_video, inputs=video_input, outputs=video_output, examples=example_videos, api_name="video")
144
+
145
+ # Gradio tabbed interface
146
+ demo = gr.TabbedInterface(
147
+ [tab4, tab1, tab2], ["input video", "input image", "input url"], title="RMBG-2.0 for background removal"
148
+ )
149
+
150
+ if __name__ == "__main__":
151
+ demo.launch(share=True, show_error=True)
giraffe.jpg ADDED
image_app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from gradio_imageslider import ImageSlider
4
+ from loadimg import load_img
5
+ import spaces
6
+ from transformers import AutoModelForImageSegmentation
7
+ import torch
8
+ from torchvision import transforms
9
+
10
+ # 检查 CUDA 是否可用
11
+ if torch.cuda.is_available():
12
+ device = "cuda"
13
+ else:
14
+ device = "cpu"
15
+
16
+ torch.set_float32_matmul_precision(["high", "highest"][0])
17
+
18
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
19
+ "briaai/RMBG-2.0", trust_remote_code=True
20
+ )
21
+ birefnet.to(device)
22
+ transform_image = transforms.Compose(
23
+ [
24
+ transforms.Resize((1024, 1024)),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
27
+ ]
28
+ )
29
+
30
+ output_folder = 'output_images'
31
+ if not os.path.exists(output_folder):
32
+ os.makedirs(output_folder)
33
+
34
+ def fn(image):
35
+ im = load_img(image, output_type="pil")
36
+ im = im.convert("RGB")
37
+ origin = im.copy()
38
+ image = process(im)
39
+ image_path = os.path.join(output_folder, "no_bg_image.png")
40
+ image.save(image_path)
41
+ return (image, origin), image_path
42
+
43
+ @spaces.GPU
44
+ def process(image):
45
+ image_size = image.size
46
+ input_images = transform_image(image).unsqueeze(0).to(device)
47
+ # Prediction
48
+ with torch.no_grad():
49
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
50
+ pred = preds[0].squeeze()
51
+ pred_pil = transforms.ToPILImage()(pred)
52
+ mask = pred_pil.resize(image_size)
53
+ image.putalpha(mask)
54
+ return image
55
+
56
+ def process_file(f):
57
+ name_path = f.rsplit(".",1)[0]+".png"
58
+ im = load_img(f, output_type="pil")
59
+ im = im.convert("RGB")
60
+ transparent = process(im)
61
+ transparent.save(name_path)
62
+ return name_path
63
+
64
+ slider1 = ImageSlider(label="RMBG-2.0", type="pil")
65
+ slider2 = ImageSlider(label="RMBG-2.0", type="pil")
66
+ image = gr.Image(label="Upload an image")
67
+ image2 = gr.Image(label="Upload an image",type="filepath")
68
+ text = gr.Textbox(label="Paste an image URL")
69
+ png_file = gr.File(label="output png file")
70
+
71
+
72
+ chameleon = load_img("giraffe.jpg", output_type="pil")
73
+
74
+ url = "http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg"
75
+
76
+ tab1 = gr.Interface(
77
+ fn, inputs=image, outputs=[slider1, gr.File(label="output png file")], examples=[chameleon], api_name="image"
78
+ )
79
+
80
+ tab2 = gr.Interface(fn, inputs=text, outputs=[slider2, gr.File(label="output png file")], examples=[url], api_name="text")
81
+ tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["giraffe.jpg"], api_name="png")
82
+
83
+
84
+ demo = gr.TabbedInterface(
85
+ [tab1, tab2], ["input image", "input url"], title="RMBG-2.0 for background removal"
86
+ )
87
+
88
+ if __name__ == "__main__":
89
+ demo.launch(share=True, show_error=True)
pexels-cottonbro-5319934.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:647dd65864fef5aee16a8033740771a1e0f0e49d51596a8d9d452dc0767cfd54
3
+ size 22477729
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ accelerate
3
+ opencv-python
4
+ spaces
5
+ pillow
6
+ numpy
7
+ timm
8
+ kornia
9
+ prettytable
10
+ typing
11
+ scikit-image
12
+ #huggingface_hub
13
+ transformers>=4.39.1
14
+ gradio
15
+ gradio_imageslider
16
+ loadimg>=0.1.1
17
+ httpx[socks]
18
+ huggingface_hub==0.25.0
19
+ moviepy
video_script.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageChops
2
+ import torch
3
+ from torchvision import transforms
4
+ from transformers import AutoModelForImageSegmentation
5
+ from moviepy.editor import VideoFileClip, ImageSequenceClip
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from uuid import uuid1
9
+ import os
10
+
11
+ # Load the model
12
+ model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
13
+ torch.set_float32_matmul_precision('high') # Set precision
14
+ model.to('cuda')
15
+ model.eval()
16
+
17
+ # Data settings
18
+ image_size = (1024, 1024)
19
+ transform_image = transforms.Compose([
20
+ transforms.Resize(image_size),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
23
+ ])
24
+
25
+ def remove_background(image):
26
+ """Remove background from a single image."""
27
+ input_images = transform_image(image).unsqueeze(0).to('cuda')
28
+
29
+ # Prediction
30
+ with torch.no_grad():
31
+ preds = model(input_images)[-1].sigmoid().cpu()
32
+ pred = preds[0].squeeze()
33
+
34
+ # Convert the prediction to a mask
35
+ mask = (pred * 255).byte() # Convert to 0-255 range
36
+ mask_pil = transforms.ToPILImage()(mask).convert("L")
37
+ mask_resized = mask_pil.resize(image.size, Image.LANCZOS)
38
+
39
+ # Apply the mask to the image
40
+ image.putalpha(mask_resized)
41
+
42
+ return image, mask_resized
43
+
44
+ def process_video(input_video_path, output_video_path):
45
+ """Process a video to remove the background from each frame."""
46
+ # Load the video
47
+ video_clip = VideoFileClip(input_video_path)
48
+
49
+ # Process each frame
50
+ frames = []
51
+ for frame in tqdm(video_clip.iter_frames()):
52
+ frame_pil = Image.fromarray(frame)
53
+ frame_no_bg, mask_resized = remove_background(frame_pil)
54
+ path = "{}.png".format(uuid1())
55
+ frame_no_bg.save(path)
56
+ frame_no_bg = Image.open(path).convert("RGBA")
57
+ os.remove(path)
58
+
59
+ # Convert mask_resized to RGBA mode
60
+ mask_resized_rgba = mask_resized.convert("RGBA")
61
+
62
+ # Apply the mask using ImageChops.multiply
63
+ output = ImageChops.multiply(frame_no_bg, mask_resized_rgba)
64
+ output_np = np.array(output)
65
+ frames.append(output_np)
66
+
67
+ # Save the processed frames as a new video
68
+ processed_clip = ImageSequenceClip(frames, fps=video_clip.fps)
69
+ processed_clip.write_videofile(output_video_path, codec='libx264', ffmpeg_params=['-pix_fmt', 'yuva420p'])
70
+
71
+ if __name__ == "__main__":
72
+ from IPython import display
73
+ # Example usage
74
+ input_video_path = "300_A_car_is_running_on_the_road.mp4" # Replace with your video path
75
+ output_video_path = "300_A_car_is_running_on_the_road_no_bg.mp4"
76
+ process_video(input_video_path, output_video_path)
77
+ display.Video("300_A_car_is_running_on_the_road_no_bg.mp4")
78
+ pass