CRISPRTool / app.py
supercat666's picture
fix
d51aeae
raw
history blame
31.2 kB
import os
import tiger
import cas9on
import cas9off
import cas12
import pandas as pd
import streamlit as st
import plotly.graph_objs as go
import numpy as np
from pathlib import Path
import zipfile
import io
import gtracks
import subprocess
# title and documentation
st.markdown(Path('crisprTool.md').read_text(), unsafe_allow_html=True)
st.divider()
CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d']
selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
cas9on_path = 'cas9_model/on-cla.h5'
cas12_path = 'cas12_model/Seq_deepCpf1_weights.h5'
#plot functions
def generate_coolbox_plot(bigwig_path, region, output_image_path):
frame = CoolBox()
frame += BigWig(bigwig_path)
frame.plot(region, savefig=output_image_path)
def generate_pygenometracks_plot(bigwig_file_path, region, output_image_path):
# Define the configuration for pyGenomeTracks
tracks = """
[bigwig]
file = {}
height = 4
color = blue
min_value = 0
max_value = 10
""".format(bigwig_file_path)
# Write the configuration to a temporary INI file
config_file_path = "pygenometracks.ini"
with open(config_file_path, 'w') as configfile:
configfile.write(tracks)
# Define the region to plot
region_dict = {'chrom': region.split(':')[0],
'start': int(region.split(':')[1].split('-')[0]),
'end': int(region.split(':')[1].split('-')[1])}
# Generate the plot
plot_tracks(tracks_file=config_file_path,
region=region_dict,
out_file_name=output_image_path)
@st.cache_data
def convert_df(df):
# IMPORTANT: Cache the conversion to prevent computation on every rerun
return df.to_csv().encode('utf-8')
def mode_change_callback():
if st.session_state.mode in {tiger.RUN_MODES['all'], tiger.RUN_MODES['titration']}: # TODO: support titration
st.session_state.check_off_targets = False
st.session_state.disable_off_target_checkbox = True
else:
st.session_state.disable_off_target_checkbox = False
def progress_update(update_text, percent_complete):
with progress.container():
st.write(update_text)
st.progress(percent_complete / 100)
def initiate_run():
# initialize state variables
st.session_state.transcripts = None
st.session_state.input_error = None
st.session_state.on_target = None
st.session_state.titration = None
st.session_state.off_target = None
# initialize transcript DataFrame
transcripts = pd.DataFrame(columns=[tiger.ID_COL, tiger.SEQ_COL])
# manual entry
if st.session_state.entry_method == ENTRY_METHODS['manual']:
transcripts = pd.DataFrame({
tiger.ID_COL: ['ManualEntry'],
tiger.SEQ_COL: [st.session_state.manual_entry]
}).set_index(tiger.ID_COL)
# fasta file upload
elif st.session_state.entry_method == ENTRY_METHODS['fasta']:
if st.session_state.fasta_entry is not None:
fasta_path = st.session_state.fasta_entry.name
with open(fasta_path, 'w') as f:
f.write(st.session_state.fasta_entry.getvalue().decode('utf-8'))
transcripts = tiger.load_transcripts([fasta_path], enforce_unique_ids=False)
os.remove(fasta_path)
# convert to upper case as used by tokenizer
transcripts[tiger.SEQ_COL] = transcripts[tiger.SEQ_COL].apply(lambda s: s.upper().replace('U', 'T'))
# ensure all transcripts have unique identifiers
if transcripts.index.has_duplicates:
st.session_state.input_error = "Duplicate transcript ID's detected in fasta file"
# ensure all transcripts only contain nucleotides A, C, G, T, and wildcard N
elif not all(transcripts[tiger.SEQ_COL].apply(lambda s: set(s).issubset(tiger.NUCLEOTIDE_TOKENS.keys()))):
st.session_state.input_error = 'Transcript(s) must only contain upper or lower case A, C, G, and Ts or Us'
# ensure all transcripts satisfy length requirements
elif any(transcripts[tiger.SEQ_COL].apply(lambda s: len(s) < tiger.TARGET_LEN)):
st.session_state.input_error = 'Transcript(s) must be at least {:d} bases.'.format(tiger.TARGET_LEN)
# run model if we have any transcripts
elif len(transcripts) > 0:
st.session_state.transcripts = transcripts
def parse_gene_annotations(file_path):
gene_dict = {}
with open(file_path, 'r') as file:
headers = file.readline().strip().split('\t') # Assuming tab-delimited file
symbol_idx = headers.index('Approved symbol') # Find index of 'Approved symbol'
ensembl_idx = headers.index('Ensembl gene ID') # Find index of 'Ensembl gene ID'
for line in file:
values = line.strip().split('\t')
# Ensure we have enough values and add mapping from symbol to Ensembl ID
if len(values) > max(symbol_idx, ensembl_idx):
gene_dict[values[symbol_idx]] = values[ensembl_idx]
return gene_dict
# Replace 'your_annotation_file.txt' with the path to your actual gene annotation file
gene_annotations = parse_gene_annotations('Human_genes_HUGO_02242024_annotation.txt')
gene_symbol_list = list(gene_annotations.keys()) # List of gene symbols for the autocomplete feature
# Check if the selected model is Cas9
if selected_model == 'Cas9':
# Use a radio button to select enzymes, making sure only one can be selected at a time
target_selection = st.radio(
"Select either on-target or off-target:",
('on-target', 'off-target'),
key='target_selection'
)
if 'current_gene_symbol' not in st.session_state:
st.session_state['current_gene_symbol'] = ""
# Define a function to clean up old files
def clean_up_old_files(gene_symbol):
genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
bed_file_path = f"{gene_symbol}_crispr_targets.bed"
csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
for path in [genbank_file_path, bed_file_path, csv_file_path]:
if os.path.exists(path):
os.remove(path)
# Gene symbol entry with autocomplete-like feature
gene_symbol = st.selectbox('Enter a Gene Symbol:', [''] + gene_symbol_list, key='gene_symbol',
format_func=lambda x: x if x else "")
# Handle gene symbol change and file cleanup
if gene_symbol != st.session_state['current_gene_symbol'] and gene_symbol:
if st.session_state['current_gene_symbol']:
# Clean up files only if a different gene symbol is entered and a previous symbol exists
clean_up_old_files(st.session_state['current_gene_symbol'])
# Update the session state with the new gene symbol
st.session_state['current_gene_symbol'] = gene_symbol
if target_selection == 'on-target':
# Prediction button
predict_button = st.button('Predict on-target')
if 'exons' not in st.session_state:
st.session_state['exons'] = []
# Process predictions
if predict_button and gene_symbol:
with st.spinner('Predicting... Please wait'):
predictions, gene_sequence, exons = cas9on.process_gene(gene_symbol, cas9on_path)
sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10]
st.session_state['on_target_results'] = sorted_predictions
st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
st.session_state['exons'] = exons # Store exon data
# Notify the user once the process is completed successfully.
st.success('Prediction completed!')
st.session_state['prediction_made'] = True
if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') # Get Ensembl ID or default to 'Unknown'
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("**Genome**")
st.markdown("Homo sapiens")
with col2:
st.markdown("**Gene**")
st.markdown(f"{gene_symbol} : {ensembl_id} (primary)")
with col3:
st.markdown("**Nuclease**")
st.markdown("SpCas9")
# Include "Target" in the DataFrame's columns
try:
df = pd.DataFrame(st.session_state['on_target_results'],
columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Exon", "Target", "gRNA", "Prediction"])
st.dataframe(df)
except ValueError as e:
st.error(f"DataFrame creation error: {e}")
# Optionally print or log the problematic data for debugging:
print(st.session_state['on_target_results'])
# Initialize Plotly figure
fig = go.Figure()
EXON_BASE = 0 # Base position for exons and CDS on the Y axis
EXON_HEIGHT = 0.02 # How 'tall' the exon markers should appear
# Plot Exons as small markers on the X-axis
for exon in st.session_state['exons']:
exon_start, exon_end = exon['start'], exon['end']
fig.add_trace(go.Bar(
x=[(exon_start + exon_end) / 2],
y=[EXON_HEIGHT],
width=[exon_end - exon_start],
base=EXON_BASE,
marker_color='rgba(128, 0, 128, 0.5)',
name='Exon'
))
VERTICAL_GAP = 0.2 # Gap between different ranks
# Define max and min Y values based on strand and rank
MAX_STRAND_Y = 0.1 # Maximum Y value for positive strand results
MIN_STRAND_Y = -0.1 # Minimum Y value for negative strand results
# Iterate over top 5 sorted predictions to create the plot
for i, prediction in enumerate(st.session_state['on_target_results'][:5], start=1): # Only top 5
chrom, start, end, strand, transcript, exon, target, gRNA, prediction_score = prediction
midpoint = (int(start) + int(end)) / 2
# Vertical position based on rank, modified by strand
y_value = (MAX_STRAND_Y - (i - 1) * VERTICAL_GAP) if strand == '1' or strand == '+' else (
MIN_STRAND_Y + (i - 1) * VERTICAL_GAP)
fig.add_trace(go.Scatter(
x=[midpoint],
y=[y_value],
mode='markers+text',
marker=dict(symbol='triangle-up' if strand == '1' or strand == '+' else 'triangle-down',
size=12),
text=f"Rank: {i}", # Text label
hoverinfo='text',
hovertext=f"Rank: {i}<br>Chromosome: {chrom}<br>Target Sequence: {target}<br>gRNA: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == '1' or strand == '+' else '-'}<br>Transcript: {transcript}<br>Prediction: {prediction_score:.4f}",
))
# Update layout for clarity and interaction
fig.update_layout(
title='Top 5 gRNA Sequences by Prediction Score',
xaxis_title='Genomic Position',
yaxis_title='Strand',
yaxis=dict(tickvals=[MAX_STRAND_Y, MIN_STRAND_Y], ticktext=['+', '-']),
showlegend=False,
hovermode='x unified',
)
# Display the plot
st.plotly_chart(fig)
if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
gene_symbol = st.session_state['current_gene_symbol']
gene_sequence = st.session_state['gene_sequence']
# Define file paths
genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
bed_file_path = f"{gene_symbol}_crispr_targets.bed"
csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
plot_image_path = f"{gene_symbol}_gtracks_plot.png"
# Generate files
cas9on.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path)
cas9on.create_bed_file_from_df(df, bed_file_path)
cas9on.create_csv_from_df(df, csv_file_path)
# Prepare an in-memory buffer for the ZIP file
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
# For each file, add it to the ZIP file
zip_file.write(genbank_file_path)
zip_file.write(bed_file_path)
zip_file.write(csv_file_path)
# Important: move the cursor to the beginning of the BytesIO buffer before reading it
zip_buffer.seek(0)
# Specify the region you want to visualize
min_start = df['Start Pos'].min()
max_end = df['End Pos'].max()
chromosome = df['Chr'].mode()[0] # Assumes most common chromosome is the target
region = f"{chromosome}:{min_start}-{max_end}"
# Generate the pyGenomeTracks plot
gtracks_command = f"gtracks {region} {bed_file_path} {plot_image_path}"
subprocess.run(gtracks_command, shell=True)
st.image(plot_image_path)
# Display the download button for the ZIP file
st.download_button(
label="Download GenBank, BED, CSV files as ZIP",
data=zip_buffer.getvalue(),
file_name=f"{gene_symbol}_files.zip",
mime="application/zip"
)
elif target_selection == 'off-target':
ENTRY_METHODS = dict(
manual='Manual entry of target sequence',
txt="txt file upload"
)
if __name__ == '__main__':
# app initialization for Cas9 off-target
if 'target_sequence' not in st.session_state:
st.session_state.target_sequence = None
if 'input_error' not in st.session_state:
st.session_state.input_error = None
if 'off_target_results' not in st.session_state:
st.session_state.off_target_results = None
# target sequence entry
st.selectbox(
label='How would you like to provide target sequences?',
options=ENTRY_METHODS.values(),
key='entry_method',
disabled=st.session_state.target_sequence is not None
)
if st.session_state.entry_method == ENTRY_METHODS['manual']:
st.text_input(
label='Enter on/off sequences:',
key='manual_entry',
placeholder='Enter on/off sequences like:GGGTGGGGGGAGTTTGCTCCAGG,AGGTGGGGTGA_TTTGCTCCAGG',
disabled=st.session_state.target_sequence is not None
)
elif st.session_state.entry_method == ENTRY_METHODS['txt']:
st.file_uploader(
label='Upload a txt file:',
key='txt_entry',
disabled=st.session_state.target_sequence is not None
)
# prediction button
if st.button('Predict off-target'):
if st.session_state.entry_method == ENTRY_METHODS['manual']:
user_input = st.session_state.manual_entry
if user_input: # Check if user_input is not empty
predictions = cas9off.process_input_and_predict(user_input, input_type='manual')
elif st.session_state.entry_method == ENTRY_METHODS['txt']:
uploaded_file = st.session_state.txt_entry
if uploaded_file is not None:
# Read the uploaded file content
file_content = uploaded_file.getvalue().decode("utf-8")
predictions = cas9off.process_input_and_predict(file_content, input_type='manual')
st.session_state.off_target_results = predictions
else:
predictions = None
progress = st.empty()
# input error display
error = st.empty()
if st.session_state.input_error is not None:
error.error(st.session_state.input_error, icon="🚨")
else:
error.empty()
# off-target results display
off_target_results = st.empty()
if st.session_state.off_target_results is not None:
with off_target_results.container():
if len(st.session_state.off_target_results) > 0:
st.write('Off-target predictions:', st.session_state.off_target_results)
st.download_button(
label='Download off-target predictions',
data=convert_df(st.session_state.off_target_results),
file_name='off_target_results.csv',
mime='text/csv'
)
else:
st.write('No significant off-target effects detected!')
else:
off_target_results.empty()
# running the CRISPR-Net model for off-target predictions
if st.session_state.target_sequence is not None:
st.session_state.off_target_results = cas9off.predict_off_targets(
target_sequence=st.session_state.target_sequence,
status_update_fn=progress_update
)
st.session_state.target_sequence = None
st.experimental_rerun()
elif selected_model == 'Cas12':
# Gene symbol entry with autocomplete-like feature
gene_symbol = st.selectbox('Enter a Gene Symbol:', [''] + gene_symbol_list, key='gene_symbol',
format_func=lambda x: x if x else "")
# Initialize the current_gene_symbol in the session state if it doesn't exist
if 'current_gene_symbol' not in st.session_state:
st.session_state['current_gene_symbol'] = ""
# Prediction button
predict_button = st.button('Predict on-target')
# Function to clean up old files
def clean_up_old_files(gene_symbol):
genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
bed_file_path = f"{gene_symbol}_crispr_targets.bed"
csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
for path in [genbank_file_path, bed_file_path, csv_file_path]:
if os.path.exists(path):
os.remove(path)
# Clean up files if a new gene symbol is entered
if st.session_state['current_gene_symbol'] and gene_symbol != st.session_state['current_gene_symbol']:
clean_up_old_files(st.session_state['current_gene_symbol'])
# Process predictions
if predict_button and gene_symbol:
# Update the current gene symbol
st.session_state['current_gene_symbol'] = gene_symbol
# Run the prediction process
with st.spinner('Predicting... Please wait'):
predictions, gene_sequence = cas12.process_gene(gene_symbol,cas12_path)
sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10]
st.session_state['on_target_results'] = sorted_predictions
st.success('Prediction completed!')
# Visualization and file generation
if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
df = pd.DataFrame(st.session_state['on_target_results'],
columns=["Gene ID", "Start Pos", "End Pos", "Strand", "Target", "gRNA", "Prediction"])
st.dataframe(df)
# Now create a Plotly plot with the sorted_predictions
fig = go.Figure()
# Initialize the y position for the positive and negative strands
positive_strand_y = 0.1
negative_strand_y = -0.1
# Use an offset to spread gRNA sequences vertically
offset = 0.05
# Iterate over the sorted predictions to create the plot
for i, prediction in enumerate(sorted_predictions, start=1):
# Extract data for plotting and convert start and end to integers
chrom, start, end, strand, target, gRNA, pred_score = prediction
start, end = int(start), int(end)
midpoint = (start + end) / 2
# Set the y-value and arrow symbol based on the strand
if strand == '1':
y_value = positive_strand_y
arrow_symbol = 'triangle-right'
# Increment the y-value for the next positive strand gRNA
positive_strand_y += offset
else:
y_value = negative_strand_y
arrow_symbol = 'triangle-left'
# Decrement the y-value for the next negative strand gRNA
negative_strand_y -= offset
fig.add_trace(go.Scatter(
x=[midpoint],
y=[y_value], # Use the y_value set above for the strand
mode='markers+text',
marker=dict(symbol=arrow_symbol, size=10),
name=f"gRNA: {gRNA}",
text=f"Rank: {i}", # Place text at the marker
hoverinfo='text',
hovertext=f"Rank: {i}<br>Chromosome: {chrom}<br>Target Sequence: {target}<br>gRNA: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == 1 else '-'}<br>Prediction Score: {pred_score:.4f}",
))
# Update the layout of the plot
fig.update_layout(
title='Top 10 gRNA Sequences by Prediction Score',
xaxis_title='Genomic Position',
yaxis=dict(
title='Strand',
showgrid=True, # Show horizontal gridlines for clarity
zeroline=True, # Show a line at y=0 to represent the axis
zerolinecolor='Black',
zerolinewidth=2,
tickvals=[positive_strand_y, negative_strand_y],
ticktext=['+ Strand', '- Strand']
),
showlegend=False # Hide the legend if it's not necessary
)
# Display the plot
st.plotly_chart(fig)
# Ensure gene_sequence is not empty before generating files
if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
gene_symbol = st.session_state['current_gene_symbol']
gene_sequence = st.session_state['gene_sequence']
# Define file paths
genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
bed_file_path = f"{gene_symbol}_crispr_targets.bed"
csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
# Generate files
cas12.generate_genbank_file_from_data(df, gene_sequence, gene_symbol, genbank_file_path)
cas12.generate_bed_file_from_data(df, bed_file_path)
cas12.create_csv_from_df(df, csv_file_path)
# Prepare an in-memory buffer for the ZIP file
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
# For each file, add it to the ZIP file
zip_file.write(genbank_file_path, arcname=genbank_file_path.split('/')[-1])
zip_file.write(bed_file_path, arcname=bed_file_path.split('/')[-1])
zip_file.write(csv_file_path, arcname=csv_file_path.split('/')[-1])
# Important: move the cursor to the beginning of the BytesIO buffer before reading it
zip_buffer.seek(0)
# Display the download button for the ZIP file
st.download_button(
label="Download genbank,.bed,csv files as ZIP",
data=zip_buffer.getvalue(),
file_name=f"{gene_symbol}_files.zip",
mime="application/zip"
)
elif selected_model == 'Cas13d':
ENTRY_METHODS = dict(
manual='Manual entry of single transcript',
fasta="Fasta file upload (supports multiple transcripts if they have unique ID's)"
)
if __name__ == '__main__':
# app initialization
if 'mode' not in st.session_state:
st.session_state.mode = tiger.RUN_MODES['all']
st.session_state.disable_off_target_checkbox = True
if 'entry_method' not in st.session_state:
st.session_state.entry_method = ENTRY_METHODS['manual']
if 'transcripts' not in st.session_state:
st.session_state.transcripts = None
if 'input_error' not in st.session_state:
st.session_state.input_error = None
if 'on_target' not in st.session_state:
st.session_state.on_target = None
if 'titration' not in st.session_state:
st.session_state.titration = None
if 'off_target' not in st.session_state:
st.session_state.off_target = None
# mode selection
col1, col2 = st.columns([0.65, 0.35])
with col1:
st.radio(
label='What do you want to predict?',
options=tuple(tiger.RUN_MODES.values()),
key='mode',
on_change=mode_change_callback,
disabled=st.session_state.transcripts is not None,
)
with col2:
st.checkbox(
label='Find off-target effects (slow)',
key='check_off_targets',
disabled=st.session_state.disable_off_target_checkbox or st.session_state.transcripts is not None
)
# transcript entry
st.selectbox(
label='How would you like to provide transcript(s) of interest?',
options=ENTRY_METHODS.values(),
key='entry_method',
disabled=st.session_state.transcripts is not None
)
if st.session_state.entry_method == ENTRY_METHODS['manual']:
st.text_input(
label='Enter a target transcript:',
key='manual_entry',
placeholder='Upper or lower case',
disabled=st.session_state.transcripts is not None
)
elif st.session_state.entry_method == ENTRY_METHODS['fasta']:
st.file_uploader(
label='Upload a fasta file:',
key='fasta_entry',
disabled=st.session_state.transcripts is not None
)
# let's go!
st.button(label='Get predictions!', on_click=initiate_run, disabled=st.session_state.transcripts is not None)
progress = st.empty()
# input error
error = st.empty()
if st.session_state.input_error is not None:
error.error(st.session_state.input_error, icon="🚨")
else:
error.empty()
# on-target results
on_target_results = st.empty()
if st.session_state.on_target is not None:
with on_target_results.container():
st.write('On-target predictions:', st.session_state.on_target)
st.download_button(
label='Download on-target predictions',
data=convert_df(st.session_state.on_target),
file_name='on_target.csv',
mime='text/csv'
)
else:
on_target_results.empty()
# titration results
titration_results = st.empty()
if st.session_state.titration is not None:
with titration_results.container():
st.write('Titration predictions:', st.session_state.titration)
st.download_button(
label='Download titration predictions',
data=convert_df(st.session_state.titration),
file_name='titration.csv',
mime='text/csv'
)
else:
titration_results.empty()
# off-target results
off_target_results = st.empty()
if st.session_state.off_target is not None:
with off_target_results.container():
if len(st.session_state.off_target) > 0:
st.write('Off-target predictions:', st.session_state.off_target)
st.download_button(
label='Download off-target predictions',
data=convert_df(st.session_state.off_target),
file_name='off_target.csv',
mime='text/csv'
)
else:
st.write('We did not find any off-target effects!')
else:
off_target_results.empty()
# keep trying to run model until we clear inputs (streamlit UI changes can induce race-condition reruns)
if st.session_state.transcripts is not None:
st.session_state.on_target, st.session_state.titration, st.session_state.off_target = tiger.tiger_exhibit(
transcripts=st.session_state.transcripts,
mode={v: k for k, v in tiger.RUN_MODES.items()}[st.session_state.mode],
check_off_targets=st.session_state.check_off_targets,
status_update_fn=progress_update
)
st.session_state.transcripts = None
st.experimental_rerun()