supercat666 commited on
Commit
16e89c0
1 Parent(s): 9ad0b46

fix app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -65
app.py CHANGED
@@ -4,6 +4,7 @@ import cas9att
4
  import cas9attvcf
5
  import cas9off
6
  import cas12
 
7
  import pandas as pd
8
  import streamlit as st
9
  import plotly.graph_objs as go
@@ -184,10 +185,7 @@ if selected_model == 'Cas9':
184
  if predict_button and gene_symbol:
185
  model_choice = st.radio("mutation or not:", ('normal', 'mutation'))
186
  with st.spinner('Predicting... Please wait'):
187
- if model_choice == 'cas9attvcf':
188
- predictions, gene_sequence, exons = cas9attvcf.process_gene(gene_symbol, cas9att_path)
189
- else:
190
- predictions, gene_sequence, exons = cas9att.process_gene(gene_symbol, cas9att_path)
191
 
192
  sorted_predictions = sorted(predictions)[:10]
193
  st.session_state['on_target_results'] = sorted_predictions
@@ -437,83 +435,98 @@ elif selected_model == 'Cas12':
437
 
438
  # Process predictions
439
  if predict_button and gene_symbol:
440
- # Update the current gene symbol
441
- st.session_state['current_gene_symbol'] = gene_symbol
442
-
443
- # Run the prediction process
444
  with st.spinner('Predicting... Please wait'):
445
- predictions, gene_sequence, exons = cas12.process_gene(gene_symbol,cas12_path)
446
- sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10]
 
447
  st.session_state['on_target_results'] = sorted_predictions
448
  st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
449
  st.session_state['exons'] = exons # Store exon data
 
 
450
  st.success('Prediction completed!')
 
451
 
452
- # Visualization and file generation
453
  if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
454
- df = pd.DataFrame(st.session_state['on_target_results'],
455
- columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Exon", "Target", "gRNA", "Prediction"])
456
- st.dataframe(df)
457
- # Now create a Plotly plot with the sorted_predictions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  fig = go.Figure()
459
 
460
- # Initialize the y position for the positive and negative strands
461
- positive_strand_y = 0.1
462
- negative_strand_y = -0.1
463
-
464
- # Use an offset to spread gRNA sequences vertically
465
- offset = 0.05
466
-
467
- # Iterate over the sorted predictions to create the plot
468
- for i, prediction in enumerate(sorted_predictions, start=1):
469
- # Extract data for plotting and convert start and end to integers
470
- chrom, start, end, strand, target, gRNA, pred_score = prediction
471
- start, end = int(start), int(end)
472
- midpoint = (start + end) / 2
473
-
474
- # Set the y-value and arrow symbol based on the strand
475
- if strand == '1':
476
- y_value = positive_strand_y
477
- arrow_symbol = 'triangle-right'
478
- # Increment the y-value for the next positive strand gRNA
479
- positive_strand_y += offset
480
- else:
481
- y_value = negative_strand_y
482
- arrow_symbol = 'triangle-left'
483
- # Decrement the y-value for the next negative strand gRNA
484
- negative_strand_y -= offset
 
 
 
 
485
 
486
  fig.add_trace(go.Scatter(
487
  x=[midpoint],
488
- y=[y_value], # Use the y_value set above for the strand
489
  mode='markers+text',
490
- marker=dict(symbol=arrow_symbol, size=10),
491
- name=f"gRNA: {gRNA}",
492
- text=f"Rank: {i}", # Place text at the marker
493
  hoverinfo='text',
494
- hovertext=f"Rank: {i}<br>Chromosome: {chrom}<br>Target Sequence: {target}<br>gRNA: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == 1 else '-'}<br>Prediction Score: {pred_score:.4f}",
495
  ))
496
 
497
- # Update the layout of the plot
498
  fig.update_layout(
499
- title='Top 10 gRNA Sequences by Prediction Score',
500
  xaxis_title='Genomic Position',
501
- yaxis=dict(
502
- title='Strand',
503
- showgrid=True, # Show horizontal gridlines for clarity
504
- zeroline=True, # Show a line at y=0 to represent the axis
505
- zerolinecolor='Black',
506
- zerolinewidth=2,
507
- tickvals=[positive_strand_y, negative_strand_y],
508
- ticktext=['+ Strand', '- Strand']
509
- ),
510
- showlegend=False # Hide the legend if it's not necessary
511
  )
512
 
513
  # Display the plot
514
  st.plotly_chart(fig)
515
 
516
- # Ensure gene_sequence is not empty before generating files
517
  if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
518
  gene_symbol = st.session_state['current_gene_symbol']
519
  gene_sequence = st.session_state['gene_sequence']
@@ -522,26 +535,38 @@ elif selected_model == 'Cas12':
522
  genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
523
  bed_file_path = f"{gene_symbol}_crispr_targets.bed"
524
  csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
 
525
 
526
  # Generate files
527
- cas12.generate_genbank_file_from_data(df, gene_sequence, gene_symbol, genbank_file_path)
528
- cas12.generate_bed_file_from_data(df, bed_file_path)
529
- cas12.create_csv_from_df(df, csv_file_path)
530
 
531
  # Prepare an in-memory buffer for the ZIP file
532
  zip_buffer = io.BytesIO()
533
  with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
534
  # For each file, add it to the ZIP file
535
- zip_file.write(genbank_file_path, arcname=genbank_file_path.split('/')[-1])
536
- zip_file.write(bed_file_path, arcname=bed_file_path.split('/')[-1])
537
- zip_file.write(csv_file_path, arcname=csv_file_path.split('/')[-1])
538
 
539
  # Important: move the cursor to the beginning of the BytesIO buffer before reading it
540
  zip_buffer.seek(0)
541
 
 
 
 
 
 
 
 
 
 
 
 
542
  # Display the download button for the ZIP file
543
  st.download_button(
544
- label="Download genbank,.bed,csv files as ZIP",
545
  data=zip_buffer.getvalue(),
546
  file_name=f"{gene_symbol}_files.zip",
547
  mime="application/zip"
 
4
  import cas9attvcf
5
  import cas9off
6
  import cas12
7
+ import cas12lstm
8
  import pandas as pd
9
  import streamlit as st
10
  import plotly.graph_objs as go
 
185
  if predict_button and gene_symbol:
186
  model_choice = st.radio("mutation or not:", ('normal', 'mutation'))
187
  with st.spinner('Predicting... Please wait'):
188
+ predictions, gene_sequence, exons = cas9att.process_gene(gene_symbol, cas9att_path)
 
 
 
189
 
190
  sorted_predictions = sorted(predictions)[:10]
191
  st.session_state['on_target_results'] = sorted_predictions
 
435
 
436
  # Process predictions
437
  if predict_button and gene_symbol:
 
 
 
 
438
  with st.spinner('Predicting... Please wait'):
439
+ predictions, gene_sequence, exons = cas12lstm.process_gene(gene_symbol, cas9att_path)
440
+
441
+ sorted_predictions = sorted(predictions)[:10]
442
  st.session_state['on_target_results'] = sorted_predictions
443
  st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
444
  st.session_state['exons'] = exons # Store exon data
445
+
446
+ # Notify the user once the process is completed successfully.
447
  st.success('Prediction completed!')
448
+ st.session_state['prediction_made'] = True
449
 
 
450
  if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
451
+ ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') # Get Ensembl ID or default to 'Unknown'
452
+ col1, col2, col3 = st.columns(3)
453
+ with col1:
454
+ st.markdown("**Genome**")
455
+ st.markdown("Homo sapiens")
456
+ with col2:
457
+ st.markdown("**Gene**")
458
+ st.markdown(f"{gene_symbol} : {ensembl_id} (primary)")
459
+ with col3:
460
+ st.markdown("**Nuclease**")
461
+ st.markdown("SpCas9")
462
+ # Include "Target" in the DataFrame's columns
463
+ try:
464
+ df = pd.DataFrame(st.session_state['on_target_results'],
465
+ columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Exon", "Target",
466
+ "gRNA", "Prediction"])
467
+ st.dataframe(df)
468
+ except ValueError as e:
469
+ st.error(f"DataFrame creation error: {e}")
470
+ # Optionally print or log the problematic data for debugging:
471
+ print(st.session_state['on_target_results'])
472
+
473
+ # Initialize Plotly figure
474
  fig = go.Figure()
475
 
476
+ EXON_BASE = 0 # Base position for exons and CDS on the Y axis
477
+ EXON_HEIGHT = 0.02 # How 'tall' the exon markers should appear
478
+
479
+ # Plot Exons as small markers on the X-axis
480
+ for exon in st.session_state['exons']:
481
+ exon_start, exon_end = exon['start'], exon['end']
482
+ fig.add_trace(go.Bar(
483
+ x=[(exon_start + exon_end) / 2],
484
+ y=[EXON_HEIGHT],
485
+ width=[exon_end - exon_start],
486
+ base=EXON_BASE,
487
+ marker_color='rgba(128, 0, 128, 0.5)',
488
+ name='Exon'
489
+ ))
490
+
491
+ VERTICAL_GAP = 0.2 # Gap between different ranks
492
+
493
+ # Define max and min Y values based on strand and rank
494
+ MAX_STRAND_Y = 0.1 # Maximum Y value for positive strand results
495
+ MIN_STRAND_Y = -0.1 # Minimum Y value for negative strand results
496
+
497
+ # Iterate over top 5 sorted predictions to create the plot
498
+ for i, prediction in enumerate(st.session_state['on_target_results'][:5], start=1): # Only top 5
499
+ chrom, start, end, strand, transcript, exon, target, gRNA, prediction_score = prediction
500
+ midpoint = (int(start) + int(end)) / 2
501
+
502
+ # Vertical position based on rank, modified by strand
503
+ y_value = (MAX_STRAND_Y - (i - 1) * VERTICAL_GAP) if strand == '1' or strand == '+' else (
504
+ MIN_STRAND_Y + (i - 1) * VERTICAL_GAP)
505
 
506
  fig.add_trace(go.Scatter(
507
  x=[midpoint],
508
+ y=[y_value],
509
  mode='markers+text',
510
+ marker=dict(symbol='triangle-up' if strand == '1' or strand == '+' else 'triangle-down',
511
+ size=12),
512
+ text=f"Rank: {i}", # Text label
513
  hoverinfo='text',
514
+ hovertext=f"Rank: {i}<br>Chromosome: {chrom}<br>Target Sequence: {target}<br>gRNA: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == '1' or strand == '+' else '-'}<br>Transcript: {transcript}<br>Prediction: {prediction_score:.4f}",
515
  ))
516
 
517
+ # Update layout for clarity and interaction
518
  fig.update_layout(
519
+ title='Top 5 gRNA Sequences by Prediction Score',
520
  xaxis_title='Genomic Position',
521
+ yaxis_title='Strand',
522
+ yaxis=dict(tickvals=[MAX_STRAND_Y, MIN_STRAND_Y], ticktext=['+', '-']),
523
+ showlegend=False,
524
+ hovermode='x unified',
 
 
 
 
 
 
525
  )
526
 
527
  # Display the plot
528
  st.plotly_chart(fig)
529
 
 
530
  if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
531
  gene_symbol = st.session_state['current_gene_symbol']
532
  gene_sequence = st.session_state['gene_sequence']
 
535
  genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
536
  bed_file_path = f"{gene_symbol}_crispr_targets.bed"
537
  csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
538
+ plot_image_path = f"{gene_symbol}_gtracks_plot.png"
539
 
540
  # Generate files
541
+ cas12lstm.generate_genbank_file_from_df(df, gene_sequence, gene_symbol, genbank_file_path)
542
+ cas12lstm.create_bed_file_from_df(df, bed_file_path)
543
+ cas12lstm.create_csv_from_df(df, csv_file_path)
544
 
545
  # Prepare an in-memory buffer for the ZIP file
546
  zip_buffer = io.BytesIO()
547
  with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
548
  # For each file, add it to the ZIP file
549
+ zip_file.write(genbank_file_path)
550
+ zip_file.write(bed_file_path)
551
+ zip_file.write(csv_file_path)
552
 
553
  # Important: move the cursor to the beginning of the BytesIO buffer before reading it
554
  zip_buffer.seek(0)
555
 
556
+ # Specify the region you want to visualize
557
+ min_start = df['Start Pos'].min()
558
+ max_end = df['End Pos'].max()
559
+ chromosome = df['Chr'].mode()[0] # Assumes most common chromosome is the target
560
+ region = f"{chromosome}:{min_start}-{max_end}"
561
+
562
+ # Generate the pyGenomeTracks plot
563
+ gtracks_command = f"gtracks {region} {bed_file_path} {plot_image_path}"
564
+ subprocess.run(gtracks_command, shell=True)
565
+ st.image(plot_image_path)
566
+
567
  # Display the download button for the ZIP file
568
  st.download_button(
569
+ label="Download GenBank, BED, CSV files as ZIP",
570
  data=zip_buffer.getvalue(),
571
  file_name=f"{gene_symbol}_files.zip",
572
  mime="application/zip"