mawairon commited on
Commit
7f58142
1 Parent(s): 16b0032

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -4
app.py CHANGED
@@ -61,7 +61,7 @@ def load_model(model_name: str):
61
 
62
  return model, tokenizer
63
 
64
- elif model_name == 'CNN':
65
  hidden_dim = 2048
66
  width = 2048
67
  seq_drop_prob = 0.05
@@ -84,6 +84,36 @@ def load_model(model_name: str):
84
  model.load_state_dict(weights)
85
  return model, None
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  else:
88
  raise ValueError("Invalid model name")
89
 
@@ -120,7 +150,7 @@ def analyze_dna(username, password, sequence, model_name):
120
  logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
121
  return logits
122
 
123
- elif model_name == 'CNN-8k-context' or model_name == 'CNN-16k-context':
124
 
125
  SEQUENCE_LENGTH = 8000 if '8k' in model_name else (16000 if '16k' in model_name else None)
126
  pad_char = 'N'
@@ -191,8 +221,8 @@ demo = gr.Interface(
191
  gr.Textbox(label="DNA Sequence"),
192
  gr.Dropdown(label="Model", choices=[
193
  "GENA-Bert",
194
- "CNN-8k-context",
195
- "CNN-16k-context"
196
  ])
197
  ],
198
  outputs=[
 
61
 
62
  return model, tokenizer
63
 
64
+ elif model_name == 'CNN-m2-8k-context':
65
  hidden_dim = 2048
66
  width = 2048
67
  seq_drop_prob = 0.05
 
84
  model.load_state_dict(weights)
85
  return model, None
86
 
87
+ elif model_name == 'CNN-m4-16k-context':
88
+ seq_drop_prob = 0.05
89
+ hidden_dim = 2000
90
+ width = 768
91
+ train_sequence_length = 16000
92
+ weight_decay = 0.0001
93
+ num_labs = len(set(y_train))
94
+
95
+
96
+ model_seq = nn.Sequential(
97
+ nn.Conv1d(4, width, 7, padding=3),
98
+ nn.ReLU(),
99
+ nn.BatchNorm1d(width),
100
+ ResNet1d(width, [(3, width // 2, 1)] * 1, dropout=None, dilation=7),
101
+ nn.ReLU(),
102
+ Pool2BN(width),
103
+ )
104
+
105
+ new_head = torch.nn.Sequential(
106
+ torch.nn.Dropout(0.5), ## for DEEPLIFT comment out
107
+ MLP([width * 2, num_labs])
108
+ )
109
+
110
+ joined_model = torch.nn.Sequential(
111
+ model_seq,
112
+ new_head
113
+ )
114
+
115
+ return joined_model, None
116
+
117
  else:
118
  raise ValueError("Invalid model name")
119
 
 
150
  logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
151
  return logits
152
 
153
+ elif 'CNN' in model_name:
154
 
155
  SEQUENCE_LENGTH = 8000 if '8k' in model_name else (16000 if '16k' in model_name else None)
156
  pad_char = 'N'
 
221
  gr.Textbox(label="DNA Sequence"),
222
  gr.Dropdown(label="Model", choices=[
223
  "GENA-Bert",
224
+ "CNN-m2-8k-context",
225
+ "CNN-m4-16k-context"
226
  ])
227
  ],
228
  outputs=[