rishi9440's picture
Update app.py
9d62619
raw
history blame
2.58 kB
import streamlit as st
import os
from datetime import datetime
from PIL import Image
from io import BytesIO
from src.utils import change_background, matte
from src.st_style import apply_prod_style
apply_prod_style(st) # NOTE: Uncomment this for production!
def V_SPACE(lines):
for _ in range(lines):
st.write(' ')
def image_download_button(pil_image, filename: str, fmt: str, label="Download"):
if fmt not in ["jpg", "png"]:
raise Exception(f"Unknown image format (Available: {fmt} - case sensitive)")
pil_format = "JPEG" if fmt == "jpg" else "PNG"
file_format = "jpg" if fmt == "jpg" else "png"
mime = "image/jpeg/?target=external" if fmt == "jpg" else "image/png/?target=external"
buf = BytesIO()
pil_image.save(buf, format=pil_format)
return st.download_button(
label=label,
data=buf.getvalue(),
file_name=f'{filename}.{file_format}',
mime=mime,
on_click=open_in_new_tab,
args=(buf.getvalue(),)
)
def open_in_new_tab(file_content):
file_ = BytesIO(file_content)
file_.seek(0)
b64_img = base64.b64encode(file_.read()).decode()
href = f'<a href="data:image/png;base64,{b64_img}" target="_blank" rel="noopener noreferrer">Open image in new tab</a>'
st.sidebar.markdown(href, unsafe_allow_html=True)
uploaded_file = st.file_uploader(
label="Upload your photo here",
accept_multiple_files=False, type=["png", "jpg", "jpeg"],
)
if uploaded_file is not None:
in_mode = "Transparent (PNG)"
in_submit = st.button("Submit")
if uploaded_file is not None and in_submit:
img_input = Image.open(uploaded_file)
with st.spinner("AI is doing magic to your photo. Please wait..."):
hexmap = {
"Transparent (PNG)": "#000000",
"Black": "#000000",
"White": "#FFFFFF",
"Green": "#22EE22",
"Red": "#EE2222",
"Blue": "#2222EE",
}
alpha = 0.0 if in_mode == "Transparent (PNG)" else 1.0
img_matte = matte(img_input)
img_output = change_background(img_input, img_matte, background_alpha=alpha, background_hex=hexmap[in_mode])
with st.expander("Success!", expanded=True):
st.image(img_output)
uploaded_name = os.path.splitext(uploaded_file.name)[0]
image_download_button(
pil_image=img_output,
filename=uploaded_name,
fmt="png"
)