import streamlit as st import os import pandas as pd import plotly.express as px import numpy as np from predict import process_target_data, get_average_embedding # Import your function st.set_page_config(page_title="CoV-SNN", page_icon="🧬") def main(): st.title("CoV-SNN") st.markdown("##### Predict viral escape potential of novel SARS-CoV-2 variants in seconds!") # Read the README.md file try: with open("INSTRUCTIONS.md", "r") as readme_file: readme_text = readme_file.read() except FileNotFoundError: readme_text = "INSTRUCTIONS.md file not found." option = st.radio( "Select a reference embedding:", ["Omicron", "Other"], captions=["Use average embedding of Omicron sequences (Pre-generated)", "Generate average embedding of your own sequences (Takes longer)"],) # File uploader for the reference.csv reference_file = st.file_uploader("Upload reference sequences. Make sure the CSV file has ``sequence`` column.", type=["csv"], disabled=option == "Omicron") # File uploader for the target.csv target_file = st.file_uploader("Upload target sequences. Make sure the CSV file has ``accession_id`` and ``sequence`` columns.", type=["csv"], disabled = option == "Other" and reference_file is None) if target_file is not None and (option == "Omicron" or reference_file is not None): if option == "Omicron": # Assuming you have a pre-defined average_embedding average_embedding = np.load("average_omicron_embedding.npy") print(f"Average Omicron embedding loaded from file with shape {average_embedding.shape}") else: with st.spinner('Calculating average embedding...'): ref_df = pd.read_csv(reference_file) average_embedding = get_average_embedding(ref_df) with st.spinner('Predicting escape potentials...'): # Read the uploaded CSV file into a DataFrame target_dataset = pd.read_csv(target_file) # Process the target dataset results_df = process_target_data(average_embedding, target_dataset) # Reverse the rank_sc_sp by subtracting it from the maximum rank value plus one results_df['Escape Potential'] = results_df['rank_by_scip'].max() + 1 - results_df['rank_by_scip'] # Create scatter plot with manual color assignment fig = px.scatter( results_df.applymap(lambda x: round(x, 6) if isinstance(x, (int, float)) else x), x="log10(gr)", y="log10(sc)", labels={"log10(gr)": "log10(gr)", "log10(sc)": "log10(sc)"}, title="CoV-SNN Results", hover_name="accession_id", color="Escape Potential", color_continuous_scale=["green", "yellow", "red"], hover_data={ "log10(sp)": True, # display log10(sp) "log10(sc)": True, # display log10(sc) "log10(ip)": True, # display log10(ip) #"log10(gr)": True, # display log10(gr) "sp": False, # display actual sp "sc": False, # display actual sc "ip": False, # display actual ip #"gr": False, # display actual gr "rank_by_sc": True, # display rank by sc "rank_by_sp": True, # display rank by sp "rank_by_ip": True, # display rank by ip "rank_by_scsp": True, # display rank by scsp "rank_by_scip": True, # display rank by scip #"rank_by_scgr": True, # display rank by scgr "Escape Potential": False }, ) # Hide the colorbar ticks and labels fig.update_coloraxes( colorbar=dict( title=None, tickvals=[], ticktext=[], y=0.5, len=0.7 ) ) # Hide the legend #fig.update_layout(showlegend=False) # add your rotated title via annotations fig.update_layout( margin=dict(r=110), annotations=[ dict( text="Escape Potential", font_size=14, textangle=270, showarrow=False, xref="paper", yref="paper", x=1.14, y=0.5 ) ] ) # Display the plot in Streamlit st.plotly_chart(fig, theme="streamlit", border=True, use_container_width=True, border_color="black") # Display the results as a DataFrame st.dataframe(results_df[["accession_id", "log10(sc)", "log10(sp)", "log10(ip)", "rank_by_sc", "rank_by_sp", "rank_by_ip", "rank_by_scsp", "rank_by_scip" ]], hide_index=True) # Display the README.md file st.markdown(readme_text) if __name__ == "__main__": main()