andyP commited on
Commit
26a4923
1 Parent(s): aff5ec5

Initial commit

Browse files
adjacency_matrix/graph_extended_comments.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0508863549ac3faea223be7d93bef5ec24b70af65124223fec485e1021b0f3e
3
+ size 829003020
config.json CHANGED
@@ -1,9 +1,10 @@
1
  {
2
  "attention_probs_dropout_prob": 0.1,
 
3
  "classifier_dropout": null,
4
  "do_lower_case": 1,
5
  "do_remove_accents": 0,
6
- "gcn_adj_matrix": "adjacency_matrix/graph_dataset_comments.pkl",
7
  "gcn_embedding_dim": 32,
8
  "gradient_checkpointing": false,
9
  "hidden_act": "gelu",
@@ -33,7 +34,7 @@
33
  "pad_token_id": 0,
34
  "position_embedding_type": "absolute",
35
  "tf_threshold": 0.0,
36
- "transformers_version": "4.30.2",
37
  "type_vocab_size": 2,
38
  "use_cache": true,
39
  "vocab_size": 37788,
 
1
  {
2
  "attention_probs_dropout_prob": 0.1,
3
+ "bert_model": "readerbench/RoBERT-base",
4
  "classifier_dropout": null,
5
  "do_lower_case": 1,
6
  "do_remove_accents": 0,
7
+ "gcn_adj_matrix": "adjacency_matrix/graph_extended_comments.pkl",
8
  "gcn_embedding_dim": 32,
9
  "gradient_checkpointing": false,
10
  "hidden_act": "gelu",
 
34
  "pad_token_id": 0,
35
  "position_embedding_type": "absolute",
36
  "tf_threshold": 0.0,
37
+ "transformers_version": "4.31.0",
38
  "type_vocab_size": 2,
39
  "use_cache": true,
40
  "vocab_size": 37788,
modeling_vcgn.py CHANGED
@@ -1,4 +1,6 @@
 
1
  import torch
 
2
  from transformers import PreTrainedModel, BertTokenizer
3
  from transformers.utils import is_remote_url, download_url
4
  from pathlib import Path
@@ -49,6 +51,9 @@ def get_torch_gcn(gcn_vocab_adj_tf, gcn_vocab_adj,gcn_config:VGCNConfig):
49
  adj = gcn_vocab_adj_list[i]
50
  adj = normalize_adj(adj)
51
  norm_gcn_vocab_adj_list.append(sparse_scipy2torch(adj.tocoo()))
 
 
 
52
 
53
  del gcn_vocab_adj_list
54
 
@@ -66,7 +71,8 @@ class VCGNModelForTextClassification(PreTrainedModel):
66
  self.remove_stop_words = False
67
  self.tokenizer = None
68
  self.norm_gcn_vocab_adj_list = None
69
-
 
70
 
71
  self.load_adj_matrix(config.gcn_adj_matrix)
72
 
@@ -80,26 +86,97 @@ class VCGNModelForTextClassification(PreTrainedModel):
80
  )
81
 
82
  def load_adj_matrix(self, adj_matrix):
 
83
  if Path(adj_matrix).is_file():
84
- #load file
85
- gcn_vocab_adj_tf, gcn_vocab_adj, adj_config = pkl.load(open(adj_matrix, 'rb'))
86
- if is_remote_url(adj_matrix):
87
- resolved_archive_file = download_url(adj_matrix)
 
 
 
 
 
 
88
 
89
  self.pre_trained_model_name = adj_config['bert_model']
90
  self.remove_stop_words = adj_config['remove_stop_words']
91
  self.tokenizer = BertTokenizer.from_pretrained(self.pre_trained_model_name)
92
- self.norm_gcn_vocab_adj_list = get_torch_gcn(gcn_vocab_adj_tf, gcn_vocab_adj, self.config)
93
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- def forward(self, tensor, labels=None):
96
- logits = self.model(tensor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  if labels is not None:
98
  loss = torch.nn.cross_entropy(logits, labels)
99
  return {"loss": loss, "logits": logits}
100
  return {"logits": logits}
101
 
102
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  import torch
105
  import torch.nn as nn
@@ -130,7 +207,13 @@ class VocabGraphConvolution(nn.Module):
130
  """
131
  def __init__(self,adj_matrix,voc_dim, num_adj, hid_dim, out_dim, dropout_rate=0.2):
132
  super(VocabGraphConvolution, self).__init__()
133
- self.adj_matrix=adj_matrix
 
 
 
 
 
 
134
  self.voc_dim=voc_dim
135
  self.num_adj=num_adj
136
  self.hid_dim=hid_dim
@@ -147,7 +230,7 @@ class VocabGraphConvolution(nn.Module):
147
 
148
  def reset_parameters(self):
149
  for n,p in self.named_parameters():
150
- if n.startswith('W') or n.startswith('a') or n in ('W','a','dense'):
151
  init.kaiming_uniform_(p, a=math.sqrt(5))
152
 
153
  def forward(self, X_dv, add_linear_mapping_term=False):
 
1
+ from typing import List, Union
2
  import torch
3
+ import torch.nn.functional as F
4
  from transformers import PreTrainedModel, BertTokenizer
5
  from transformers.utils import is_remote_url, download_url
6
  from pathlib import Path
 
51
  adj = gcn_vocab_adj_list[i]
52
  adj = normalize_adj(adj)
53
  norm_gcn_vocab_adj_list.append(sparse_scipy2torch(adj.tocoo()))
54
+
55
+ for t in norm_gcn_vocab_adj_list:
56
+ t.requires_grad = False
57
 
58
  del gcn_vocab_adj_list
59
 
 
71
  self.remove_stop_words = False
72
  self.tokenizer = None
73
  self.norm_gcn_vocab_adj_list = None
74
+ self.gcn_vocab_size = config.vocab_size
75
+
76
 
77
  self.load_adj_matrix(config.gcn_adj_matrix)
78
 
 
86
  )
87
 
88
  def load_adj_matrix(self, adj_matrix):
89
+ filename = None
90
  if Path(adj_matrix).is_file():
91
+ filename = Path(adj_matrix)
92
+ #load file
93
+ elif (Path(__file__).parent / Path(adj_matrix)).is_file():
94
+ filename = Path(__file__).parent / Path(adj_matrix)
95
+ elif is_remote_url(adj_matrix):
96
+ filename = download_url(adj_matrix)
97
+
98
+
99
+ gcn_vocab_adj_tf, gcn_vocab_adj, adj_config = pkl.load(open(filename, 'rb'))
100
+
101
 
102
  self.pre_trained_model_name = adj_config['bert_model']
103
  self.remove_stop_words = adj_config['remove_stop_words']
104
  self.tokenizer = BertTokenizer.from_pretrained(self.pre_trained_model_name)
105
+ self.norm_gcn_vocab_adj_list = get_torch_gcn(gcn_vocab_adj_tf, gcn_vocab_adj, self.config)
106
+
107
+ def _prep_batch(self, batch: torch.Tensor):
108
+
109
+ vocab_size = self.tokenizer.vocab_size
110
+
111
+ batch_gcn_swop_eye = F.one_hot(batch, vocab_size).float().to(self.device) # shape (batch_size, seq_len, vocab_size)
112
+ batch_gcn_swop_eye = batch_gcn_swop_eye.transpose(1,2) # shape (batch_size, vocab_size, seq_len)
113
+ # set all [PAD] tokens to 0
114
+ batch_gcn_swop_eye[:, self.tokenizer.pad_token_id, :] = 0
115
+ batch_gcn_swop_eye[:, self.tokenizer.cls_token_id, :] = 0
116
+ batch_gcn_swop_eye[:, self.tokenizer.sep_token_id, :] = 0
117
+
118
+ batch_gcn_swop_eye = F.pad(batch_gcn_swop_eye,(0,self.config.gcn_embedding_dim,0,0,0,0),value=0)
119
+
120
+ batch = F.pad(batch, (0, self.config.gcn_embedding_dim), 'constant', 0)
121
+
122
+ #fill gcn tokens with [SEP]
123
+ mask = torch.zeros(batch.shape[0], batch.shape[1] + 1, dtype=batch.dtype, device=self.device)
124
+ mask2 = torch.zeros(batch.shape[0], batch.shape[1] + 1, dtype=batch.dtype, device=self.device)
125
+
126
+ pos_start = (batch==self.tokenizer.pad_token_id).int().argmax(1)
127
+
128
+ mask[(torch.arange(batch.shape[0]), pos_start)] = 1
129
+ mask2[(torch.arange(batch.shape[0]), pos_start+self.config.gcn_embedding_dim)] = 1
130
+
131
+ mask = mask.cumsum(1)[:, :-1].bool()
132
+ mask2 = mask2.cumsum(1)[:, :-1].bool()
133
 
134
+ mask = mask & ~mask2
135
+
136
+ batch.masked_fill_(mask, self.tokenizer.sep_token_id)
137
+
138
+ return batch, batch_gcn_swop_eye
139
+
140
+ def text_to_batch(self, text: Union[List[str], str]):
141
+ if isinstance(text, str):
142
+ text = [text]
143
+ encoded = self.tokenizer.batch_encode_plus(text, padding=True, truncation=True, return_tensors='pt', max_length=self.config.max_seq_len-self.config.gcn_embedding_dim)
144
+ return encoded['input_ids'].to(self.device)
145
+
146
+ def forward(self, input:Union[torch.Tensor, List[str], str], labels=None):
147
+
148
+ if not isinstance(input, torch.Tensor):
149
+ input = self.text_to_batch(input)
150
+
151
+ input, batch_gcn_swop_eye = self._prep_batch(input)
152
+
153
+ segment_ids = torch.zeros_like(input).int().to(self.device)
154
+ input_mask = (input>0).int().to(self.device)
155
+
156
+
157
+ logits = self.model(batch_gcn_swop_eye, input, segment_ids, input_mask )
158
  if labels is not None:
159
  loss = torch.nn.cross_entropy(logits, labels)
160
  return {"loss": loss, "logits": logits}
161
  return {"logits": logits}
162
 
163
+ def predict(self, text: Union[List[str], str], as_dict=True):
164
+ with torch.no_grad():
165
+ logits = self.forward(text)['logits']
166
+ if as_dict:
167
+ label_id = torch.argmax(logits, dim=1).cpu().numpy()
168
+ label = [self.config.id2label[l] for l in label_id]
169
+ return {
170
+ "logits": logits,
171
+ "label_id": label_id,
172
+ "label": label,
173
+ }
174
+ else:
175
+ return torch.argmax(logits, dim=1).cpu().numpy()
176
+
177
+ @property
178
+ def device(self):
179
+ return next(self.parameters()).device
180
 
181
  import torch
182
  import torch.nn as nn
 
207
  """
208
  def __init__(self,adj_matrix,voc_dim, num_adj, hid_dim, out_dim, dropout_rate=0.2):
209
  super(VocabGraphConvolution, self).__init__()
210
+ if type(adj_matrix) is not list:
211
+ self.adj_matrix=adj_matrix
212
+ else:
213
+ self.adj_matrix=torch.nn.ParameterList([torch.nn.Parameter(x) for x in adj_matrix])
214
+ for p in self.adj_matrix:
215
+ p.requires_grad=False
216
+
217
  self.voc_dim=voc_dim
218
  self.num_adj=num_adj
219
  self.hid_dim=hid_dim
 
230
 
231
  def reset_parameters(self):
232
  for n,p in self.named_parameters():
233
+ if n.startswith('W') :
234
  init.kaiming_uniform_(p, a=math.sqrt(5))
235
 
236
  def forward(self, X_dv, add_linear_mapping_term=False):
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5b23581bedc6271217c0910a5676cfbb76a36b8b707a8f8f4171986cc6e5d8dd
3
- size 479695719
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2dd4760540bf1667e77b45ab271e0a87376a97ecb0ea7ab669391e45a5606820
3
+ size 481615461