import time import sys import streamlit as st import string import os from io import StringIO import pdb import json import torch import requests import socket from streamlit_image_select import image_select use_case = {"1":"Image background removal - (upload any picture and remove background)","2":"Masking foreground for downstream inpainting task"} mask_types = { "rgba - makes background white":"rgba", "green - makes the background green":"green", "blur - blurs background":"blur", "map - makes the foreground white and rest black ":"map" } APP_NAME = "hf/salient_object_detection" INFO_URL = "https://www.taskswithcode.com/stats/" TMP_DIR="tmp_dir" TMP_SEED = 1 def get_views(action): ret_val = 0 #return "{:,}".format(ret_val) hostname = socket.gethostname() ip_address = socket.gethostbyname(hostname) if ("view_count" not in st.session_state): try: app_info = {'name': APP_NAME,"action":action,"host":hostname,"ip":ip_address} res = requests.post(INFO_URL, json = app_info).json() print(res) data = res["count"] except: data = 0 ret_val = data st.session_state["view_count"] = data else: ret_val = st.session_state["view_count"] if (action != "init"): app_info = {'name': APP_NAME,"action":action,"host":hostname,"ip":ip_address} res = requests.post(INFO_URL, json = app_info).json() return "{:,}".format(ret_val) def construct_model_info_for_display(model_names): options_arr = [] #markdown_str = f"

Models evaluated ({len(model_names)})
" markdown_str = f"

Model evaluated
" markdown_str += f"

" for node in model_names: options_arr .append(node["name"]) if (node["mark"] == "True"): markdown_str += f"
 • Model: {node['name']}
    Code released by: {node['orig_author']}
    Model info: {node['sota_info']['task']}
" if ("Note" in node): markdown_str += f"
    {node['Note']}link
" markdown_str += "

" markdown_str += "
Note:
• Uploaded files are loaded into non-persistent memory for the duration of the computation. They are not cached
" markdown_str += "

Github code for this app
" return options_arr,markdown_str def init_page(): st.set_page_config(page_title='TWC - State-of-the-art model salient object detection (visually dominant objects in an image)', page_icon="logo.jpg", layout='centered', initial_sidebar_state='auto', menu_items={ 'About': 'This app was created by taskswithcode. http://taskswithcode.com' }) col,pad = st.columns([85,15]) with col: st.image("long_form_logo_with_icon.png") def run_test(config,input_file_name,display_area,uploaded_file,mask_type): global TMP_SEED display_area.text("Processing request...") try: if (uploaded_file is None): file_data = open(input_file_name, "rb") r = requests.post(config["SERVER_ADDRESS"], data={"mask":mask_type}, files={"test":file_data}) else: file_data = uploaded_file.read() file_name = f"{TMP_DIR}/{TMP_SEED}_{str(time.time()).replace('.','_')}_{uploaded_file.name}" TMP_SEED += 1 with open(file_name,"wb") as fp: fp.write(file_data) file_data = open(file_name, "rb") r = requests.post(config["SERVER_ADDRESS"], data={"mask":mask_type}, files={"test":file_data}) os.remove(file_name) print("Servers response:",r.status_code,len(r.content)) if (r.status_code == 200): size = "{:,}".format(len(r.content)) return {"response":r.content,"size":size} else: return {"error":f"API request failed {r.status_code}"} except Exception as e: st.error("Some error occurred during prediction" + str(e)) st.stop() return {"error":f"Exception in performing salient object detection: {str(e)}"} return {} def display_results(results,response_info,mask): main_sent = f"
{response_info}

" body_sent = [] download_data = {} main_sent = main_sent + "\n" + '\n'.join(body_sent) st.markdown(main_sent,unsafe_allow_html=True) st.image(results["response"], caption=f'Output of salient object detection with mask: {mask}') st.session_state["download_ready"] = results["response"] get_views("submit") def init_session(): print("Init session") init_page() st.session_state["model_name"] = "insprynet" st.session_state["download_ready"] = None st.session_state["model_name"] = "ss_test" st.session_state["file_name"] = "default" st.session_state["mask_type"] = "rgba" def app_main(app_mode,example_files,model_name_files,config_file): init_session() with open(example_files) as fp: example_file_names = json.load(fp) with open(model_name_files) as fp: model_names = json.load(fp) with open(config_file) as fp: config = json.load(fp) curr_use_case = use_case[app_mode].split(".")[0] curr_use_case = use_case[app_mode].split(".")[0] st.markdown("
State-of-the-art model for salient object detection
", unsafe_allow_html=True) st.markdown(f"
Use cases for salient object detection
   •  {use_case['1']}
   •  {use_case['2']}
", unsafe_allow_html=True) st.markdown(f"
views: {get_views('init')}
", unsafe_allow_html=True) try: with st.form('twc_form'): step1_line = "Upload an image or choose an example image below" uploaded_file = st.file_uploader(step1_line, type=["png","jpg","jpeg"]) selected_file_name = image_select("Select image", ["twc_samples/sample1.jpg", "twc_samples/sample2.jpg", "twc_samples/sample3.jpg", "twc_samples/sample4.jpg"]) st.write("") mask_type = st.selectbox(label=f'Select type of masking', options = list(dict.keys(mask_types)), index=0, key = "twc_mask_types") mask_type = mask_types[mask_type] st.write("") submit_button = st.form_submit_button('Run') options_arr,markdown_str = construct_model_info_for_display(model_names) input_status_area = st.empty() display_area = st.empty() if submit_button: start = time.time() if uploaded_file is not None: st.session_state["file_name"] = uploaded_file.name else: st.session_state["file_name"] = selected_file_name st.session_state["mask_type"] = mask_type display_area.empty() results = run_test(config,st.session_state["file_name"],display_area,uploaded_file,mask_type) with display_area.container(): if ("error" in results): st.error(results["error"]) else: device = 'GPU' if torch.cuda.is_available() else 'CPU' response_info = f"Computation time on {device}: {time.time() - start:.2f} secs for image size: {results['size']} bytes" display_results(results,response_info,mask_type) #st.json(results) st.download_button( label="Download results as png", data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "", disabled = False if st.session_state["download_ready"] != None else True, file_name= (st.session_state["model_name"] + "_" + st.session_state["mask_type"] + "_" + '_'.join(st.session_state["file_name"].split(".")[:-1]) + ".png").replace("/","_"), mime='image/png', key ="download" ) except Exception as e: st.error("Some error occurred during loading" + str(e)) st.stop() st.markdown(markdown_str, unsafe_allow_html=True) if __name__ == "__main__": app_main("1","sod_app_examples.json","sod_app_models.json","config.json")