cov-snn-app / app.py
smtnkc
Using rank_by_scip
6d42f96
import streamlit as st
import os
import pandas as pd
import plotly.express as px
import numpy as np
from predict import process_target_data, get_average_embedding # Import your function
st.set_page_config(page_title="CoV-SNN", page_icon="🧬")
def main():
st.title("CoV-SNN")
st.markdown("##### Predict viral escape potential of novel SARS-CoV-2 variants in seconds!")
# Read the README.md file
try:
with open("INSTRUCTIONS.md", "r") as readme_file:
readme_text = readme_file.read()
except FileNotFoundError:
readme_text = "INSTRUCTIONS.md file not found."
option = st.radio(
"Select a reference embedding:",
["Omicron", "Other"],
captions=["Use average embedding of Omicron sequences (Pre-generated)", "Generate average embedding of your own sequences (Takes longer)"],)
# File uploader for the reference.csv
reference_file = st.file_uploader("Upload reference sequences. Make sure the CSV file has ``sequence`` column.",
type=["csv"],
disabled=option == "Omicron")
# File uploader for the target.csv
target_file = st.file_uploader("Upload target sequences. Make sure the CSV file has ``accession_id`` and ``sequence`` columns.",
type=["csv"],
disabled = option == "Other" and reference_file is None)
if target_file is not None and (option == "Omicron" or reference_file is not None):
if option == "Omicron":
# Assuming you have a pre-defined average_embedding
average_embedding = np.load("average_omicron_embedding.npy")
print(f"Average Omicron embedding loaded from file with shape {average_embedding.shape}")
else:
with st.spinner('Calculating average embedding...'):
ref_df = pd.read_csv(reference_file)
average_embedding = get_average_embedding(ref_df)
with st.spinner('Predicting escape potentials...'):
# Read the uploaded CSV file into a DataFrame
target_dataset = pd.read_csv(target_file)
# Process the target dataset
results_df = process_target_data(average_embedding, target_dataset)
# Reverse the rank_sc_sp by subtracting it from the maximum rank value plus one
results_df['Escape Potential'] = results_df['rank_by_scip'].max() + 1 - results_df['rank_by_scip']
# Create scatter plot with manual color assignment
fig = px.scatter(
results_df.applymap(lambda x: round(x, 6) if isinstance(x, (int, float)) else x),
x="log10(gr)",
y="log10(sc)",
labels={"log10(gr)": "log10(gr)", "log10(sc)": "log10(sc)"},
title="CoV-SNN Results",
hover_name="accession_id",
color="Escape Potential",
color_continuous_scale=["green", "yellow", "red"],
hover_data={
"log10(sp)": True, # display log10(sp)
"log10(sc)": True, # display log10(sc)
"log10(ip)": True, # display log10(ip)
#"log10(gr)": True, # display log10(gr)
"sp": False, # display actual sp
"sc": False, # display actual sc
"ip": False, # display actual ip
#"gr": False, # display actual gr
"rank_by_sc": True, # display rank by sc
"rank_by_sp": True, # display rank by sp
"rank_by_ip": True, # display rank by ip
"rank_by_scsp": True, # display rank by scsp
"rank_by_scip": True, # display rank by scip
#"rank_by_scgr": True, # display rank by scgr
"Escape Potential": False
},
)
# Hide the colorbar ticks and labels
fig.update_coloraxes(
colorbar=dict(
title=None,
tickvals=[],
ticktext=[],
y=0.5,
len=0.7
)
)
# Hide the legend
#fig.update_layout(showlegend=False)
# add your rotated title via annotations
fig.update_layout(
margin=dict(r=110),
annotations=[
dict(
text="Escape Potential",
font_size=14,
textangle=270,
showarrow=False,
xref="paper",
yref="paper",
x=1.14,
y=0.5
)
]
)
# Display the plot in Streamlit
st.plotly_chart(fig, theme="streamlit", border=True, use_container_width=True, border_color="black")
# Display the results as a DataFrame
st.dataframe(results_df[["accession_id", "log10(sc)", "log10(sp)", "log10(ip)",
"rank_by_sc", "rank_by_sp", "rank_by_ip", "rank_by_scsp", "rank_by_scip"
]], hide_index=True)
# Display the README.md file
st.markdown(readme_text)
if __name__ == "__main__":
main()