hdallatorre commited on
Commit
17d1cb0
1 Parent(s): 33d57f0

Upload SegmentNT

Browse files
config.json CHANGED
@@ -10,6 +10,22 @@
10
  },
11
  "emb_layer_norm_before": false,
12
  "esmfold_config": null,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  "hidden_dropout_prob": 0.0,
14
  "hidden_size": 1024,
15
  "initializer_range": 0.02,
@@ -20,7 +36,6 @@
20
  "max_position_embeddings": 2050,
21
  "model_type": "esm",
22
  "num_attention_heads": 16,
23
- "num_features": 14,
24
  "num_hidden_layers": 29,
25
  "num_layers_head": 2,
26
  "pad_token_id": 1,
 
10
  },
11
  "emb_layer_norm_before": false,
12
  "esmfold_config": null,
13
+ "features": [
14
+ "protein_coding_gene",
15
+ "lncRNA",
16
+ "exon",
17
+ "intron",
18
+ "splice_donor",
19
+ "splice_acceptor",
20
+ "5UTR",
21
+ "3UTR",
22
+ "CTCF-bound",
23
+ "polyA_signal",
24
+ "enhancer_Tissue_specific",
25
+ "enhancer_Tissue_invariant",
26
+ "promoter_Tissue_specific",
27
+ "promoter_Tissue_invariant"
28
+ ],
29
  "hidden_dropout_prob": 0.0,
30
  "hidden_size": 1024,
31
  "initializer_range": 0.02,
 
36
  "max_position_embeddings": 2050,
37
  "model_type": "esm",
38
  "num_attention_heads": 16,
 
39
  "num_hidden_layers": 29,
40
  "num_layers_head": 2,
41
  "pad_token_id": 1,
modeling_segment_nt.py CHANGED
@@ -1159,6 +1159,7 @@ class SegmentNT(EsmPreTrainedModel):
1159
  super().__init__(config)
1160
  self.num_labels = config.num_labels
1161
  self.config = config
 
1162
 
1163
  self.esm = EsmModel(config, add_pooling_layer=False)
1164
 
@@ -1171,7 +1172,7 @@ class SegmentNT(EsmPreTrainedModel):
1171
  embed_dim * (2**i) for i in range(num_layers)
1172
  ),
1173
  )
1174
- self.fc = nn.Linear(in_features=embed_dim, out_features=6 * 2 * config.num_features)
1175
  self.activation_fn = nn.SiLU()
1176
 
1177
  self.init_weights()
@@ -1232,7 +1233,7 @@ class SegmentNT(EsmPreTrainedModel):
1232
  logits = self.fc(x)
1233
 
1234
  # Final reshape to have logits per nucleotides, per feature
1235
- logits = torch.reshape(logits, (x.shape[0], x.shape[1] * 6, self.config.num_features, 2))
1236
 
1237
  # Add logits to the ESM outputs
1238
  outputs["logits"] = logits
 
1159
  super().__init__(config)
1160
  self.num_labels = config.num_labels
1161
  self.config = config
1162
+ self.num_features = len(config.features)
1163
 
1164
  self.esm = EsmModel(config, add_pooling_layer=False)
1165
 
 
1172
  embed_dim * (2**i) for i in range(num_layers)
1173
  ),
1174
  )
1175
+ self.fc = nn.Linear(in_features=embed_dim, out_features=6 * 2 * self.num_features)
1176
  self.activation_fn = nn.SiLU()
1177
 
1178
  self.init_weights()
 
1233
  logits = self.fc(x)
1234
 
1235
  # Final reshape to have logits per nucleotides, per feature
1236
+ logits = torch.reshape(logits, (x.shape[0], x.shape[1] * 6, self.num_features, 2))
1237
 
1238
  # Add logits to the ESM outputs
1239
  outputs["logits"] = logits
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b3b9da828b4a2058d0c66c7b6e0249acd2e54b755c8fa300a8f7a5cb0ee19ebd
3
  size 2237478985
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:deed8fda3b2b9a76ee11aade732666b8303f97f4ecd6ab8aa6e3e1a9de9e1aa5
3
  size 2237478985
segment_nt_config.py CHANGED
@@ -15,7 +15,7 @@
15
  """ ESM model configuration"""
16
 
17
  from dataclasses import asdict, dataclass
18
- from typing import Optional
19
 
20
  from transformers import PretrainedConfig, logging
21
 
@@ -99,6 +99,7 @@ class SegmentNTConfig(PretrainedConfig):
99
 
100
  def __init__(
101
  self,
 
102
  vocab_size=None,
103
  mask_token_id=None,
104
  pad_token_id=None,
@@ -121,7 +122,6 @@ class SegmentNTConfig(PretrainedConfig):
121
  add_bias_fnn=True,
122
  rescaling_factor=None,
123
  num_layers_head=2,
124
- num_features=14,
125
  **kwargs,
126
  ):
127
  super().__init__(
@@ -147,7 +147,7 @@ class SegmentNTConfig(PretrainedConfig):
147
  self.add_bias_fnn = add_bias_fnn
148
  # Arguments needed for Segment NT
149
  self.num_layers_head = num_layers_head
150
- self.num_features = num_features
151
  self.rescaling_factor = rescaling_factor
152
  if is_folding_model:
153
  if esmfold_config is None:
 
15
  """ ESM model configuration"""
16
 
17
  from dataclasses import asdict, dataclass
18
+ from typing import List, Optional
19
 
20
  from transformers import PretrainedConfig, logging
21
 
 
99
 
100
  def __init__(
101
  self,
102
+ features=None,
103
  vocab_size=None,
104
  mask_token_id=None,
105
  pad_token_id=None,
 
122
  add_bias_fnn=True,
123
  rescaling_factor=None,
124
  num_layers_head=2,
 
125
  **kwargs,
126
  ):
127
  super().__init__(
 
147
  self.add_bias_fnn = add_bias_fnn
148
  # Arguments needed for Segment NT
149
  self.num_layers_head = num_layers_head
150
+ self.features = features
151
  self.rescaling_factor = rescaling_factor
152
  if is_folding_model:
153
  if esmfold_config is None: