File size: 2,073 Bytes
542c815
3f8e328
542c815
c862667
542c815
 
42cca0b
d6e753e
 
 
42cca0b
4189f11
 
42cca0b
018621a
3267028
b98efed
42cca0b
0ce70fb
 
 
 
 
 
4189f11
 
 
542c815
988f91c
 
542c815
 
4189f11
42cca0b
4189f11
 
 
 
 
 
 
8a70686
4189f11
8a70686
4189f11
 
42cca0b
4189f11
8a70686
 
4189f11
42cca0b
4189f11
8a70686
42cca0b
4189f11
42cca0b
4189f11
8a70686
032f16e
4189f11
 
58f9d23
 
4189f11
d909bca
4189f11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
import torch
import torch.nn.functional as F
import functools
from torchvision.transforms.functional import normalize
import gradio as gr
from gradio_imageslider import ImageSlider
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple
import requests
from io import BytesIO

net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)

@functools.lru_cache()
def get_url_im(url):
    user_agent = {'User-agent': 'gradio-app'}
    response = requests.get(url, headers=user_agent)
    return BytesIO(response.content)

def resize_image(image_url):
    image_data = get_url_im(image_url)
    image = Image.open(image_data)
    image = image.convert('RGB')
    model_input_size = (1024, 1024)
    image = image.resize(model_input_size, Image.BILINEAR)
    return image

def process(image_url):
    # prepare input
    orig_image = resize_image(image_url)
    w, h = orig_im_size = orig_image.size
    im_np = np.array(orig_image)
    im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
    im_tensor = torch.unsqueeze(im_tensor, 0)
    im_tensor = torch.divide(im_tensor, 255.0)
    im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
    if torch.cuda.is_available():
        im_tensor = im_tensor.cuda()

    # inference
    result = net(im_tensor)
    # post process
    result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
    ma = torch.max(result)
    mi = torch.min(result)
    result = (result - mi) / (ma - mi)
    # image to pil
    im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
    pil_im = Image.fromarray(np.squeeze(im_array))
    # paste the mask on the original image
    new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
    new_im.paste(orig_image, mask=pil_im)
    
    return new_im

iface = gr.Interface(
    fn=process,
    inputs=gr.Textbox(label="Text or Image URL"),
    outputs=gr.Image(type="pil", label="Output Image"),
)

iface.launch()