Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -80,7 +80,7 @@ def load_model(model_name: str):
|
|
80 |
model_seq,
|
81 |
new_head
|
82 |
)
|
83 |
-
weights = torch.load('
|
84 |
model.load_state_dict(weights)
|
85 |
return model, None
|
86 |
|
@@ -120,9 +120,9 @@ def analyze_dna(username, password, sequence, model_name):
|
|
120 |
logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
121 |
return logits
|
122 |
|
123 |
-
elif model_name == 'CNN':
|
124 |
|
125 |
-
SEQUENCE_LENGTH = 8000
|
126 |
pad_char = 'N'
|
127 |
|
128 |
# Truncate sequence
|
@@ -190,8 +190,9 @@ demo = gr.Interface(
|
|
190 |
gr.Textbox(label="Password", type="password"),
|
191 |
gr.Textbox(label="DNA Sequence"),
|
192 |
gr.Dropdown(label="Model", choices=[
|
193 |
-
"
|
194 |
-
"CNN"
|
|
|
195 |
])
|
196 |
],
|
197 |
outputs=[
|
|
|
80 |
model_seq,
|
81 |
new_head
|
82 |
)
|
83 |
+
weights = torch.load('CNN_1stGEAC_m2_best.pth')
|
84 |
model.load_state_dict(weights)
|
85 |
return model, None
|
86 |
|
|
|
120 |
logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
121 |
return logits
|
122 |
|
123 |
+
elif model_name == 'CNN-8k-context' or model_name == 'CNN-16k-context':
|
124 |
|
125 |
+
SEQUENCE_LENGTH = 8000 if '8k' in model_name else (16000 if '16k' in model_name else None)
|
126 |
pad_char = 'N'
|
127 |
|
128 |
# Truncate sequence
|
|
|
190 |
gr.Textbox(label="Password", type="password"),
|
191 |
gr.Textbox(label="DNA Sequence"),
|
192 |
gr.Dropdown(label="Model", choices=[
|
193 |
+
"GENA-Bert",
|
194 |
+
"CNN-8k-context",
|
195 |
+
"CNN-16k-context"
|
196 |
])
|
197 |
],
|
198 |
outputs=[
|