mawairon commited on
Commit
16b0032
·
verified ·
1 Parent(s): f7fbeb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -80,7 +80,7 @@ def load_model(model_name: str):
80
  model_seq,
81
  new_head
82
  )
83
- weights = torch.load('/CNN_1stGEAC_m2_best.pth')
84
  model.load_state_dict(weights)
85
  return model, None
86
 
@@ -120,9 +120,9 @@ def analyze_dna(username, password, sequence, model_name):
120
  logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
121
  return logits
122
 
123
- elif model_name == 'CNN':
124
 
125
- SEQUENCE_LENGTH = 8000
126
  pad_char = 'N'
127
 
128
  # Truncate sequence
@@ -190,8 +190,9 @@ demo = gr.Interface(
190
  gr.Textbox(label="Password", type="password"),
191
  gr.Textbox(label="DNA Sequence"),
192
  gr.Dropdown(label="Model", choices=[
193
- "gena-bert",
194
- "CNN"
 
195
  ])
196
  ],
197
  outputs=[
 
80
  model_seq,
81
  new_head
82
  )
83
+ weights = torch.load('CNN_1stGEAC_m2_best.pth')
84
  model.load_state_dict(weights)
85
  return model, None
86
 
 
120
  logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
121
  return logits
122
 
123
+ elif model_name == 'CNN-8k-context' or model_name == 'CNN-16k-context':
124
 
125
+ SEQUENCE_LENGTH = 8000 if '8k' in model_name else (16000 if '16k' in model_name else None)
126
  pad_char = 'N'
127
 
128
  # Truncate sequence
 
190
  gr.Textbox(label="Password", type="password"),
191
  gr.Textbox(label="DNA Sequence"),
192
  gr.Dropdown(label="Model", choices=[
193
+ "GENA-Bert",
194
+ "CNN-8k-context",
195
+ "CNN-16k-context"
196
  ])
197
  ],
198
  outputs=[