supercat666 commited on
Commit
ce4236e
1 Parent(s): a5afc1a

fixed cas9on

Browse files
Files changed (2) hide show
  1. app.py +37 -2
  2. cas9on.py +97 -14
app.py CHANGED
@@ -13,7 +13,7 @@ st.divider()
13
  CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d']
14
 
15
  selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
16
-
17
 
18
  @st.cache_data
19
  def convert_df(df):
@@ -92,8 +92,43 @@ if selected_model == 'Cas9':
92
 
93
  # Actions based on the selected enzyme
94
  if target_selection == 'on-target':
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  elif target_selection == 'off-target':
99
  ENTRY_METHODS = dict(
 
13
  CRISPR_MODELS = ['Cas9', 'Cas12', 'Cas13d']
14
 
15
  selected_model = st.selectbox('Select CRISPR model:', CRISPR_MODELS, key='selected_model')
16
+ cas9on_path = '/cas9_model/on-cla.h5'
17
 
18
  @st.cache_data
19
  def convert_df(df):
 
92
 
93
  # Actions based on the selected enzyme
94
  if target_selection == 'on-target':
95
+ # app initialization for Cas9 on-target
96
+ if 'gene_symbol' not in st.session_state:
97
+ st.session_state.gene_symbol = None
98
+ if 'on_target_results' not in st.session_state:
99
+ st.session_state.on_target_results = None
100
+
101
+ # Gene symbol entry
102
+ st.text_input(
103
+ label='Enter a Gene Symbol:',
104
+ key='gene_symbol_entry',
105
+ placeholder='e.g., BRCA1'
106
+ )
107
 
108
+ # prediction button
109
+ if st.button('Predict on-target'):
110
+ gene_symbol = st.session_state.gene_symbol_entry
111
+ if gene_symbol: # Check if gene_symbol is not empty
112
+ predictions = cas9on.process_gene(gene_symbol, cas9on_path)
113
+ st.session_state.on_target_results = predictions[:10] # Store only first 10 for display
114
+
115
+ # on-target results display
116
+ on_target_results = st.empty()
117
+ if st.session_state.on_target_results is not None:
118
+ with on_target_results.container():
119
+ if len(st.session_state.on_target_results) > 0:
120
+ st.write('On-target predictions:', st.session_state.on_target_results)
121
+ full_predictions = cas9on.process_gene(gene_symbol, cas9on_path) # Get full predictions for download
122
+ st.download_button(
123
+ label='Download on-target predictions',
124
+ data=cas9on.convert_df(full_predictions),
125
+ file_name='on_target_results.csv',
126
+ mime='text/csv'
127
+ )
128
+ else:
129
+ st.write('No significant on-target effects detected!')
130
+ else:
131
+ on_target_results.empty()
132
 
133
  elif target_selection == 'off-target':
134
  ENTRY_METHODS = dict(
cas9on.py CHANGED
@@ -1,8 +1,11 @@
 
1
  import tensorflow as tf
2
  import pandas as pd
3
  import numpy as np
4
  from operator import add
5
  from functools import reduce
 
 
6
 
7
  # configure GPUs
8
  for gpu in tf.config.list_physical_devices('GPU'):
@@ -18,7 +21,6 @@ ntmap = {'A': (1, 0, 0, 0),
18
  }
19
  epimap = {'A': 1, 'N': 0}
20
 
21
-
22
  def get_seqcode(seq):
23
  return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape(
24
  (1, len(seq), -1))
@@ -54,13 +56,9 @@ class Episgt:
54
  return x
55
 
56
  from keras.models import load_model
57
-
58
  class DCModelOntar:
59
  def __init__(self, ontar_model_dir, is_reg=False):
60
- if is_reg:
61
- self.model = load_model(ontar_model_dir)
62
- else:
63
- self.model = load_model(ontar_model_dir)
64
 
65
  def ontar_predict(self, x, channel_first=True):
66
  if channel_first:
@@ -68,11 +66,96 @@ class DCModelOntar:
68
  yp = self.model.predict(x)
69
  return yp.ravel()
70
 
71
- def predict():
72
- file_path = 'eg_cls_on_target.episgt'
73
- input_data = Episgt(file_path, num_epi_features=4, with_y=True)
74
- x, y = input_data.get_dataset()
75
- x = np.expand_dims(x, axis=2) # shape(x) = [100, 8, 1, 23]
76
- dcModel = DCModelOntar('on-cla.h5')
77
- predicted_on_target = dcModel.ontar_predict(x)
78
- return predicted_on_target
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
  import tensorflow as tf
3
  import pandas as pd
4
  import numpy as np
5
  from operator import add
6
  from functools import reduce
7
+ from keras.models import load_model
8
+ import random
9
 
10
  # configure GPUs
11
  for gpu in tf.config.list_physical_devices('GPU'):
 
21
  }
22
  epimap = {'A': 1, 'N': 0}
23
 
 
24
  def get_seqcode(seq):
25
  return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape(
26
  (1, len(seq), -1))
 
56
  return x
57
 
58
  from keras.models import load_model
 
59
  class DCModelOntar:
60
  def __init__(self, ontar_model_dir, is_reg=False):
61
+ self.model = load_model(ontar_model_dir)
 
 
 
62
 
63
  def ontar_predict(self, x, channel_first=True):
64
  if channel_first:
 
66
  yp = self.model.predict(x)
67
  return yp.ravel()
68
 
69
+ # Function to generate random epigenetic data
70
+ def generate_random_epigenetic_data(length):
71
+ return ''.join(random.choice('AN') for _ in range(length))
72
+
73
+ # Function to predict on-target efficiency and format output
74
+ def format_prediction_output(gRNA_sites, gene_id, model_path):
75
+ dcModel = DCModelOntar(model_path)
76
+ formatted_data = []
77
+
78
+ for gRNA in gRNA_sites:
79
+ # Encode the gRNA sequence
80
+ encoded_seq = get_seqcode(gRNA).reshape(-1,4,1,23)
81
+ #encoded_seq = np.expand_dims(encoded_seq, axis=2) # Adjust the shape for the model
82
+
83
+ # Generate random epigenetic features (as placeholders)
84
+ ctcf = get_epicode(generate_random_epigenetic_data(len(gRNA))).reshape(-1,1,1,23)
85
+ dnase = get_epicode(generate_random_epigenetic_data(len(gRNA))).reshape(-1,1,1,23)
86
+ h3k4me3 = get_epicode(generate_random_epigenetic_data(len(gRNA))).reshape(-1,1,1,23)
87
+ rrbs = get_epicode(generate_random_epigenetic_data(len(gRNA))).reshape(-1,1,1,23)
88
+
89
+ # Predict on-target efficiency using the model
90
+ input = np.concatenate((encoded_seq, ctcf, dnase, h3k4me3, rrbs), axis=1)
91
+ prediction = dcModel.ontar_predict(input)
92
+
93
+ # Format output
94
+ formatted_data.append([gene_id, "start_pos", "end_pos", "strand", gRNA, ctcf, dnase, h3k4me3, rrbs, prediction[0]])
95
+
96
+ return formatted_data
97
+
98
+ def fetch_ensembl_transcripts(gene_symbol):
99
+ url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
100
+ response = requests.get(url)
101
+ if response.status_code == 200:
102
+ gene_data = response.json()
103
+ if 'Transcript' in gene_data:
104
+ return gene_data['Transcript']
105
+ else:
106
+ print("No transcripts found for gene:", gene_symbol)
107
+ return None
108
+ else:
109
+ print(f"Error fetching gene data from Ensembl: {response.text}")
110
+ return None
111
+
112
+ def fetch_ensembl_sequence(transcript_id):
113
+ url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
114
+ response = requests.get(url)
115
+ if response.status_code == 200:
116
+ sequence_data = response.json()
117
+ if 'seq' in sequence_data:
118
+ return sequence_data['seq']
119
+ else:
120
+ print("No sequence found for transcript:", transcript_id)
121
+ return None
122
+ else:
123
+ print(f"Error fetching sequence data from Ensembl: {response.text}")
124
+ return None
125
+
126
+ def find_crispr_targets(sequence, pam="NGG", target_length=20):
127
+ targets = []
128
+ len_sequence = len(sequence)
129
+
130
+ for i in range(len_sequence - len(pam) + 1):
131
+ if sequence[i + 1:i + 3] == pam[1:]:
132
+ if i >= target_length:
133
+ target_seq = sequence[i - target_length:i + 3]
134
+ targets.append(target_seq)
135
+
136
+ return targets
137
+
138
+
139
+ def process_gene(gene_symbol, model_path):
140
+ transcripts = fetch_ensembl_transcripts(gene_symbol)
141
+ all_data = []
142
+
143
+ if transcripts:
144
+ for transcript in transcripts:
145
+ transcript_id = transcript['id']
146
+ gene_sequence = fetch_ensembl_sequence(transcript_id)
147
+ if gene_sequence:
148
+ gRNA_sites = find_crispr_targets(gene_sequence)
149
+ if gRNA_sites:
150
+ formatted_data = format_prediction_output(gRNA_sites, transcript_id, model_path)
151
+ all_data.extend(formatted_data)
152
+
153
+ return all_data
154
+
155
+
156
+ # Function to save results as CSV
157
+ def save_to_csv(data, filename="crispr_results.csv"):
158
+ df = pd.DataFrame(data,
159
+ columns=["Gene ID", "Start Pos", "End Pos", "Strand", "gRNA", "CTCF", "Dnase", "H3K4me3", "RRBS",
160
+ "Prediction"])
161
+ df.to_csv(filename, index=False)