Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import pandas as pd | |
import plotly.express as px | |
import numpy as np | |
from predict import process_target_data, get_average_embedding # Import your function | |
st.set_page_config(page_title="CoV-SNN", page_icon="🧬") | |
def main(): | |
st.title("CoV-SNN") | |
st.markdown("##### Predict viral escape potential of novel SARS-CoV-2 variants in seconds!") | |
# Read the README.md file | |
try: | |
with open("INSTRUCTIONS.md", "r") as readme_file: | |
readme_text = readme_file.read() | |
except FileNotFoundError: | |
readme_text = "INSTRUCTIONS.md file not found." | |
option = st.radio( | |
"Select a reference embedding:", | |
["Omicron", "Other"], | |
captions=["Use average embedding of Omicron sequences (Pre-generated)", "Generate average embedding of your own sequences (Takes longer)"],) | |
# File uploader for the reference.csv | |
reference_file = st.file_uploader("Upload reference sequences. Make sure the CSV file has ``sequence`` column.", | |
type=["csv"], | |
disabled=option == "Omicron") | |
# File uploader for the target.csv | |
target_file = st.file_uploader("Upload target sequences. Make sure the CSV file has ``accession_id`` and ``sequence`` columns.", | |
type=["csv"], | |
disabled = option == "Other" and reference_file is None) | |
if target_file is not None and (option == "Omicron" or reference_file is not None): | |
if option == "Omicron": | |
# Assuming you have a pre-defined average_embedding | |
average_embedding = np.load("average_omicron_embedding.npy") | |
print(f"Average Omicron embedding loaded from file with shape {average_embedding.shape}") | |
else: | |
with st.spinner('Calculating average embedding...'): | |
ref_df = pd.read_csv(reference_file) | |
average_embedding = get_average_embedding(ref_df) | |
with st.spinner('Predicting escape potentials...'): | |
# Read the uploaded CSV file into a DataFrame | |
target_dataset = pd.read_csv(target_file) | |
# Process the target dataset | |
results_df = process_target_data(average_embedding, target_dataset) | |
# Reverse the rank_sc_sp by subtracting it from the maximum rank value plus one | |
results_df['Escape Potential'] = results_df['rank_by_scip'].max() + 1 - results_df['rank_by_scip'] | |
# Create scatter plot with manual color assignment | |
fig = px.scatter( | |
results_df.applymap(lambda x: round(x, 6) if isinstance(x, (int, float)) else x), | |
x="log10(gr)", | |
y="log10(sc)", | |
labels={"log10(gr)": "log10(gr)", "log10(sc)": "log10(sc)"}, | |
title="CoV-SNN Results", | |
hover_name="accession_id", | |
color="Escape Potential", | |
color_continuous_scale=["green", "yellow", "red"], | |
hover_data={ | |
"log10(sp)": True, # display log10(sp) | |
"log10(sc)": True, # display log10(sc) | |
"log10(ip)": True, # display log10(ip) | |
#"log10(gr)": True, # display log10(gr) | |
"sp": False, # display actual sp | |
"sc": False, # display actual sc | |
"ip": False, # display actual ip | |
#"gr": False, # display actual gr | |
"rank_by_sc": True, # display rank by sc | |
"rank_by_sp": True, # display rank by sp | |
"rank_by_ip": True, # display rank by ip | |
"rank_by_scsp": True, # display rank by scsp | |
"rank_by_scip": True, # display rank by scip | |
#"rank_by_scgr": True, # display rank by scgr | |
"Escape Potential": False | |
}, | |
) | |
# Hide the colorbar ticks and labels | |
fig.update_coloraxes( | |
colorbar=dict( | |
title=None, | |
tickvals=[], | |
ticktext=[], | |
y=0.5, | |
len=0.7 | |
) | |
) | |
# Hide the legend | |
#fig.update_layout(showlegend=False) | |
# add your rotated title via annotations | |
fig.update_layout( | |
margin=dict(r=110), | |
annotations=[ | |
dict( | |
text="Escape Potential", | |
font_size=14, | |
textangle=270, | |
showarrow=False, | |
xref="paper", | |
yref="paper", | |
x=1.14, | |
y=0.5 | |
) | |
] | |
) | |
# Display the plot in Streamlit | |
st.plotly_chart(fig, theme="streamlit", border=True, use_container_width=True, border_color="black") | |
# Display the results as a DataFrame | |
st.dataframe(results_df[["accession_id", "log10(sc)", "log10(sp)", "log10(ip)", | |
"rank_by_sc", "rank_by_sp", "rank_by_ip", "rank_by_scsp", "rank_by_scip" | |
]], hide_index=True) | |
# Display the README.md file | |
st.markdown(readme_text) | |
if __name__ == "__main__": | |
main() | |