arikat commited on
Commit
94f42df
1 Parent(s): 06f6357

new pre-processing function

Browse files
Files changed (1) hide show
  1. app.py +15 -32
app.py CHANGED
@@ -201,20 +201,27 @@ def fig_to_img(fig):
201
  img = Image.open(buf)
202
  return img
203
 
204
-
205
- def process_family_sequence(protein_fasta):
206
  lines = protein_fasta.split('\n')
207
-
208
  headers = [line for line in lines if line.startswith('>')]
209
  if len(headers) > 1:
210
- return None, None, None, "Multiple fasta sequences detected. Please upload a fasta file with multiple sequences, otherwise only include one fasta sequence."
211
 
212
  protein_sequence = ''.join(line for line in lines if not line.startswith('>'))
213
-
214
  # Check for invalid characters
215
  valid_characters = set("ACDEFGHIKLMNPQRSTVWYacdefghiklmnpqrstvwy") # the 20 standard amino acids
216
  if not set(protein_sequence).issubset(valid_characters):
217
- return None, None, None, "Invalid protein sequence. It contains characters that are not one of the 20 standard amino acids. Does your sequence contain gaps?"
 
 
 
 
 
 
 
 
218
 
219
  encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt")
220
  input_idsfam = encoded_input["input_ids"]
@@ -263,20 +270,7 @@ def process_family_sequence(protein_fasta):
263
 
264
 
265
  def process_single_sequence(protein_fasta): #, protein_file
266
-
267
- lines = protein_fasta.split('\n')
268
-
269
- headers = [line for line in lines if line.startswith('>')]
270
- if len(headers) > 1:
271
- return None, "Multiple fasta sequences detected. Please upload a fasta file with multiple sequences, otherwise only include one fasta sequence.", None
272
-
273
- protein_sequence = ''.join(line for line in lines if not line.startswith('>'))
274
-
275
- # Check for invalid characters
276
- valid_characters = set("ACDEFGHIKLMNPQRSTVWYacdefghiklmnpqrstvwy") # the 20 standard amino acids
277
- if not set(protein_sequence).issubset(valid_characters):
278
- return None, "Invalid protein sequence. It contains characters that are not one of the 20 standard amino acids. Does your sequence contain gaps?", None
279
-
280
 
281
  encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt")
282
  input_ids = encoded_input["input_ids"]
@@ -360,18 +354,7 @@ def mask_residue(sequence, position):
360
  return sequence[:position] + 'X' + sequence[position+1:]
361
 
362
  def generate_heatmap(protein_fasta):
363
- lines = protein_fasta.strip().split('\n')
364
- header = lines[0]
365
- protein_sequence = ''.join(lines[1:])
366
-
367
- # Check if the header is valid
368
- if not header.startswith('>'):
369
- return None, "Invalid FASTA format. Header should start with '>'.", None
370
-
371
- # Check for invalid characters in the sequence
372
- valid_characters = set("ACDEFGHIKLMNPQRSTVWYacdefghiklmnpqrstvwy")
373
- if not set(protein_sequence).issubset(valid_characters):
374
- return None, "Invalid protein sequence. It contains characters that are not one of the 20 standard amino acids.", None
375
 
376
  # Tokenize and predict for original sequence
377
  encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt")
 
201
  img = Image.open(buf)
202
  return img
203
 
204
+ def preprocess_protein_sequence(protein_fasta):
 
205
  lines = protein_fasta.split('\n')
206
+
207
  headers = [line for line in lines if line.startswith('>')]
208
  if len(headers) > 1:
209
+ return None, "Multiple fasta sequences detected. Please upload a fasta file with only one sequence."
210
 
211
  protein_sequence = ''.join(line for line in lines if not line.startswith('>'))
212
+
213
  # Check for invalid characters
214
  valid_characters = set("ACDEFGHIKLMNPQRSTVWYacdefghiklmnpqrstvwy") # the 20 standard amino acids
215
  if not set(protein_sequence).issubset(valid_characters):
216
+ return None, "Invalid protein sequence. It contains characters that are not one of the 20 standard amino acids. Does your sequence contain gaps?"
217
+
218
+ return protein_sequence, None
219
+
220
+
221
+ def process_family_sequence(protein_fasta):
222
+ protein_sequence, error_msg = preprocess_protein_sequence(protein_fasta)
223
+ if error_msg:
224
+ return None, None, None, error_msg
225
 
226
  encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt")
227
  input_idsfam = encoded_input["input_ids"]
 
270
 
271
 
272
  def process_single_sequence(protein_fasta): #, protein_file
273
+ protein_sequence, error_msg = preprocess_protein_sequence(protein_fasta)
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
  encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt")
276
  input_ids = encoded_input["input_ids"]
 
354
  return sequence[:position] + 'X' + sequence[position+1:]
355
 
356
  def generate_heatmap(protein_fasta):
357
+ protein_sequence, error_msg = preprocess_protein_sequence(protein_fasta)
 
 
 
 
 
 
 
 
 
 
 
358
 
359
  # Tokenize and predict for original sequence
360
  encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt")