Fabrice-TIERCELIN commited on
Commit
4cae45e
1 Parent(s): 7f6739c

Add working files

Browse files
Files changed (6) hide show
  1. README.md +7 -12
  2. app.py +110 -273
  3. briarmbg.py +455 -0
  4. foo.py +2 -0
  5. input.jpg +0 -0
  6. requirements.txt +9 -26
README.md CHANGED
@@ -1,18 +1,13 @@
1
  ---
2
- title: Text-to-Audio
3
- emoji: 🔊
4
- colorFrom: gray
5
- colorTo: gray
6
- tags:
7
- - sound generation
8
- - language models
9
- - LLMs
10
  sdk: gradio
11
- sdk_version: 4.40.0
12
  app_file: app.py
13
  pinned: false
14
- license: openrail
15
- short_description: Sound effect from description
16
  ---
17
 
18
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: BRIA RMBG 1.4
3
+ emoji: 💻
4
+ colorFrom: red
5
+ colorTo: red
 
 
 
 
6
  sdk: gradio
7
+ sdk_version: 4.16.0
8
  app_file: app.py
9
  pinned: false
10
+ license: other
 
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,278 +1,115 @@
1
- import gradio as gr
2
- import json
3
  import torch
4
- import time
5
- import random
6
- try:
7
- # Only on HuggingFace
8
- import spaces
9
- is_space_imported = True
10
- except ImportError:
11
- is_space_imported = False
12
-
13
- from tqdm import tqdm
14
- from huggingface_hub import snapshot_download
15
- from models import AudioDiffusion, DDPMScheduler
16
- from audioldm.audio.stft import TacotronSTFT
17
- from audioldm.variational_autoencoder import AutoencoderKL
18
- from pydub import AudioSegment
19
-
20
- max_64_bit_int = 2**63 - 1
21
-
22
- # Automatic device detection
23
  if torch.cuda.is_available():
24
- device_type = "cuda"
25
- device_selection = "cuda:0"
 
 
 
 
 
26
  else:
27
- device_type = "cpu"
28
- device_selection = "cpu"
29
-
30
- class Tango:
31
- def __init__(self, name = "declare-lab/tango2", device = device_selection):
32
-
33
- path = snapshot_download(repo_id = name)
34
-
35
- vae_config = json.load(open("{}/vae_config.json".format(path)))
36
- stft_config = json.load(open("{}/stft_config.json".format(path)))
37
- main_config = json.load(open("{}/main_config.json".format(path)))
38
-
39
- self.vae = AutoencoderKL(**vae_config).to(device)
40
- self.stft = TacotronSTFT(**stft_config).to(device)
41
- self.model = AudioDiffusion(**main_config).to(device)
42
-
43
- vae_weights = torch.load("{}/pytorch_model_vae.bin".format(path), map_location = device)
44
- stft_weights = torch.load("{}/pytorch_model_stft.bin".format(path), map_location = device)
45
- main_weights = torch.load("{}/pytorch_model_main.bin".format(path), map_location = device)
46
-
47
- self.vae.load_state_dict(vae_weights)
48
- self.stft.load_state_dict(stft_weights)
49
- self.model.load_state_dict(main_weights)
50
 
51
- print ("Successfully loaded checkpoint from:", name)
52
-
53
- self.vae.eval()
54
- self.stft.eval()
55
- self.model.eval()
56
-
57
- self.scheduler = DDPMScheduler.from_pretrained(main_config["scheduler_name"], subfolder = "scheduler")
58
-
59
- def chunks(self, lst, n):
60
- # Yield successive n-sized chunks from a list
61
- for i in range(0, len(lst), n):
62
- yield lst[i:i + n]
63
-
64
- def generate(self, prompt, steps = 100, guidance = 3, samples = 1, disable_progress = True):
65
- # Generate audio for a single prompt string
66
- with torch.no_grad():
67
- latents = self.model.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress = disable_progress)
68
- mel = self.vae.decode_first_stage(latents)
69
- wave = self.vae.decode_to_waveform(mel)
70
- return wave
71
 
72
- def generate_for_batch(self, prompts, steps = 200, guidance = 3, samples = 1, batch_size = 8, disable_progress = True):
73
- # Generate audio for a list of prompt strings
74
- outputs = []
75
- for k in tqdm(range(0, len(prompts), batch_size)):
76
- batch = prompts[k: k + batch_size]
77
- with torch.no_grad():
78
- latents = self.model.inference(batch, self.scheduler, steps, guidance, samples, disable_progress = disable_progress)
79
- mel = self.vae.decode_first_stage(latents)
80
- wave = self.vae.decode_to_waveform(mel)
81
- outputs += [item for item in wave]
82
- if samples == 1:
83
- return outputs
84
- return list(self.chunks(outputs, samples))
85
-
86
- # Initialize TANGO
87
-
88
- tango = Tango(device = "cpu")
89
- tango.vae.to(device_type)
90
- tango.stft.to(device_type)
91
- tango.model.to(device_type)
92
-
93
- def update_seed(is_randomize_seed, seed):
94
- if is_randomize_seed:
95
- return random.randint(0, max_64_bit_int)
96
- return seed
97
-
98
- def check(
99
- prompt,
100
- output_number,
101
- steps,
102
- guidance,
103
- is_randomize_seed,
104
- seed
105
- ):
106
- if prompt is None or prompt == "":
107
- raise gr.Error("Please provide a prompt input.")
108
- if not output_number in [1, 2, 3]:
109
- raise gr.Error("Please ask for 1, 2 or 3 output files.")
110
-
111
- def update_output(output_format, output_number):
112
- return [
113
- gr.update(format = output_format),
114
- gr.update(format = output_format, visible = (2 <= output_number)),
115
- gr.update(format = output_format, visible = (output_number == 3)),
116
- gr.update(visible = False)
117
- ]
118
-
119
- def text2audio(
120
- prompt,
121
- output_number,
122
- steps,
123
- guidance,
124
- is_randomize_seed,
125
- seed
126
- ):
127
- start = time.time()
128
-
129
- if seed is None:
130
- seed = random.randint(0, max_64_bit_int)
131
-
132
- random.seed(seed)
133
- torch.manual_seed(seed)
134
-
135
- output_wave = tango.generate(prompt, steps, guidance, output_number)
136
-
137
- output_wave_1 = gr.make_waveform((16000, output_wave[0]))
138
- output_wave_2 = gr.make_waveform((16000, output_wave[1])) if (2 <= output_number) else None
139
- output_wave_3 = gr.make_waveform((16000, output_wave[2])) if (output_number == 3) else None
140
-
141
- end = time.time()
142
- secondes = int(end - start)
143
- minutes = secondes // 60
144
- secondes = secondes - (minutes * 60)
145
- hours = minutes // 60
146
- minutes = minutes - (hours * 60)
147
- return [
148
- output_wave_1,
149
- output_wave_2,
150
- output_wave_3,
151
- gr.update(visible = True, value = "Start again to get a different result. The output have been generated in " + ((str(hours) + " h, ") if hours != 0 else "") + ((str(minutes) + " min, ") if hours != 0 or minutes != 0 else "") + str(secondes) + " sec.")
152
- ]
153
-
154
- if is_space_imported:
155
- text2audio = spaces.GPU(text2audio, duration = 420)
156
-
157
- # Gradio interface
158
- with gr.Blocks() as interface:
159
- gr.Markdown("""
160
- <p style="text-align: center;">
161
- <b><big><big><big>Text-to-Audio</big></big></big></b>
162
- <br/>Generates 10 seconds of sound effects from description, freely, without account, without watermark
163
- </p>
164
- <br/>
165
- <br/>
166
- ✨ Powered by <i>Tango 2</i> AI.
167
- <br/>
168
- <ul>
169
- <li>If you need <b>47 seconds</b> of audio, I recommend to use <i>Stable Audio</i>,</li>
170
- <li>If you need to generate <b>music</b>, I recommend to use <i>MusicGen</i>,</li>
171
- </ul>
172
- <br/>
173
- """ + ("🏃‍♀️ Estimated time: few minutes. Current device: GPU." if torch.cuda.is_available() else "🐌 Slow process... ~5 min. Current device: CPU.") + """
174
- Your computer must <b><u>not</u></b> enter into standby mode.<br/>You can duplicate this space on a free account, it's designed to work on CPU, GPU and ZeroGPU.<br/>
175
- <a href='https://huggingface.co/spaces/Fabrice-TIERCELIN/Text-to-Audio?duplicate=true&hidden=public&hidden=public'><img src='https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14'></a>
176
- <br/>
177
- ⚖️ You can use, modify and share the generated sounds but not for commercial uses.
178
- """
179
- )
180
- input_text = gr.Textbox(label = "Prompt", value = "Snort of a horse", lines = 2, autofocus = True)
181
- with gr.Accordion("Advanced options", open = False):
182
- output_format = gr.Radio(label = "Output format", info = "The file you can dowload", choices = ["mp3", "wav"], value = "wav")
183
- output_number = gr.Slider(label = "Number of generations", info = "1, 2 or 3 output files", minimum = 1, maximum = 3, value = 1, step = 1, interactive = True)
184
- denoising_steps = gr.Slider(label = "Steps", info = "lower=faster & variant, higher=audio quality & similar", minimum = 10, maximum = 200, value = 10, step = 1, interactive = True)
185
- guidance_scale = gr.Slider(label = "Guidance Scale", info = "lower=audio quality, higher=follow the prompt", minimum = 1, maximum = 10, value = 3, step = 0.1, interactive = True)
186
- randomize_seed = gr.Checkbox(label = "\U0001F3B2 Randomize seed", value = True, info = "If checked, result is always different")
187
- seed = gr.Slider(minimum = 0, maximum = max_64_bit_int, step = 1, randomize = True, label = "Seed")
188
-
189
- submit = gr.Button("🚀 Generate", variant = "primary")
190
-
191
- output_audio_1 = gr.Audio(label = "Generated Audio #1/3", format = "wav", type="numpy", autoplay = True)
192
- output_audio_2 = gr.Audio(label = "Generated Audio #2/3", format = "wav", type="numpy")
193
- output_audio_3 = gr.Audio(label = "Generated Audio #3/3", format = "wav", type="numpy")
194
- information = gr.Label(label = "Information")
195
-
196
- submit.click(fn = update_seed, inputs = [
197
- randomize_seed,
198
- seed
199
- ], outputs = [
200
- seed
201
- ], queue = False, show_progress = False).then(fn = check, inputs = [
202
- input_text,
203
- output_number,
204
- denoising_steps,
205
- guidance_scale,
206
- randomize_seed,
207
- seed
208
- ], outputs = [], queue = False, show_progress = False).success(fn = update_output, inputs = [
209
- output_format,
210
- output_number
211
- ], outputs = [
212
- output_audio_1,
213
- output_audio_2,
214
- output_audio_3,
215
- information
216
- ], queue = False, show_progress = False).success(fn = text2audio, inputs = [
217
- input_text,
218
- output_number,
219
- denoising_steps,
220
- guidance_scale,
221
- randomize_seed,
222
- seed
223
- ], outputs = [
224
- output_audio_1,
225
- output_audio_2,
226
- output_audio_3,
227
- information
228
- ], scroll_to_output = True)
229
-
230
- gr.Examples(
231
- fn = text2audio,
232
- inputs = [
233
- input_text,
234
- output_number,
235
- denoising_steps,
236
- guidance_scale,
237
- randomize_seed,
238
- seed
239
- ],
240
- outputs = [
241
- output_audio_1,
242
- output_audio_2,
243
- output_audio_3,
244
- information
245
- ],
246
- examples = [
247
- ["A hammer is hitting a wooden surface", 3, 100, 3, False, 123],
248
- ["Peaceful and calming ambient music with singing bowl and other instruments.", 3, 100, 3, False, 123],
249
- ["A man is speaking in a small room.", 2, 100, 3, False, 123],
250
- ["A female is speaking followed by footstep sound", 1, 100, 3, False, 123],
251
- ["Wooden table tapping sound followed by water pouring sound.", 3, 200, 3, False, 123],
252
- ],
253
- cache_examples = "lazy" if is_space_imported else False,
254
- )
255
-
256
- gr.Markdown(
257
- """
258
- ## How to prompt your sound
259
- You can use round brackets to increase the importance of a part:
260
- ```
261
- Peaceful and (calming) ambient music with singing bowl and other instruments
262
- ```
263
- You can use several levels of round brackets to even more increase the importance of a part:
264
- ```
265
- (Peaceful) and ((calming)) ambient music with singing bowl and other instruments
266
- ```
267
- You can use number instead of several round brackets:
268
- ```
269
- (Peaceful:1.5) and ((calming)) ambient music with singing bowl and other instruments
270
- ```
271
- You can do the same thing with square brackets to decrease the importance of a part:
272
- ```
273
- (Peaceful:1.5) and ((calming)) ambient music with [singing:2] bowl and other instruments
274
- """
275
- )
276
-
277
- if __name__ == "__main__":
278
- interface.launch(share = False)
 
1
+ import numpy as np
 
2
  import torch
3
+ import torch.nn.functional as F
4
+ from torchvision.transforms.functional import normalize
5
+ from huggingface_hub import hf_hub_download
6
+ import gradio as gr
7
+ from gradio_imageslider import ImageSlider
8
+ from briarmbg import BriaRMBG
9
+ import PIL
10
+ from PIL import Image
11
+ from typing import Tuple
12
+
13
+ net=BriaRMBG()
14
+ # model_path = "./model1.pth"
15
+ #model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
16
+ model_path = hf_hub_download("cocktailpeanut/gbmr", 'model.pth')
 
 
 
 
 
17
  if torch.cuda.is_available():
18
+ net.load_state_dict(torch.load(model_path))
19
+ net=net.cuda()
20
+ device = "cuda"
21
+ elif torch.backends.mps.is_available():
22
+ net.load_state_dict(torch.load(model_path,map_location="mps"))
23
+ net=net.to("mps")
24
+ device = "mps"
25
  else:
26
+ net.load_state_dict(torch.load(model_path,map_location="cpu"))
27
+ device = "cpu"
28
+ net.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ def resize_image(image):
32
+ image = image.convert('RGB')
33
+ model_input_size = (1024, 1024)
34
+ image = image.resize(model_input_size, Image.BILINEAR)
35
+ return image
36
+
37
+
38
+ def process(image):
39
+
40
+ # prepare input
41
+ orig_image = Image.fromarray(image)
42
+ w,h = orig_im_size = orig_image.size
43
+ image = resize_image(orig_image)
44
+ im_np = np.array(image)
45
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
46
+ im_tensor = torch.unsqueeze(im_tensor,0)
47
+ im_tensor = torch.divide(im_tensor,255.0)
48
+ im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
49
+ if device == "cuda":
50
+ im_tensor=im_tensor.cuda()
51
+ elif device == "mps":
52
+ im_tensor=im_tensor.to("mps")
53
+
54
+ #inference
55
+ result=net(im_tensor)
56
+ # post process
57
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
58
+ ma = torch.max(result)
59
+ mi = torch.min(result)
60
+ result = (result-mi)/(ma-mi)
61
+ # image to pil
62
+ im_array = (result*255).cpu().data.numpy().astype(np.uint8)
63
+ pil_im = Image.fromarray(np.squeeze(im_array))
64
+ # paste the mask on the original image
65
+ new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
66
+ new_im.paste(orig_image, mask=pil_im)
67
+ # new_orig_image = orig_image.convert('RGBA')
68
+
69
+ return new_im
70
+ # return [new_orig_image, new_im]
71
+
72
+
73
+ # block = gr.Blocks().queue()
74
+
75
+ # with block:
76
+ # gr.Markdown("## BRIA RMBG 1.4")
77
+ # gr.HTML('''
78
+ # <p style="margin-bottom: 10px; font-size: 94%">
79
+ # This is a demo for BRIA RMBG 1.4 that using
80
+ # <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
81
+ # </p>
82
+ # ''')
83
+ # with gr.Row():
84
+ # with gr.Column():
85
+ # input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
86
+ # # input_image = gr.Image(sources=None, type="numpy") # None for upload, ctrl+v and webcam
87
+ # run_button = gr.Button(value="Run")
88
+
89
+ # with gr.Column():
90
+ # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[1], height='auto')
91
+ # ips = [input_image]
92
+ # run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
93
+
94
+ # block.launch(debug = True)
95
+
96
+ # block = gr.Blocks().queue()
97
+
98
+ gr.Markdown("## BRIA RMBG 1.4")
99
+ gr.HTML('''
100
+ <p style="margin-bottom: 10px; font-size: 94%">
101
+ This is a demo for BRIA RMBG 1.4 that using
102
+ <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
103
+ </p>
104
+ ''')
105
+ title = "Background Removal"
106
+ description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
107
+ For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
108
+ """
109
+ examples = [['./input.jpg'],]
110
+ # output = ImageSlider(position=0.5,label='Image without background', type="pil", show_download_button=True)
111
+ # demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description)
112
+ demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)
113
+
114
+ if __name__ == "__main__":
115
+ demo.launch(share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
briarmbg.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class REBNCONV(nn.Module):
6
+ def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
7
+ super(REBNCONV,self).__init__()
8
+
9
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
10
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
11
+ self.relu_s1 = nn.ReLU(inplace=True)
12
+
13
+ def forward(self,x):
14
+
15
+ hx = x
16
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
17
+
18
+ return xout
19
+
20
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
21
+ def _upsample_like(src,tar):
22
+
23
+ src = F.interpolate(src,size=tar.shape[2:],mode='bilinear')
24
+
25
+ return src
26
+
27
+
28
+ ### RSU-7 ###
29
+ class RSU7(nn.Module):
30
+
31
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
32
+ super(RSU7,self).__init__()
33
+
34
+ self.in_ch = in_ch
35
+ self.mid_ch = mid_ch
36
+ self.out_ch = out_ch
37
+
38
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
39
+
40
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
41
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
42
+
43
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
44
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
45
+
46
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
47
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
48
+
49
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
50
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
51
+
52
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
53
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
54
+
55
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
56
+
57
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
58
+
59
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
60
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
61
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
62
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
63
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
64
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
65
+
66
+ def forward(self,x):
67
+ b, c, h, w = x.shape
68
+
69
+ hx = x
70
+ hxin = self.rebnconvin(hx)
71
+
72
+ hx1 = self.rebnconv1(hxin)
73
+ hx = self.pool1(hx1)
74
+
75
+ hx2 = self.rebnconv2(hx)
76
+ hx = self.pool2(hx2)
77
+
78
+ hx3 = self.rebnconv3(hx)
79
+ hx = self.pool3(hx3)
80
+
81
+ hx4 = self.rebnconv4(hx)
82
+ hx = self.pool4(hx4)
83
+
84
+ hx5 = self.rebnconv5(hx)
85
+ hx = self.pool5(hx5)
86
+
87
+ hx6 = self.rebnconv6(hx)
88
+
89
+ hx7 = self.rebnconv7(hx6)
90
+
91
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
92
+ hx6dup = _upsample_like(hx6d,hx5)
93
+
94
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
95
+ hx5dup = _upsample_like(hx5d,hx4)
96
+
97
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
98
+ hx4dup = _upsample_like(hx4d,hx3)
99
+
100
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
101
+ hx3dup = _upsample_like(hx3d,hx2)
102
+
103
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
104
+ hx2dup = _upsample_like(hx2d,hx1)
105
+
106
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
107
+
108
+ return hx1d + hxin
109
+
110
+
111
+ ### RSU-6 ###
112
+ class RSU6(nn.Module):
113
+
114
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
115
+ super(RSU6,self).__init__()
116
+
117
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
118
+
119
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
120
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
121
+
122
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
123
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
124
+
125
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
126
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
127
+
128
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
129
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
130
+
131
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
132
+
133
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
134
+
135
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
136
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
137
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
138
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
139
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
140
+
141
+ def forward(self,x):
142
+
143
+ hx = x
144
+
145
+ hxin = self.rebnconvin(hx)
146
+
147
+ hx1 = self.rebnconv1(hxin)
148
+ hx = self.pool1(hx1)
149
+
150
+ hx2 = self.rebnconv2(hx)
151
+ hx = self.pool2(hx2)
152
+
153
+ hx3 = self.rebnconv3(hx)
154
+ hx = self.pool3(hx3)
155
+
156
+ hx4 = self.rebnconv4(hx)
157
+ hx = self.pool4(hx4)
158
+
159
+ hx5 = self.rebnconv5(hx)
160
+
161
+ hx6 = self.rebnconv6(hx5)
162
+
163
+
164
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
165
+ hx5dup = _upsample_like(hx5d,hx4)
166
+
167
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
168
+ hx4dup = _upsample_like(hx4d,hx3)
169
+
170
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
171
+ hx3dup = _upsample_like(hx3d,hx2)
172
+
173
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
174
+ hx2dup = _upsample_like(hx2d,hx1)
175
+
176
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
177
+
178
+ return hx1d + hxin
179
+
180
+ ### RSU-5 ###
181
+ class RSU5(nn.Module):
182
+
183
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
184
+ super(RSU5,self).__init__()
185
+
186
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
187
+
188
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
189
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
190
+
191
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
192
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
193
+
194
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
195
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
196
+
197
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
198
+
199
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
200
+
201
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
202
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
203
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
204
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
205
+
206
+ def forward(self,x):
207
+
208
+ hx = x
209
+
210
+ hxin = self.rebnconvin(hx)
211
+
212
+ hx1 = self.rebnconv1(hxin)
213
+ hx = self.pool1(hx1)
214
+
215
+ hx2 = self.rebnconv2(hx)
216
+ hx = self.pool2(hx2)
217
+
218
+ hx3 = self.rebnconv3(hx)
219
+ hx = self.pool3(hx3)
220
+
221
+ hx4 = self.rebnconv4(hx)
222
+
223
+ hx5 = self.rebnconv5(hx4)
224
+
225
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
226
+ hx4dup = _upsample_like(hx4d,hx3)
227
+
228
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
229
+ hx3dup = _upsample_like(hx3d,hx2)
230
+
231
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
232
+ hx2dup = _upsample_like(hx2d,hx1)
233
+
234
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
235
+
236
+ return hx1d + hxin
237
+
238
+ ### RSU-4 ###
239
+ class RSU4(nn.Module):
240
+
241
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
242
+ super(RSU4,self).__init__()
243
+
244
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
245
+
246
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
247
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
248
+
249
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
250
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
251
+
252
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
253
+
254
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
255
+
256
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
257
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
258
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
259
+
260
+ def forward(self,x):
261
+
262
+ hx = x
263
+
264
+ hxin = self.rebnconvin(hx)
265
+
266
+ hx1 = self.rebnconv1(hxin)
267
+ hx = self.pool1(hx1)
268
+
269
+ hx2 = self.rebnconv2(hx)
270
+ hx = self.pool2(hx2)
271
+
272
+ hx3 = self.rebnconv3(hx)
273
+
274
+ hx4 = self.rebnconv4(hx3)
275
+
276
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
277
+ hx3dup = _upsample_like(hx3d,hx2)
278
+
279
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
280
+ hx2dup = _upsample_like(hx2d,hx1)
281
+
282
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
283
+
284
+ return hx1d + hxin
285
+
286
+ ### RSU-4F ###
287
+ class RSU4F(nn.Module):
288
+
289
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
290
+ super(RSU4F,self).__init__()
291
+
292
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
293
+
294
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
295
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
296
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
297
+
298
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
299
+
300
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
301
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
302
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
303
+
304
+ def forward(self,x):
305
+
306
+ hx = x
307
+
308
+ hxin = self.rebnconvin(hx)
309
+
310
+ hx1 = self.rebnconv1(hxin)
311
+ hx2 = self.rebnconv2(hx1)
312
+ hx3 = self.rebnconv3(hx2)
313
+
314
+ hx4 = self.rebnconv4(hx3)
315
+
316
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
317
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
318
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
319
+
320
+ return hx1d + hxin
321
+
322
+
323
+ class myrebnconv(nn.Module):
324
+ def __init__(self, in_ch=3,
325
+ out_ch=1,
326
+ kernel_size=3,
327
+ stride=1,
328
+ padding=1,
329
+ dilation=1,
330
+ groups=1):
331
+ super(myrebnconv,self).__init__()
332
+
333
+ self.conv = nn.Conv2d(in_ch,
334
+ out_ch,
335
+ kernel_size=kernel_size,
336
+ stride=stride,
337
+ padding=padding,
338
+ dilation=dilation,
339
+ groups=groups)
340
+ self.bn = nn.BatchNorm2d(out_ch)
341
+ self.rl = nn.ReLU(inplace=True)
342
+
343
+ def forward(self,x):
344
+ return self.rl(self.bn(self.conv(x)))
345
+
346
+
347
+ class BriaRMBG(nn.Module):
348
+
349
+ def __init__(self,in_ch=3,out_ch=1):
350
+ super(BriaRMBG,self).__init__()
351
+
352
+ self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
353
+ self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
354
+
355
+ self.stage1 = RSU7(64,32,64)
356
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
357
+
358
+ self.stage2 = RSU6(64,32,128)
359
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
360
+
361
+ self.stage3 = RSU5(128,64,256)
362
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
363
+
364
+ self.stage4 = RSU4(256,128,512)
365
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
366
+
367
+ self.stage5 = RSU4F(512,256,512)
368
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
369
+
370
+ self.stage6 = RSU4F(512,256,512)
371
+
372
+ # decoder
373
+ self.stage5d = RSU4F(1024,256,512)
374
+ self.stage4d = RSU4(1024,128,256)
375
+ self.stage3d = RSU5(512,64,128)
376
+ self.stage2d = RSU6(256,32,64)
377
+ self.stage1d = RSU7(128,16,64)
378
+
379
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
380
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
381
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
382
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
383
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
384
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
385
+
386
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
387
+
388
+ def forward(self,x):
389
+
390
+ hx = x
391
+
392
+ hxin = self.conv_in(hx)
393
+ #hx = self.pool_in(hxin)
394
+
395
+ #stage 1
396
+ hx1 = self.stage1(hxin)
397
+ hx = self.pool12(hx1)
398
+
399
+ #stage 2
400
+ hx2 = self.stage2(hx)
401
+ hx = self.pool23(hx2)
402
+
403
+ #stage 3
404
+ hx3 = self.stage3(hx)
405
+ hx = self.pool34(hx3)
406
+
407
+ #stage 4
408
+ hx4 = self.stage4(hx)
409
+ hx = self.pool45(hx4)
410
+
411
+ #stage 5
412
+ hx5 = self.stage5(hx)
413
+ hx = self.pool56(hx5)
414
+
415
+ #stage 6
416
+ hx6 = self.stage6(hx)
417
+ hx6up = _upsample_like(hx6,hx5)
418
+
419
+ #-------------------- decoder --------------------
420
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
421
+ hx5dup = _upsample_like(hx5d,hx4)
422
+
423
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
424
+ hx4dup = _upsample_like(hx4d,hx3)
425
+
426
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
427
+ hx3dup = _upsample_like(hx3d,hx2)
428
+
429
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
430
+ hx2dup = _upsample_like(hx2d,hx1)
431
+
432
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
433
+
434
+
435
+ #side output
436
+ d1 = self.side1(hx1d)
437
+ d1 = _upsample_like(d1,x)
438
+
439
+ d2 = self.side2(hx2d)
440
+ d2 = _upsample_like(d2,x)
441
+
442
+ d3 = self.side3(hx3d)
443
+ d3 = _upsample_like(d3,x)
444
+
445
+ d4 = self.side4(hx4d)
446
+ d4 = _upsample_like(d4,x)
447
+
448
+ d5 = self.side5(hx5d)
449
+ d5 = _upsample_like(d5,x)
450
+
451
+ d6 = self.side6(hx6)
452
+ d6 = _upsample_like(d6,x)
453
+
454
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
455
+
foo.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ def hello():
2
+ print("hello world")
input.jpg ADDED
requirements.txt CHANGED
@@ -1,26 +1,9 @@
1
- torch==2.4.0
2
- torchaudio==2.4.0
3
- torchvision==0.19.0
4
- transformers==4.31.0
5
- accelerate==0.21.0
6
- datasets==2.1.0
7
- einops==0.8.0
8
- huggingface_hub==0.19.4
9
- importlib_metadata==6.3.0
10
- librosa==0.9.2
11
- matplotlib==3.9.0
12
- numpy==1.23.0
13
- omegaconf==2.3.0
14
- packaging==24.1
15
- progressbar33==2.4
16
- protobuf==3.20.*
17
- safetensors==0.4.4
18
- sentencepiece==0.1.99
19
- scipy==1.8.0
20
- soundfile==0.12.1
21
- torchlibrosa==0.1.0
22
- tqdm==4.63.1
23
- wandb==0.12.14
24
- ipython==8.12.0
25
- gradio==4.3.0
26
- wavio==0.0.7
 
1
+ gradio==4.16.0
2
+ gradio_imageslider
3
+ #torch
4
+ #torchvision
5
+ pillow
6
+ numpy
7
+ typing
8
+ gitpython
9
+ huggingface_hub