import streamlit as st import torch from PIL import Image import io import numpy as np from briarmbg import BriaRMBG from torchvision.transforms.functional import normalize # Reuse the functions from your CLI script def convert_to_jpg(image, image_name): """Convert PNG to JPG if necessary.""" if image_name.lower().endswith('.png'): img = Image.open(image) # Convert to RGB if the image has an alpha channel if img.mode in ('RGBA', 'LA') or (img.mode == 'P' and 'transparency' in img.info): bg = Image.new("RGB", img.size, (255, 255, 255)) bg.paste(img, mask=img.split()[3] if img.mode == 'RGBA' else img.split()[1]) else: bg = img.convert("RGB") return bg return Image.open(image) def resize_image(image, size=(1024, 1024)): image = image.convert('RGB') image = image.resize(size, Image.BILINEAR) return image def remove_background(model, image): # Save original size original_size = image.size # Convert to JPG if necessary # image = convert_to_jpg(image) # Preprocess the image image_resized = resize_image(image) im_np = np.array(image_resized) im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1) im_tensor = torch.unsqueeze(im_tensor,0) im_tensor = torch.divide(im_tensor,255.0) im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0]) if torch.cuda.is_available(): im_tensor = im_tensor.cuda() model = model.cuda() # Process the image with torch.no_grad(): result = model(im_tensor) result = torch.squeeze(torch.nn.functional.interpolate(result[0][0], size=image_resized.size, mode='bilinear'), 0) ma = torch.max(result) mi = torch.min(result) result = (result-mi)/(ma-mi) im_array = (result*255).cpu().data.numpy().astype(np.uint8) pil_im = Image.fromarray(np.squeeze(im_array)).resize(original_size, Image.BILINEAR) # Create transparent image new_im = Image.new("RGBA", original_size, (0,0,0,0)) new_im.paste(image, mask=pil_im) return new_im # Load the model @st.cache_resource def load_model(): device = "cuda" if torch.cuda.is_available() else "cpu" net = BriaRMBG() net.load_state_dict(torch.load("model.pth", map_location=device)) net.to(device) net.eval() return net # Streamlit app def main(): st.title("Background Removal App") # Load model model = load_model() # File uploader uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: # Display original image image = convert_to_jpg(uploaded_file, uploaded_file.name) st.image(image, caption="Original Image", use_column_width=True) # Process button if st.button("Remove Background"): # Process image result = remove_background(model, image) # Display result st.image(result, caption="Image with Background Removed", use_column_width=True) # Save button buf = io.BytesIO() result.save(buf, format="PNG") byte_im = buf.getvalue() st.download_button( label="Download Image", data=byte_im, file_name="background_removed.png", mime="image/png" ) if __name__ == "__main__": main()