meghanaraok commited on
Commit
c142205
1 Parent(s): 3f55ad4

Delete HiLATmain/models/modeling - Copy1.py

Browse files
Files changed (1) hide show
  1. HiLATmain/models/modeling - Copy1.py +0 -337
HiLATmain/models/modeling - Copy1.py DELETED
@@ -1,337 +0,0 @@
1
- import collections
2
- import logging
3
-
4
- import torch
5
- from torch.nn import BCEWithLogitsLoss, Dropout, Linear
6
- from transformers import AutoModel, XLNetModel, LongformerModel, LongformerConfig
7
- from transformers.models.longformer.modeling_longformer import LongformerEncoder, LongformerClassificationHead, LongformerLayer
8
-
9
- from hilat.models.utils import initial_code_title_vectors
10
-
11
- logger = logging.getLogger("lwat")
12
-
13
-
14
- class CodingModelConfig:
15
- def __init__(self,
16
- transformer_model_name_or_path,
17
- transformer_tokenizer_name,
18
- transformer_layer_update_strategy,
19
- num_chunks,
20
- max_seq_length,
21
- dropout,
22
- dropout_att,
23
- d_model,
24
- label_dictionary,
25
- num_labels,
26
- use_code_representation,
27
- code_max_seq_length,
28
- code_batch_size,
29
- multi_head_att,
30
- chunk_att,
31
- linear_init_mean,
32
- linear_init_std,
33
- document_pooling_strategy,
34
- multi_head_chunk_attention):
35
- super(CodingModelConfig, self).__init__()
36
- self.transformer_model_name_or_path = transformer_model_name_or_path
37
- self.transformer_tokenizer_name = transformer_tokenizer_name
38
- self.transformer_layer_update_strategy = transformer_layer_update_strategy
39
- self.num_chunks = num_chunks
40
- self.max_seq_length = max_seq_length
41
- self.dropout = dropout
42
- self.dropout_att = dropout_att
43
- self.d_model = d_model
44
- # labels_dictionary is a dataframe with columns: icd9_code, long_title
45
- self.label_dictionary = label_dictionary
46
- self.num_labels = num_labels
47
- self.use_code_representation = use_code_representation
48
- self.code_max_seq_length = code_max_seq_length
49
- self.code_batch_size = code_batch_size
50
- self.multi_head_att = multi_head_att
51
- self.chunk_att = chunk_att
52
- self.linear_init_mean = linear_init_mean
53
- self.linear_init_std = linear_init_std
54
- self.document_pooling_strategy = document_pooling_strategy
55
- self.multi_head_chunk_attention = multi_head_chunk_attention
56
-
57
-
58
- class LableWiseAttentionLayer(torch.nn.Module):
59
- def __init__(self, coding_model_config, args):
60
- super(LableWiseAttentionLayer, self).__init__()
61
-
62
- self.config = coding_model_config
63
- self.args = args
64
-
65
- # layers
66
- self.l1_linear = torch.nn.Linear(self.config.d_model,
67
- self.config.d_model, bias=False)
68
- self.tanh = torch.nn.Tanh()
69
- self.l2_linear = torch.nn.Linear(self.config.d_model, self.config.num_labels, bias=False)
70
- self.softmax = torch.nn.Softmax(dim=1)
71
-
72
- # Mean pooling last hidden state of code title from transformer model as the initial code vectors
73
- self._init_linear_weights(mean=self.config.linear_init_mean, std=self.config.linear_init_std)
74
-
75
- def _init_linear_weights(self, mean, std):
76
- # normalize the l1 weights
77
- torch.nn.init.normal_(self.l1_linear.weight, mean, std)
78
- if self.l1_linear.bias is not None:
79
- self.l1_linear.bias.data.fill_(0)
80
- # initialize the l2
81
- if self.config.use_code_representation:
82
- code_vectors = initial_code_title_vectors(self.config.label_dictionary,
83
- self.config.transformer_model_name_or_path,
84
- self.config.transformer_tokenizer_name
85
- if self.config.transformer_tokenizer_name
86
- else self.config.transformer_model_name_or_path,
87
- self.config.code_max_seq_length,
88
- self.config.code_batch_size,
89
- self.config.d_model,
90
- self.args.device)
91
-
92
- self.l2_linear.weight = torch.nn.Parameter(code_vectors, requires_grad=True)
93
- torch.nn.init.normal_(self.l2_linear.weight, mean, std)
94
- if self.l2_linear.bias is not None:
95
- self.l2_linear.bias.data.fill_(0)
96
-
97
- def forward(self, x):
98
- # input: (batch_size, max_seq_length, transformer_hidden_size)
99
- # output: (batch_size, max_seq_length, transformer_hidden_size)
100
- # Z = Tan(WH)
101
- l1_output = self.tanh(self.l1_linear(x))
102
- # softmax(UZ)
103
- # l2_linear output shape: (batch_size, max_seq_length, num_labels)
104
- # attention_weight shape: (batch_size, num_labels, max_seq_length)
105
- attention_weight = self.softmax(self.l2_linear(l1_output)).transpose(1, 2)
106
- # attention_output shpae: (batch_size, num_labels, transformer_hidden_size)
107
- attention_output = torch.matmul(attention_weight, x)
108
-
109
- return attention_output, attention_weight
110
-
111
- class ChunkAttentionLayer(torch.nn.Module):
112
- def __init__(self, coding_model_config, args):
113
- super(ChunkAttentionLayer, self).__init__()
114
-
115
- self.config = coding_model_config
116
- self.args = args
117
-
118
- # layers
119
- self.l1_linear = torch.nn.Linear(self.config.d_model,
120
- self.config.d_model, bias=False)
121
- self.tanh = torch.nn.Tanh()
122
- self.l2_linear = torch.nn.Linear(self.config.d_model, 1, bias=False)
123
- self.softmax = torch.nn.Softmax(dim=1)
124
-
125
- self._init_linear_weights(mean=self.config.linear_init_mean, std=self.config.linear_init_std)
126
-
127
- def _init_linear_weights(self, mean, std):
128
- # initialize the l1
129
- torch.nn.init.normal_(self.l1_linear.weight, mean, std)
130
- if self.l1_linear.bias is not None:
131
- self.l1_linear.bias.data.fill_(0)
132
- # initialize the l2
133
- torch.nn.init.normal_(self.l2_linear.weight, mean, std)
134
- if self.l2_linear.bias is not None:
135
- self.l2_linear.bias.data.fill_(0)
136
-
137
- def forward(self, x):
138
- # input: (batch_size, num_chunks, transformer_hidden_size)
139
- # output: (batch_size, num_chunks, transformer_hidden_size)
140
- # Z = Tan(WH)
141
- l1_output = self.tanh(self.l1_linear(x))
142
- # softmax(UZ)
143
- # l2_linear output shape: (batch_size, num_chunks, 1)
144
- # attention_weight shape: (batch_size, 1, num_chunks)
145
- attention_weight = self.softmax(self.l2_linear(l1_output)).transpose(1, 2)
146
- # attention_output shpae: (batch_size, 1, transformer_hidden_size)
147
- attention_output = torch.matmul(attention_weight, x)
148
- return attention_output, attention_weight
149
-
150
-
151
- class CodingModel(torch.nn.Module):
152
- def __init__(self, coding_model_config, args):
153
- super(CodingModel, self).__init__()
154
- self.coding_model_config = coding_model_config
155
- self.args = args
156
- # layers
157
- self.transformer_layer = AutoModel.from_pretrained(self.coding_model_config.transformer_model_name_or_path)
158
- if isinstance(self.transformer_layer, XLNetModel):
159
- self.transformer_layer.config.use_mems_eval = False
160
- self.dropout = Dropout(p=self.coding_model_config.dropout)
161
-
162
- if self.coding_model_config.multi_head_att:
163
- # initial multi head attention according to the num_chunks
164
- self.label_wise_attention_layer = torch.nn.ModuleList(
165
- [LableWiseAttentionLayer(coding_model_config, args)
166
- for _ in range(self.coding_model_config.num_chunks)])
167
- else:
168
- self.label_wise_attention_layer = LableWiseAttentionLayer(coding_model_config, args)
169
- self.dropout_att = Dropout(p=self.coding_model_config.dropout_att)
170
-
171
- # initial chunk attention
172
- if self.coding_model_config.chunk_att:
173
- if self.coding_model_config.multi_head_chunk_attention:
174
- self.chunk_attention_layer = torch.nn.ModuleList([ChunkAttentionLayer(coding_model_config, args)
175
- for _ in range(self.coding_model_config.num_labels)])
176
- else:
177
- self.chunk_attention_layer = ChunkAttentionLayer(coding_model_config, args)
178
-
179
- self.classifier_layer = Linear(self.coding_model_config.d_model,
180
- self.coding_model_config.num_labels)
181
- else:
182
- if self.coding_model_config.document_pooling_strategy == "flat":
183
- self.classifier_layer = Linear(self.coding_model_config.num_chunks * self.coding_model_config.d_model,
184
- self.coding_model_config.num_labels)
185
- else: # max or mean pooling
186
- self.classifier_layer = Linear(self.coding_model_config.d_model,
187
- self.coding_model_config.num_labels)
188
- self.sigmoid = torch.nn.Sigmoid()
189
-
190
- if self.coding_model_config.transformer_layer_update_strategy == "no":
191
- self.freeze_all_transformer_layers()
192
- elif self.coding_model_config.transformer_layer_update_strategy == "last":
193
- self.freeze_all_transformer_layers()
194
- self.unfreeze_transformer_last_layers()
195
-
196
- # initialize the weights of classifier
197
- self._init_linear_weights(mean=self.coding_model_config.linear_init_mean, std=self.coding_model_config.linear_init_std)
198
-
199
- def _init_linear_weights(self, mean, std):
200
- torch.nn.init.normal_(self.classifier_layer.weight, mean, std)
201
-
202
- def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, targets=None):
203
- # input ids/mask/type_ids shape: (batch_size, num_chunks, max_seq_length)
204
- # labels shape: (batch_size, num_labels)
205
- transformer_output = []
206
-
207
- # pass chunk by chunk into transformer layer in the batches.
208
- # input (batch_size, sequence_length)
209
- for i in range(self.coding_model_config.num_chunks):
210
- l1_output = self.transformer_layer(input_ids=input_ids[:, i, :],
211
- attention_mask=attention_mask[:, i, :],
212
- token_type_ids=token_type_ids[:, i, :])
213
- # output hidden state shape: (batch_size, sequence_length, hidden_size)
214
- transformer_output.append(l1_output[0])
215
-
216
- # transpose back chunk and batch size dimensions
217
- transformer_output = torch.stack(transformer_output)
218
- transformer_output = transformer_output.transpose(0, 1)
219
- # dropout transformer output
220
- l2_dropout = self.dropout(transformer_output)
221
-
222
- # Label-wise attention layers
223
- # output: (batch_size, num_chunks, num_labels, hidden_size)
224
- attention_output = []
225
- attention_weights = []
226
-
227
- for i in range(self.coding_model_config.num_chunks):
228
- # input: (batch_size, max_seq_length, transformer_hidden_size)
229
- if self.coding_model_config.multi_head_att:
230
- attention_layer = self.label_wise_attention_layer[i]
231
- else:
232
- attention_layer = self.label_wise_attention_layer
233
- l3_attention, attention_weight = attention_layer(l2_dropout[:, i, :])
234
- # l3_attention shape: (batch_size, num_labels, hidden_size)
235
- # attention_weight: (batch_size, num_labels, max_seq_length)
236
- attention_output.append(l3_attention)
237
- attention_weights.append(attention_weight)
238
-
239
- attention_output = torch.stack(attention_output)
240
- attention_output = attention_output.transpose(0, 1)
241
- attention_weights = torch.stack(attention_weights)
242
- attention_weights = attention_weights.transpose(0, 1)
243
-
244
- config = LongformerConfig.from_pretrained("allenai/longformer-base-4096")
245
- config.num_labels =5
246
- config.num_hidden_layers = 1
247
- longformer_layer = LongformerLayer(config)
248
- l2_dropout= l2_dropout.reshape(l2_dropout.shape[0], l2_dropout.shape[1]*l2_dropout.shape[2], l2_dropout.shape[3])
249
- attention_mask = attention_mask.reshape(attention_mask.shape[0], attention_mask.shape[1]*attention_mask.shape[2])
250
- is_index_masked = attention_mask < 0
251
- output = longformer_layer(l2_dropout, attention_mask=attention_mask,output_attentions=True, is_index_masked=is_index_masked)
252
- l3_dropout = self.dropout_att(output[0])
253
- l3_dropout = l3_dropout.reshape(l3_dropout.shape[0], self.coding_model_config.num_chunks, self.coding_model_config.max_seq_length, self.coding_model_config.d_model)
254
- self.softmax = torch.nn.Softmax(dim=1)
255
- self.l2_linear = torch.nn.Linear(self.coding_model_config.d_model, self.coding_model_config.num_labels, bias=False)
256
- attention_weight = self.softmax(self.l2_linear(l3_dropout)).transpose(1, 2)
257
- attention_weight = attention_weight.reshape(attention_weight.shape[0], self.coding_model_config.num_labels, self.coding_model_config.num_chunks, self.coding_model_config.max_seq_length)
258
- # attention_weight = attention_weight.permute(0,2,1)
259
- l2_dropout = l2_dropout.reshape(l2_dropout.shape[0], self.coding_model_config.num_chunks, self.coding_model_config.max_seq_length, self.coding_model_config.d_model)
260
-
261
- attention_output = []
262
-
263
- for i in range(self.coding_model_config.num_chunks):
264
- l3_attention = torch.matmul(attention_weight[:,:,i], l2_dropout[:,i,:])
265
- attention_output.append(l3_attention)
266
-
267
- attention_output = torch.stack(attention_output)
268
- l3_dropout = self.dropout_att(attention_output)
269
- l3_dropout = l3_dropout.transpose(0,1)
270
-
271
-
272
- if self.coding_model_config.chunk_att:
273
- # Chunk attention layers
274
- # output: (batch_size, num_labels, hidden_size)
275
- chunk_attention_output = []
276
- chunk_attention_weights = []
277
-
278
- for i in range(self.coding_model_config.num_labels):
279
- if self.coding_model_config.multi_head_chunk_attention:
280
- chunk_attention = self.chunk_attention_layer[i]
281
- else:
282
- chunk_attention = self.chunk_attention_layer
283
- l4_chunk_attention, l4_chunk_attention_weights = chunk_attention(l3_dropout[:, :, i])
284
- chunk_attention_output.append(l4_chunk_attention.squeeze())
285
- chunk_attention_weights.append(l4_chunk_attention_weights.squeeze())
286
-
287
- chunk_attention_output = torch.stack(chunk_attention_output)
288
- chunk_attention_output = chunk_attention_output.transpose(0, 1)
289
- chunk_attention_weights = torch.stack(chunk_attention_weights)
290
- chunk_attention_weights = chunk_attention_weights.transpose(0, 1)
291
- # output shape: (batch_size, num_labels, hidden_size)
292
- l4_dropout = self.dropout_att(chunk_attention_output)
293
- else:
294
- # output shape: (batch_size, num_labels, hidden_size*num_chunks)
295
- l4_dropout = l3_dropout.transpose(1, 2)
296
- if self.coding_model_config.document_pooling_strategy == "flat":
297
- # Flatten layer. concatenate representation by labels
298
- l4_dropout = torch.flatten(l4_dropout, start_dim=2)
299
- elif self.coding_model_config.document_pooling_strategy == "max":
300
- l4_dropout = torch.amax(l4_dropout, 2)
301
- elif self.coding_model_config.document_pooling_strategy == "mean":
302
- l4_dropout = torch.mean(l4_dropout, 2)
303
- else:
304
- raise ValueError("Not supported pooling strategy")
305
-
306
- # classifier layer
307
- # each code has a binary linear formula
308
- logits = self.classifier_layer.weight.mul(l4_dropout).sum(dim=2).add(self.classifier_layer.bias)
309
-
310
- loss_fct = BCEWithLogitsLoss()
311
- loss = loss_fct(logits, targets)
312
-
313
- return {
314
- "loss": loss,
315
- "logits": logits,
316
- "label_attention_weights": attention_weights,
317
- "chunk_attention_weights": chunk_attention_weights if self.coding_model_config.chunk_att else []
318
- }
319
-
320
- def freeze_all_transformer_layers(self):
321
- """
322
- Freeze all layer weight parameters. They will not be updated during training.
323
- """
324
- for param in self.transformer_layer.parameters():
325
- param.requires_grad = False
326
-
327
- def unfreeze_all_transformer_layers(self):
328
- """
329
- Unfreeze all layers weight parameters. They will be updated during training.
330
- """
331
- for param in self.transformer_layer.parameters():
332
- param.requires_grad = True
333
-
334
- def unfreeze_transformer_last_layers(self):
335
- for name, param in self.transformer_layer.named_parameters():
336
- if "layer.11" in name or "pooler" in name:
337
- param.requires_grad = True