scholar-2001's picture
initial commit
91fd71a
import streamlit as st
import torch
from PIL import Image
import io
import numpy as np
from briarmbg import BriaRMBG
from torchvision.transforms.functional import normalize
# Reuse the functions from your CLI script
def convert_to_jpg(image, image_name):
"""Convert PNG to JPG if necessary."""
if image_name.lower().endswith('.png'):
img = Image.open(image)
# Convert to RGB if the image has an alpha channel
if img.mode in ('RGBA', 'LA') or (img.mode == 'P' and 'transparency' in img.info):
bg = Image.new("RGB", img.size, (255, 255, 255))
bg.paste(img, mask=img.split()[3] if img.mode == 'RGBA' else img.split()[1])
else:
bg = img.convert("RGB")
return bg
return Image.open(image)
def resize_image(image, size=(1024, 1024)):
image = image.convert('RGB')
image = image.resize(size, Image.BILINEAR)
return image
def remove_background(model, image):
# Save original size
original_size = image.size
# Convert to JPG if necessary
# image = convert_to_jpg(image)
# Preprocess the image
image_resized = resize_image(image)
im_np = np.array(image_resized)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
im_tensor = torch.unsqueeze(im_tensor,0)
im_tensor = torch.divide(im_tensor,255.0)
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
if torch.cuda.is_available():
im_tensor = im_tensor.cuda()
model = model.cuda()
# Process the image
with torch.no_grad():
result = model(im_tensor)
result = torch.squeeze(torch.nn.functional.interpolate(result[0][0], size=image_resized.size, mode='bilinear'), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result-mi)/(ma-mi)
im_array = (result*255).cpu().data.numpy().astype(np.uint8)
pil_im = Image.fromarray(np.squeeze(im_array)).resize(original_size, Image.BILINEAR)
# Create transparent image
new_im = Image.new("RGBA", original_size, (0,0,0,0))
new_im.paste(image, mask=pil_im)
return new_im
# Load the model
@st.cache_resource
def load_model():
device = "cuda" if torch.cuda.is_available() else "cpu"
net = BriaRMBG()
net.load_state_dict(torch.load("model.pth", map_location=device))
net.to(device)
net.eval()
return net
# Streamlit app
def main():
st.title("Background Removal App")
# Load model
model = load_model()
# File uploader
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Display original image
image = convert_to_jpg(uploaded_file, uploaded_file.name)
st.image(image, caption="Original Image", use_column_width=True)
# Process button
if st.button("Remove Background"):
# Process image
result = remove_background(model, image)
# Display result
st.image(result, caption="Image with Background Removed", use_column_width=True)
# Save button
buf = io.BytesIO()
result.save(buf, format="PNG")
byte_im = buf.getvalue()
st.download_button(
label="Download Image",
data=byte_im,
file_name="background_removed.png",
mime="image/png"
)
if __name__ == "__main__":
main()