supercat666 commited on
Commit
242350b
1 Parent(s): 2b3514d
Files changed (2) hide show
  1. app.py +47 -22
  2. cas9on.py +30 -2
app.py CHANGED
@@ -144,13 +144,20 @@ if selected_model == 'Cas9':
144
  # Prediction button
145
  predict_button = st.button('Predict on-target')
146
 
 
 
 
 
 
147
  # Process predictions
148
  if predict_button and gene_symbol:
149
  with st.spinner('Predicting... Please wait'):
150
- predictions, gene_sequence = cas9on.process_gene(gene_symbol, cas9on_path)
151
  sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10]
152
  st.session_state['on_target_results'] = sorted_predictions
153
  st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
 
 
154
 
155
  # Notify the user once the process is completed successfully.
156
  st.success('Prediction completed!')
@@ -162,44 +169,64 @@ if selected_model == 'Cas9':
162
  df = pd.DataFrame(st.session_state['on_target_results'],
163
  columns=["Gene ID", "Start Pos", "End Pos", "Strand", "Target", "gRNA", "Prediction"])
164
  st.dataframe(df)
165
- # Now create a Plotly plot with the sorted_predictions
 
166
  fig = go.Figure()
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  # Initialize the y position for the positive and negative strands
169
  positive_strand_y = 0.1
170
  negative_strand_y = -0.1
171
-
172
- # Use an offset to spread gRNA sequences vertically
173
- offset = 0.05
174
 
175
  # Iterate over the sorted predictions to create the plot
176
- for i, prediction in enumerate(sorted_predictions, start=1):
177
- # Extract data for plotting and convert start and end to integers
178
  chrom, start, end, strand, target, gRNA, pred_score = prediction
179
  start, end = int(start), int(end)
180
  midpoint = (start + end) / 2
181
 
182
- # Set the y-value and arrow symbol based on the strand
183
- if strand == '1':
184
  y_value = positive_strand_y
185
  arrow_symbol = 'triangle-right'
186
- # Increment the y-value for the next positive strand gRNA
187
  positive_strand_y += offset
188
- else:
189
  y_value = negative_strand_y
190
  arrow_symbol = 'triangle-left'
191
- # Decrement the y-value for the next negative strand gRNA
192
  negative_strand_y -= offset
193
 
194
  fig.add_trace(go.Scatter(
195
  x=[midpoint],
196
- y=[y_value], # Use the y_value set above for the strand
197
  mode='markers+text',
198
  marker=dict(symbol=arrow_symbol, size=10),
199
  name=f"gRNA: {gRNA}",
200
- text=f"Rank: {i}", # Place text at the marker
201
  hoverinfo='text',
202
- hovertext=f"Rank: {i}<br>Chromosome: {chrom}<br>Target Sequence: {target}<br>gRNA: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == 1 else '-'}<br>Prediction Score: {pred_score:.4f}",
203
  ))
204
 
205
  # Update the layout of the plot
@@ -208,14 +235,12 @@ if selected_model == 'Cas9':
208
  xaxis_title='Genomic Position',
209
  yaxis=dict(
210
  title='Strand',
211
- showgrid=True, # Show horizontal gridlines for clarity
212
- zeroline=True, # Show a line at y=0 to represent the axis
213
- zerolinecolor='Black',
214
- zerolinewidth=2,
215
- tickvals=[positive_strand_y, negative_strand_y],
216
- ticktext=['+ Strand', '- Strand']
217
  ),
218
- showlegend=False # Hide the legend if it's not necessary
219
  )
220
 
221
  # Display the plot
 
144
  # Prediction button
145
  predict_button = st.button('Predict on-target')
146
 
147
+ if 'exons' not in st.session_state:
148
+ st.session_state['exons'] = []
149
+ if 'cds' not in st.session_state:
150
+ st.session_state['cds'] = []
151
+
152
  # Process predictions
153
  if predict_button and gene_symbol:
154
  with st.spinner('Predicting... Please wait'):
155
+ predictions, gene_sequence, exons, cds = cas9on.process_gene(gene_symbol, cas9on_path)
156
  sorted_predictions = sorted(predictions, key=lambda x: x[-1], reverse=True)[:10]
157
  st.session_state['on_target_results'] = sorted_predictions
158
  st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
159
+ st.session_state['exons'] = exons # Store exon data
160
+ st.session_state['cds'] = cds # Store CDS data
161
 
162
  # Notify the user once the process is completed successfully.
163
  st.success('Prediction completed!')
 
169
  df = pd.DataFrame(st.session_state['on_target_results'],
170
  columns=["Gene ID", "Start Pos", "End Pos", "Strand", "Target", "gRNA", "Prediction"])
171
  st.dataframe(df)
172
+ # Now create a Plotly plot with the sorted_predictions# Initialize Plotly figure
173
+ # Initialize Plotly figure
174
  fig = go.Figure()
175
 
176
+ # Plot Exons as horizontal lines or rectangles
177
+ exon_y = 0.2 # Adjust this as needed
178
+ for exon in st.session_state['exons']:
179
+ exon_start, exon_end = int(exon['start']), int(exon['end'])
180
+ fig.add_trace(go.Scatter(
181
+ x=[exon_start, exon_end],
182
+ y=[exon_y, exon_y],
183
+ mode='lines',
184
+ line=dict(color='purple', width=10), # Adjust styling as needed
185
+ name='Exon'
186
+ ))
187
+
188
+ # Plot CDS as horizontal lines or rectangles
189
+ cds_y = 0.3 # Adjust this as needed
190
+ for cds in st.session_state['cds']:
191
+ cds_start, cds_end = int(cds['start']), int(cds['end'])
192
+ fig.add_trace(go.Scatter(
193
+ x=[cds_start, cds_end],
194
+ y=[cds_y, cds_y],
195
+ mode='lines',
196
+ line=dict(color='blue', width=10), # Adjust styling as needed
197
+ name='CDS'
198
+ ))
199
+
200
+ # Plot gRNAs using triangles to indicate direction
201
  # Initialize the y position for the positive and negative strands
202
  positive_strand_y = 0.1
203
  negative_strand_y = -0.1
204
+ offset = 0.05 # Use an offset to spread gRNA sequences vertically
 
 
205
 
206
  # Iterate over the sorted predictions to create the plot
207
+ for i, prediction in enumerate(st.session_state['on_target_results'], start=1):
 
208
  chrom, start, end, strand, target, gRNA, pred_score = prediction
209
  start, end = int(start), int(end)
210
  midpoint = (start + end) / 2
211
 
212
+ if strand == '1': # Positive strand
 
213
  y_value = positive_strand_y
214
  arrow_symbol = 'triangle-right'
 
215
  positive_strand_y += offset
216
+ else: # Negative strand
217
  y_value = negative_strand_y
218
  arrow_symbol = 'triangle-left'
 
219
  negative_strand_y -= offset
220
 
221
  fig.add_trace(go.Scatter(
222
  x=[midpoint],
223
+ y=[y_value],
224
  mode='markers+text',
225
  marker=dict(symbol=arrow_symbol, size=10),
226
  name=f"gRNA: {gRNA}",
227
+ text=f"Rank: {i}",
228
  hoverinfo='text',
229
+ hovertext=f"Rank: {i}<br>Chromosome: {chrom}<br>Target Sequence: {target}<br>gRNA: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == '1' else '-'}<br>Prediction Score: {pred_score:.4f}",
230
  ))
231
 
232
  # Update the layout of the plot
 
235
  xaxis_title='Genomic Position',
236
  yaxis=dict(
237
  title='Strand',
238
+ showgrid=True,
239
+ zeroline=False,
240
+ tickvals=[positive_strand_y, negative_strand_y, exon_y, cds_y],
241
+ ticktext=['+ Strand gRNAs', '- Strand gRNAs', 'Exons', 'CDS']
 
 
242
  ),
243
+ showlegend=True
244
  )
245
 
246
  # Display the plot
cas9on.py CHANGED
@@ -104,6 +104,7 @@ def find_crispr_targets(sequence, chr, start, strand, pam="NGG", target_length=2
104
 
105
  return targets
106
 
 
107
  def process_gene(gene_symbol, model_path):
108
  transcripts = fetch_ensembl_transcripts(gene_symbol)
109
  all_data = []
@@ -118,14 +119,41 @@ def process_gene(gene_symbol, model_path):
118
  # Fetch the sequence here and concatenate if multiple transcripts
119
  gene_sequence += fetch_ensembl_sequence(transcript_id) or ''
120
 
 
 
 
 
 
 
 
121
  if gene_sequence:
122
  gRNA_sites = find_crispr_targets(gene_sequence, chr, start, strand)
123
  if gRNA_sites:
124
  formatted_data = format_prediction_output(gRNA_sites, model_path)
125
  all_data.extend(formatted_data)
126
 
127
- # Return both the data and the fetched sequence
128
- return all_data, gene_sequence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  def create_genbank_features(formatted_data):
131
  features = []
 
104
 
105
  return targets
106
 
107
+
108
  def process_gene(gene_symbol, model_path):
109
  transcripts = fetch_ensembl_transcripts(gene_symbol)
110
  all_data = []
 
119
  # Fetch the sequence here and concatenate if multiple transcripts
120
  gene_sequence += fetch_ensembl_sequence(transcript_id) or ''
121
 
122
+ # Fetch exon and CDS information
123
+ exons = fetch_ensembl_exons(transcript_id)
124
+ cds_list = fetch_ensembl_cds(transcript_id)
125
+
126
+ # You might want to do something specific with exons and CDS information here
127
+ # For example, store them, print them, or include them in your analysis
128
+
129
  if gene_sequence:
130
  gRNA_sites = find_crispr_targets(gene_sequence, chr, start, strand)
131
  if gRNA_sites:
132
  formatted_data = format_prediction_output(gRNA_sites, model_path)
133
  all_data.extend(formatted_data)
134
 
135
+ # Return the data, fetched sequence, and possibly exon/CDS data
136
+ return all_data, gene_sequence, exons, cds_list
137
+
138
+ def fetch_ensembl_exons(transcript_id):
139
+ """Fetch exon information for a given transcript from Ensembl."""
140
+ url = f"https://rest.ensembl.org/overlap/id/{transcript_id}?feature=exon;content-type=application/json"
141
+ response = requests.get(url)
142
+ if response.status_code == 200:
143
+ return response.json() # Returns a list of exons for the transcript
144
+ else:
145
+ print(f"Error fetching exon data from Ensembl for transcript {transcript_id}: {response.text}")
146
+ return None
147
+
148
+ def fetch_ensembl_cds(transcript_id):
149
+ """Fetch coding sequence (CDS) information for a given transcript from Ensembl."""
150
+ url = f"https://rest.ensembl.org/overlap/id/{transcript_id}?feature=cds;content-type=application/json"
151
+ response = requests.get(url)
152
+ if response.status_code == 200:
153
+ return response.json() # Returns a list of CDS regions for the transcript
154
+ else:
155
+ print(f"Error fetching CDS data from Ensembl for transcript {transcript_id}: {response.text}")
156
+ return None
157
 
158
  def create_genbank_features(formatted_data):
159
  features = []