andyP commited on
Commit
be023c1
1 Parent(s): a4015aa

Readme update

Browse files
Files changed (3) hide show
  1. config.json +8 -0
  2. configuration_vgcn.py +2 -0
  3. modeling_vcgn.py +40 -17
config.json CHANGED
@@ -1,5 +1,12 @@
1
  {
 
 
 
2
  "attention_probs_dropout_prob": 0.1,
 
 
 
 
3
  "bert_model": "readerbench/RoBERT-base",
4
  "classifier_dropout": null,
5
  "do_lower_case": 1,
@@ -34,6 +41,7 @@
34
  "pad_token_id": 0,
35
  "position_embedding_type": "absolute",
36
  "tf_threshold": 0.0,
 
37
  "transformers_version": "4.31.0",
38
  "type_vocab_size": 2,
39
  "use_cache": true,
 
1
  {
2
+ "architectures": [
3
+ "VCGNModelForTextClassification"
4
+ ],
5
  "attention_probs_dropout_prob": 0.1,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_vgcn.VGCNConfig",
8
+ "AutoModelForSequenceClassification": "modeling_vcgn.VCGNModelForTextClassification"
9
+ },
10
  "bert_model": "readerbench/RoBERT-base",
11
  "classifier_dropout": null,
12
  "do_lower_case": 1,
 
41
  "pad_token_id": 0,
42
  "position_embedding_type": "absolute",
43
  "tf_threshold": 0.0,
44
+ "torch_dtype": "float32",
45
  "transformers_version": "4.31.0",
46
  "type_vocab_size": 2,
47
  "use_cache": true,
configuration_vgcn.py CHANGED
@@ -6,6 +6,7 @@ class VGCNConfig(BertConfig):
6
 
7
  def __init__(
8
  self,
 
9
  gcn_adj_matrix: str ='',
10
  max_seq_len: int = 256,
11
  npmi_threshold: float = 0.2,
@@ -29,5 +30,6 @@ class VGCNConfig(BertConfig):
29
  self.tf_threshold = tf_threshold
30
  self.vocab_type = vocab_type
31
  self.gcn_embedding_dim = gcn_embedding_dim
 
32
 
33
  super().__init__(**kwargs)
 
6
 
7
  def __init__(
8
  self,
9
+ bert_model='readerbench/RoBERT-base',
10
  gcn_adj_matrix: str ='',
11
  max_seq_len: int = 256,
12
  npmi_threshold: float = 0.2,
 
30
  self.tf_threshold = tf_threshold
31
  self.vocab_type = vocab_type
32
  self.gcn_embedding_dim = gcn_embedding_dim
33
+ self.bert_model = bert_model
34
 
35
  super().__init__(**kwargs)
modeling_vcgn.py CHANGED
@@ -64,27 +64,51 @@ def get_torch_gcn(gcn_vocab_adj_tf, gcn_vocab_adj,gcn_config:VGCNConfig):
64
  class VCGNModelForTextClassification(PreTrainedModel):
65
  config_class = VGCNConfig
66
 
67
- def __init__(self, config):
68
  super().__init__(config)
69
-
70
- self.pre_trained_model_name = ''
71
- self.remove_stop_words = False
72
- self.tokenizer = None
73
- self.norm_gcn_vocab_adj_list = None
74
- self.gcn_vocab_size = config.vocab_size
75
 
 
76
 
77
- self.load_adj_matrix(config.gcn_adj_matrix)
 
 
 
 
 
78
 
79
  self.model = VGCN_Bert(
80
  config,
81
- gcn_adj_matrix=self.norm_gcn_vocab_adj_list,
82
  gcn_adj_dim=config.vocab_size,
83
- gcn_adj_num=len(self.norm_gcn_vocab_adj_list),
84
  gcn_embedding_dim=config.gcn_embedding_dim,
85
 
86
  )
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def load_adj_matrix(self, adj_matrix):
89
  filename = None
90
  if Path(adj_matrix).is_file():
@@ -98,11 +122,8 @@ class VCGNModelForTextClassification(PreTrainedModel):
98
 
99
  gcn_vocab_adj_tf, gcn_vocab_adj, adj_config = pkl.load(open(filename, 'rb'))
100
 
101
-
102
- self.pre_trained_model_name = adj_config['bert_model']
103
- self.remove_stop_words = adj_config['remove_stop_words']
104
- self.tokenizer = BertTokenizer.from_pretrained(self.pre_trained_model_name)
105
- self.norm_gcn_vocab_adj_list = get_torch_gcn(gcn_vocab_adj_tf, gcn_vocab_adj, self.config)
106
 
107
  def _prep_batch(self, batch: torch.Tensor):
108
 
@@ -207,12 +228,14 @@ class VocabGraphConvolution(nn.Module):
207
  """
208
  def __init__(self,adj_matrix,voc_dim, num_adj, hid_dim, out_dim, dropout_rate=0.2):
209
  super(VocabGraphConvolution, self).__init__()
210
- if type(adj_matrix) is not list:
211
  self.adj_matrix=adj_matrix
212
- else:
213
  self.adj_matrix=torch.nn.ParameterList([torch.nn.Parameter(x) for x in adj_matrix])
214
  for p in self.adj_matrix:
215
  p.requires_grad=False
 
 
216
 
217
  self.voc_dim=voc_dim
218
  self.num_adj=num_adj
 
64
  class VCGNModelForTextClassification(PreTrainedModel):
65
  config_class = VGCNConfig
66
 
67
+ def __init__(self, config, load_adjacency_matrix=True,):
68
  super().__init__(config)
 
 
 
 
 
 
69
 
70
+ self.tokenizer = BertTokenizer.from_pretrained(config.bert_model)
71
 
72
+ if load_adjacency_matrix:
73
+ norm_gcn_vocab_adj_list = self.load_adj_matrix(config.gcn_adj_matrix)
74
+ else:
75
+ norm_gcn_vocab_adj_list = []
76
+ for _ in range(2 if config.vocab_type=='all' else 1):
77
+ norm_gcn_vocab_adj_list.append(torch.sparse.FloatTensor(torch.LongTensor([[0],[0]]), torch.Tensor([0]), (config.vocab_size, config.vocab_size)))
78
 
79
  self.model = VGCN_Bert(
80
  config,
81
+ gcn_adj_matrix=norm_gcn_vocab_adj_list,
82
  gcn_adj_dim=config.vocab_size,
83
+ gcn_adj_num=len(norm_gcn_vocab_adj_list),
84
  gcn_embedding_dim=config.gcn_embedding_dim,
85
 
86
  )
87
 
88
+ @classmethod
89
+ def from_pretrained(cls, *model_args, reload_adjacency_matrix=False, **kwargs):
90
+ model = super().from_pretrained( *model_args, **kwargs, load_adjacency_matrix=False)
91
+
92
+ if reload_adjacency_matrix:
93
+ norm_gcn_vocab_adj_list = model.load_adj_matrix(model.config.gcn_adj_matrix)
94
+ model.model.embeddings.vocab_gcn.adj_matrix=torch.nn.ParameterList([torch.nn.Parameter(x) for x in norm_gcn_vocab_adj_list])
95
+ for p in model.model.embeddings.vocab_gcn.adj_matrix:
96
+ p.requires_grad=False
97
+
98
+ return model
99
+
100
+ def set_adjacency_matrix(self, adj_matrix:Union[List, np.ndarray, sp.csr_matrix, torch.Tensor] ):
101
+
102
+ if isinstance(adj_matrix, np.ndarray):
103
+ adj_matrix = [torch.from_numpy(adj_matrix)]
104
+ else:
105
+ raise ValueError(f"adjacency matrix must be a list of torch.Tensor or torch.nn.Parameter, got {type(adj_matrix)}")
106
+
107
+ self.model.embeddings.vocab_gcn.adj_matrix=torch.nn.ParameterList([torch.nn.Parameter(x) for x in adj_matrix])
108
+ for p in self.model.embeddings.vocab_gcn.adj_matrix:
109
+ p.requires_grad=False
110
+
111
+
112
  def load_adj_matrix(self, adj_matrix):
113
  filename = None
114
  if Path(adj_matrix).is_file():
 
122
 
123
  gcn_vocab_adj_tf, gcn_vocab_adj, adj_config = pkl.load(open(filename, 'rb'))
124
 
125
+ self.tokenizer = BertTokenizer.from_pretrained(adj_config['bert_model'])
126
+ return get_torch_gcn(gcn_vocab_adj_tf, gcn_vocab_adj, self.config)
 
 
 
127
 
128
  def _prep_batch(self, batch: torch.Tensor):
129
 
 
228
  """
229
  def __init__(self,adj_matrix,voc_dim, num_adj, hid_dim, out_dim, dropout_rate=0.2):
230
  super(VocabGraphConvolution, self).__init__()
231
+ if isinstance(adj_matrix, nn.Parameter) or isinstance(adj_matrix, nn.ParameterList):
232
  self.adj_matrix=adj_matrix
233
+ elif isinstance(adj_matrix, list):
234
  self.adj_matrix=torch.nn.ParameterList([torch.nn.Parameter(x) for x in adj_matrix])
235
  for p in self.adj_matrix:
236
  p.requires_grad=False
237
+ else:
238
+ raise ValueError(f"adjacency matrix must be a list of torch.Tensor or torch.nn.Parameter, got {type(adj_matrix)}")
239
 
240
  self.voc_dim=voc_dim
241
  self.num_adj=num_adj