import os import tiger import pandas as pd import streamlit as st from pathlib import Path ENTRY_METHODS = dict( manual='Manual entry of single transcript', fasta="Fasta file upload (supports multiple transcripts if they have unique ID's)" ) CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d'] selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model') def load_model(model_name): if model_name == 'Cas9': # Placeholder for Cas9 model loading # TODO: Implement Cas9 model loading logic raise NotImplementedError("Cas9 model loading not implemented yet.") elif model_name == 'Cas12': # Placeholder for Cas12 model loading # TODO: Implement Cas12 model loading logic raise NotImplementedError("Cas12 model loading not implemented yet.") elif model_name == 'Cas13d': # Assuming tiger module is for Cas13 return tiger.load_model() # Assuming there's a load_model function in tiger.py else: raise ValueError(f"Unknown model: {model_name}") @st.cache_data def convert_df(df): # IMPORTANT: Cache the conversion to prevent computation on every rerun return df.to_csv().encode('utf-8') def mode_change_callback(): if st.session_state.mode in {tiger.RUN_MODES['all'], tiger.RUN_MODES['titration']}: # TODO: support titration st.session_state.check_off_targets = False st.session_state.disable_off_target_checkbox = True else: st.session_state.disable_off_target_checkbox = False def progress_update(update_text, percent_complete): with progress.container(): st.write(update_text) st.progress(percent_complete / 100) def initiate_run(): # Placeholder for dynamic module import based on selected_model # model_module = get_model_module(selected_model) # You will need to implement get_model_module function to import the correct module (cas9, cas12, cas13) # ... rest of the initiate_run function ... # initialize state variables st.session_state.transcripts = None st.session_state.input_error = None st.session_state.on_target = None st.session_state.titration = None st.session_state.off_target = None # initialize transcript DataFrame transcripts = pd.DataFrame(columns=[tiger.ID_COL, tiger.SEQ_COL]) # manual entry if st.session_state.entry_method == ENTRY_METHODS['manual']: transcripts = pd.DataFrame({ tiger.ID_COL: ['ManualEntry'], tiger.SEQ_COL: [st.session_state.manual_entry] }).set_index(tiger.ID_COL) # fasta file upload elif st.session_state.entry_method == ENTRY_METHODS['fasta']: if st.session_state.fasta_entry is not None: fasta_path = st.session_state.fasta_entry.name with open(fasta_path, 'w') as f: f.write(st.session_state.fasta_entry.getvalue().decode('utf-8')) transcripts = tiger.load_transcripts([fasta_path], enforce_unique_ids=False) os.remove(fasta_path) # convert to upper case as used by tokenizer transcripts[tiger.SEQ_COL] = transcripts[tiger.SEQ_COL].apply(lambda s: s.upper().replace('U', 'T')) # ensure all transcripts have unique identifiers if transcripts.index.has_duplicates: st.session_state.input_error = "Duplicate transcript ID's detected in fasta file" # ensure all transcripts only contain nucleotides A, C, G, T, and wildcard N elif not all(transcripts[tiger.SEQ_COL].apply(lambda s: set(s).issubset(tiger.NUCLEOTIDE_TOKENS.keys()))): st.session_state.input_error = 'Transcript(s) must only contain upper or lower case A, C, G, and Ts or Us' # ensure all transcripts satisfy length requirements elif any(transcripts[tiger.SEQ_COL].apply(lambda s: len(s) < tiger.TARGET_LEN)): st.session_state.input_error = 'Transcript(s) must be at least {:d} bases.'.format(tiger.TARGET_LEN) # run model if we have any transcripts elif len(transcripts) > 0: st.session_state.transcripts = transcripts if __name__ == '__main__': # app initialization if 'mode' not in st.session_state: st.session_state.mode = tiger.RUN_MODES['all'] st.session_state.disable_off_target_checkbox = True if 'entry_method' not in st.session_state: st.session_state.entry_method = ENTRY_METHODS['manual'] if 'transcripts' not in st.session_state: st.session_state.transcripts = None if 'input_error' not in st.session_state: st.session_state.input_error = None if 'on_target' not in st.session_state: st.session_state.on_target = None if 'titration' not in st.session_state: st.session_state.titration = None if 'off_target' not in st.session_state: st.session_state.off_target = None # title and documentation st.markdown(Path('tiger.md').read_text(), unsafe_allow_html=True) st.divider() # mode selection col1, col2 = st.columns([0.65, 0.35]) with col1: st.radio( label='What do you want to predict?', options=tuple(tiger.RUN_MODES.values()), key='mode', on_change=mode_change_callback, disabled=st.session_state.transcripts is not None, ) with col2: st.checkbox( label='Find off-target effects (slow)', key='check_off_targets', disabled=st.session_state.disable_off_target_checkbox or st.session_state.transcripts is not None ) # transcript entry st.selectbox( label='How would you like to provide transcript(s) of interest?', options=ENTRY_METHODS.values(), key='entry_method', disabled=st.session_state.transcripts is not None ) if st.session_state.entry_method == ENTRY_METHODS['manual']: st.text_input( label='Enter a target transcript:', key='manual_entry', placeholder='Upper or lower case', disabled=st.session_state.transcripts is not None ) elif st.session_state.entry_method == ENTRY_METHODS['fasta']: st.file_uploader( label='Upload a fasta file:', key='fasta_entry', disabled=st.session_state.transcripts is not None ) # let's go! st.button(label='Get predictions!', on_click=initiate_run, disabled=st.session_state.transcripts is not None) progress = st.empty() # input error error = st.empty() if st.session_state.input_error is not None: error.error(st.session_state.input_error, icon="🚨") else: error.empty() # on-target results on_target_results = st.empty() if st.session_state.on_target is not None: with on_target_results.container(): st.write('On-target predictions:', st.session_state.on_target) st.download_button( label='Download on-target predictions', data=convert_df(st.session_state.on_target), file_name='on_target.csv', mime='text/csv' ) else: on_target_results.empty() # titration results titration_results = st.empty() if st.session_state.titration is not None: with titration_results.container(): st.write('Titration predictions:', st.session_state.titration) st.download_button( label='Download titration predictions', data=convert_df(st.session_state.titration), file_name='titration.csv', mime='text/csv' ) else: titration_results.empty() # off-target results off_target_results = st.empty() if st.session_state.off_target is not None: with off_target_results.container(): if len(st.session_state.off_target) > 0: st.write('Off-target predictions:', st.session_state.off_target) st.download_button( label='Download off-target predictions', data=convert_df(st.session_state.off_target), file_name='off_target.csv', mime='text/csv' ) else: st.write('We did not find any off-target effects!') else: off_target_results.empty() # keep trying to run model until we clear inputs (streamlit UI changes can induce race-condition reruns) if st.session_state.transcripts is not None: st.session_state.on_target, st.session_state.titration, st.session_state.off_target = tiger.tiger_exhibit( transcripts=st.session_state.transcripts, mode={v: k for k, v in tiger.RUN_MODES.items()}[st.session_state.mode], check_off_targets=st.session_state.check_off_targets, status_update_fn=progress_update ) st.session_state.transcripts = None st.experimental_rerun()