File size: 3,463 Bytes
91fd71a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import streamlit as st
import torch
from PIL import Image
import io
import numpy as np
from briarmbg import BriaRMBG
from torchvision.transforms.functional import normalize

# Reuse the functions from your CLI script
def convert_to_jpg(image, image_name):
    """Convert PNG to JPG if necessary."""
    if image_name.lower().endswith('.png'):
        img = Image.open(image)
        # Convert to RGB if the image has an alpha channel
        if img.mode in ('RGBA', 'LA') or (img.mode == 'P' and 'transparency' in img.info):
            bg = Image.new("RGB", img.size, (255, 255, 255))
            bg.paste(img, mask=img.split()[3] if img.mode == 'RGBA' else img.split()[1])
        else:
            bg = img.convert("RGB")
        return bg
    return Image.open(image)

def resize_image(image, size=(1024, 1024)):
    image = image.convert('RGB')
    image = image.resize(size, Image.BILINEAR)
    return image

def remove_background(model, image):
    # Save original size
    original_size = image.size
    
    # Convert to JPG if necessary
    # image = convert_to_jpg(image)
    
    # Preprocess the image
    image_resized = resize_image(image)
    im_np = np.array(image_resized)
    im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
    im_tensor = torch.unsqueeze(im_tensor,0)
    im_tensor = torch.divide(im_tensor,255.0)
    im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
    
    if torch.cuda.is_available():
        im_tensor = im_tensor.cuda()
        model = model.cuda()

    # Process the image
    with torch.no_grad():
        result = model(im_tensor)
    
    result = torch.squeeze(torch.nn.functional.interpolate(result[0][0], size=image_resized.size, mode='bilinear'), 0)
    ma = torch.max(result)
    mi = torch.min(result)
    result = (result-mi)/(ma-mi)    
    im_array = (result*255).cpu().data.numpy().astype(np.uint8)
    pil_im = Image.fromarray(np.squeeze(im_array)).resize(original_size, Image.BILINEAR)
    
    # Create transparent image
    new_im = Image.new("RGBA", original_size, (0,0,0,0))
    new_im.paste(image, mask=pil_im)
    
    return new_im

# Load the model
@st.cache_resource
def load_model():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    net = BriaRMBG()
    net.load_state_dict(torch.load("model.pth", map_location=device))
    net.to(device)
    net.eval()
    return net

# Streamlit app
def main():
    st.title("Background Removal App")

    # Load model
    model = load_model()

    # File uploader
    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

    if uploaded_file is not None:
        # Display original image
        image = convert_to_jpg(uploaded_file, uploaded_file.name)
        st.image(image, caption="Original Image", use_column_width=True)

        # Process button
        if st.button("Remove Background"):
            # Process image
            result = remove_background(model, image)

            # Display result
            st.image(result, caption="Image with Background Removed", use_column_width=True)

            # Save button
            buf = io.BytesIO()
            result.save(buf, format="PNG")
            byte_im = buf.getvalue()
            st.download_button(
                label="Download Image",
                data=byte_im,
                file_name="background_removed.png",
                mime="image/png"
            )

if __name__ == "__main__":
    main()