Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -61,7 +61,7 @@ def load_model(model_name: str):
|
|
61 |
|
62 |
return model, tokenizer
|
63 |
|
64 |
-
elif model_name == 'CNN':
|
65 |
hidden_dim = 2048
|
66 |
width = 2048
|
67 |
seq_drop_prob = 0.05
|
@@ -84,6 +84,36 @@ def load_model(model_name: str):
|
|
84 |
model.load_state_dict(weights)
|
85 |
return model, None
|
86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
else:
|
88 |
raise ValueError("Invalid model name")
|
89 |
|
@@ -120,7 +150,7 @@ def analyze_dna(username, password, sequence, model_name):
|
|
120 |
logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
121 |
return logits
|
122 |
|
123 |
-
elif
|
124 |
|
125 |
SEQUENCE_LENGTH = 8000 if '8k' in model_name else (16000 if '16k' in model_name else None)
|
126 |
pad_char = 'N'
|
@@ -191,8 +221,8 @@ demo = gr.Interface(
|
|
191 |
gr.Textbox(label="DNA Sequence"),
|
192 |
gr.Dropdown(label="Model", choices=[
|
193 |
"GENA-Bert",
|
194 |
-
"CNN-8k-context",
|
195 |
-
"CNN-16k-context"
|
196 |
])
|
197 |
],
|
198 |
outputs=[
|
|
|
61 |
|
62 |
return model, tokenizer
|
63 |
|
64 |
+
elif model_name == 'CNN-m2-8k-context':
|
65 |
hidden_dim = 2048
|
66 |
width = 2048
|
67 |
seq_drop_prob = 0.05
|
|
|
84 |
model.load_state_dict(weights)
|
85 |
return model, None
|
86 |
|
87 |
+
elif model_name == 'CNN-m4-16k-context':
|
88 |
+
seq_drop_prob = 0.05
|
89 |
+
hidden_dim = 2000
|
90 |
+
width = 768
|
91 |
+
train_sequence_length = 16000
|
92 |
+
weight_decay = 0.0001
|
93 |
+
num_labs = len(set(y_train))
|
94 |
+
|
95 |
+
|
96 |
+
model_seq = nn.Sequential(
|
97 |
+
nn.Conv1d(4, width, 7, padding=3),
|
98 |
+
nn.ReLU(),
|
99 |
+
nn.BatchNorm1d(width),
|
100 |
+
ResNet1d(width, [(3, width // 2, 1)] * 1, dropout=None, dilation=7),
|
101 |
+
nn.ReLU(),
|
102 |
+
Pool2BN(width),
|
103 |
+
)
|
104 |
+
|
105 |
+
new_head = torch.nn.Sequential(
|
106 |
+
torch.nn.Dropout(0.5), ## for DEEPLIFT comment out
|
107 |
+
MLP([width * 2, num_labs])
|
108 |
+
)
|
109 |
+
|
110 |
+
joined_model = torch.nn.Sequential(
|
111 |
+
model_seq,
|
112 |
+
new_head
|
113 |
+
)
|
114 |
+
|
115 |
+
return joined_model, None
|
116 |
+
|
117 |
else:
|
118 |
raise ValueError("Invalid model name")
|
119 |
|
|
|
150 |
logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
151 |
return logits
|
152 |
|
153 |
+
elif 'CNN' in model_name:
|
154 |
|
155 |
SEQUENCE_LENGTH = 8000 if '8k' in model_name else (16000 if '16k' in model_name else None)
|
156 |
pad_char = 'N'
|
|
|
221 |
gr.Textbox(label="DNA Sequence"),
|
222 |
gr.Dropdown(label="Model", choices=[
|
223 |
"GENA-Bert",
|
224 |
+
"CNN-m2-8k-context",
|
225 |
+
"CNN-m4-16k-context"
|
226 |
])
|
227 |
],
|
228 |
outputs=[
|