mawairon commited on
Commit
b1eabde
·
verified ·
1 Parent(s): 7ad39e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -7
app.py CHANGED
@@ -81,25 +81,44 @@ def load_model():
81
  model, tokenizer = load_model()
82
 
83
  def analyze_dna(username, password, sequence):
84
- valid_usernames = os.getenv('USERNAME')
 
85
  env_password = os.getenv('PASSWORD')
86
 
87
  if username not in valid_usernames or password != env_password:
88
  return {"error": "Invalid username or password"}, ""
89
 
90
  try:
91
-
92
- sequence = sequence.replace(" ", "")
 
93
  if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
94
  return {"error": "Sequence contains invalid characters"}, ""
95
 
96
  if len(sequence) < 300:
97
  return {"error": "Sequence needs to be at least 300 nucleotides long"}, ""
98
 
99
- inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
100
- logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
101
-
102
- probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
104
  top_5_probs = [probabilities[i] for i in top_5_indices]
105
  top_5_labels = [int_to_label[i] for i in top_5_indices]
 
81
  model, tokenizer = load_model()
82
 
83
  def analyze_dna(username, password, sequence):
84
+
85
+ valid_usernames = os.getenv('USERNAME').split(',')
86
  env_password = os.getenv('PASSWORD')
87
 
88
  if username not in valid_usernames or password != env_password:
89
  return {"error": "Invalid username or password"}, ""
90
 
91
  try:
92
+ # Remove all whitespace characters
93
+ sequence = sequence.replace(" ", "").replace("\n", "").replace("\t", "").replace("\r", "")
94
+
95
  if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
96
  return {"error": "Sequence contains invalid characters"}, ""
97
 
98
  if len(sequence) < 300:
99
  return {"error": "Sequence needs to be at least 300 nucleotides long"}, ""
100
 
101
+ def get_logits(seq):
102
+ inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
103
+ with torch.no_grad():
104
+ logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
105
+ return logits
106
+
107
+ if len(sequence) > 3000:
108
+ num_shifts = len(sequence) // 1000
109
+ logits_sum = None
110
+ for i in range(num_shifts):
111
+ shifted_sequence = sequence[i*1000:] + sequence[:i*1000]
112
+ logits = get_logits(shifted_sequence)
113
+ if logits_sum is None:
114
+ logits_sum = logits
115
+ else:
116
+ logits_sum += logits
117
+ logits_avg = logits_sum / num_shifts
118
+ else:
119
+ logits_avg = get_logits(sequence)
120
+
121
+ probabilities = torch.nn.functional.softmax(logits_avg, dim=-1).squeeze().tolist()
122
  top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
123
  top_5_probs = [probabilities[i] for i in top_5_indices]
124
  top_5_labels = [int_to_label[i] for i in top_5_indices]