Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from PIL import Image | |
import io | |
import numpy as np | |
from briarmbg import BriaRMBG | |
from torchvision.transforms.functional import normalize | |
# Reuse the functions from your CLI script | |
def convert_to_jpg(image, image_name): | |
"""Convert PNG to JPG if necessary.""" | |
if image_name.lower().endswith('.png'): | |
img = Image.open(image) | |
# Convert to RGB if the image has an alpha channel | |
if img.mode in ('RGBA', 'LA') or (img.mode == 'P' and 'transparency' in img.info): | |
bg = Image.new("RGB", img.size, (255, 255, 255)) | |
bg.paste(img, mask=img.split()[3] if img.mode == 'RGBA' else img.split()[1]) | |
else: | |
bg = img.convert("RGB") | |
return bg | |
return Image.open(image) | |
def resize_image(image, size=(1024, 1024)): | |
image = image.convert('RGB') | |
image = image.resize(size, Image.BILINEAR) | |
return image | |
def remove_background(model, image): | |
# Save original size | |
original_size = image.size | |
# Convert to JPG if necessary | |
# image = convert_to_jpg(image) | |
# Preprocess the image | |
image_resized = resize_image(image) | |
im_np = np.array(image_resized) | |
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1) | |
im_tensor = torch.unsqueeze(im_tensor,0) | |
im_tensor = torch.divide(im_tensor,255.0) | |
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0]) | |
if torch.cuda.is_available(): | |
im_tensor = im_tensor.cuda() | |
model = model.cuda() | |
# Process the image | |
with torch.no_grad(): | |
result = model(im_tensor) | |
result = torch.squeeze(torch.nn.functional.interpolate(result[0][0], size=image_resized.size, mode='bilinear'), 0) | |
ma = torch.max(result) | |
mi = torch.min(result) | |
result = (result-mi)/(ma-mi) | |
im_array = (result*255).cpu().data.numpy().astype(np.uint8) | |
pil_im = Image.fromarray(np.squeeze(im_array)).resize(original_size, Image.BILINEAR) | |
# Create transparent image | |
new_im = Image.new("RGBA", original_size, (0,0,0,0)) | |
new_im.paste(image, mask=pil_im) | |
return new_im | |
# Load the model | |
def load_model(): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
net = BriaRMBG() | |
net.load_state_dict(torch.load("model.pth", map_location=device)) | |
net.to(device) | |
net.eval() | |
return net | |
# Streamlit app | |
def main(): | |
st.title("Background Removal App") | |
# Load model | |
model = load_model() | |
# File uploader | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
# Display original image | |
image = convert_to_jpg(uploaded_file, uploaded_file.name) | |
st.image(image, caption="Original Image", use_column_width=True) | |
# Process button | |
if st.button("Remove Background"): | |
# Process image | |
result = remove_background(model, image) | |
# Display result | |
st.image(result, caption="Image with Background Removed", use_column_width=True) | |
# Save button | |
buf = io.BytesIO() | |
result.save(buf, format="PNG") | |
byte_im = buf.getvalue() | |
st.download_button( | |
label="Download Image", | |
data=byte_im, | |
file_name="background_removed.png", | |
mime="image/png" | |
) | |
if __name__ == "__main__": | |
main() | |