bourdoiscatie commited on
Commit
9a58121
·
verified ·
1 Parent(s): 637332d

Add QA head

Browse files
Files changed (1) hide show
  1. custom_heads_flash_t5.py +143 -11
custom_heads_flash_t5.py CHANGED
@@ -1,3 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@@ -220,19 +264,114 @@ class FlashT5ForSequenceClassification(FlashT5PreTrainedModel):
220
  )
221
 
222
 
223
-
224
- ################## Seq2Seq head ##################
225
  class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
226
  _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
227
 
228
  def __init__(self, config: FlashT5Config):
229
  super().__init__(config)
230
  self.transformer = FlashT5EncoderModel(config)
 
 
231
  self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
232
 
233
  # Initialize weights and apply final processing
234
  self.post_init()
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  self.model_parallel = False
237
 
238
  def forward(
@@ -249,12 +388,9 @@ class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
249
  ) -> Union[Tuple, QuestionAnsweringModelOutput]:
250
  r"""
251
  Returns:
252
-
253
  Example:
254
-
255
  ```python
256
  >>> from transformers import AutoTokenizer, MTxEncoderForQuestionAnswering
257
-
258
  >>> tokenizer = AutoTokenizer.from_pretrained("MTx-small")
259
  >>> model = MTxEncoderForQuestionAnswering.from_pretrained("MTx-small")
260
  >>> input_ids = tokenizer(
@@ -265,15 +401,11 @@ class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
265
  >>> end_logits = outputs.end_logits
266
  ```"""
267
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
268
-
269
- outputs = self.transformer(
270
  input_ids,
271
  attention_mask=attention_mask,
272
- head_mask=head_mask,
273
  inputs_embeds=inputs_embeds,
274
- output_attentions=output_attentions,
275
- output_hidden_states=output_hidden_states,
276
- return_dict=return_dict,
277
  )
278
  sequence_output = outputs[0]
279
 
 
1
+
2
+ Hugging Face's logo Hugging Face
3
+
4
+ Models
5
+ Datasets
6
+ Spaces
7
+ Posts
8
+ Docs
9
+ Pricing
10
+
11
+ CATIE-AQ
12
+ /
13
+ FAT5-small-UL2-fr
14
+ private
15
+ Feature Extraction
16
+ Transformers
17
+ Safetensors
18
+ French
19
+ flash_t5
20
+ flash-attention
21
+ UL2
22
+ FAT5
23
+ custom_code
24
+ Carbon Emissions
25
+ 7 papers
26
+ Model card
27
+ Files and versions
28
+ Community
29
+ Settings
30
+ FAT5-small-UL2-fr
31
+ / custom_heads_flash_t5.py
32
+ bourdoiscatie's picture
33
+ bourdoiscatie
34
+ Update custom_heads_flash_t5.py
35
+ 3477ecb
36
+ verified
37
+ 8 days ago
38
+ raw
39
+ history
40
+ blame
41
+ edit
42
+ delete
43
+ No virus
44
+ 16.9 kB
45
  import torch
46
  import torch.nn as nn
47
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
264
  )
265
 
266
 
 
 
267
  class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
268
  _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
269
 
270
  def __init__(self, config: FlashT5Config):
271
  super().__init__(config)
272
  self.transformer = FlashT5EncoderModel(config)
273
+
274
+ self.num_labels = config.num_labels
275
  self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
276
 
277
  # Initialize weights and apply final processing
278
  self.post_init()
279
 
280
+ # Model parallel
281
+ self.model_parallel = False
282
+
283
+ def forward(
284
+ self,
285
+ input_ids: Optional[torch.LongTensor] = None,
286
+ attention_mask: Optional[torch.FloatTensor] = None,
287
+ head_mask: Optional[torch.FloatTensor] = None,
288
+ inputs_embeds: Optional[torch.FloatTensor] = None,
289
+ start_positions: Optional[torch.Tensor] = None,
290
+ end_positions: Optional[torch.Tensor] = None,
291
+ output_attentions: Optional[bool] = None,
292
+ output_hidden_states: Optional[bool] = None,
293
+ return_dict: Optional[bool] = None,
294
+ ) -> Union[Tuple[torch.FloatTensor], QuestionAnsweringModelOutput]:
295
+ r"""
296
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
297
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
298
+ Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
299
+ are not taken into account for computing the loss.
300
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
301
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
302
+ Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
303
+ are not taken into account for computing the loss.
304
+ Returns:
305
+ """
306
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
307
+
308
+ encoder_outputs = self.transformer(
309
+ input_ids=input_ids,
310
+ attention_mask=attention_mask,
311
+ inputs_embeds=inputs_embeds,
312
+ head_mask=head_mask,
313
+ output_attentions=output_attentions,
314
+ output_hidden_states=output_hidden_states,
315
+ return_dict=return_dict,
316
+ )
317
+
318
+ sequence_output = encoder_outputs[0]
319
+
320
+ logits = self.qa_outputs(sequence_output)
321
+ start_logits, end_logits = logits.split(1, dim=-1)
322
+ start_logits = start_logits.squeeze(-1).contiguous()
323
+ end_logits = end_logits.squeeze(-1).contiguous()
324
+
325
+ total_loss = None
326
+ if start_positions is not None and end_positions is not None:
327
+ # If we are on multi-GPU, split add a dimension
328
+ if len(start_positions.size()) > 1:
329
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
330
+ if len(end_positions.size()) > 1:
331
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
332
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
333
+ ignored_index = start_logits.size(1)
334
+ start_positions = start_positions.clamp(0, ignored_index)
335
+ end_positions = end_positions.clamp(0, ignored_index)
336
+
337
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
338
+ start_loss = loss_fct(start_logits, start_positions)
339
+ end_loss = loss_fct(end_logits, end_positions)
340
+ total_loss = (start_loss + end_loss) / 2
341
+
342
+ if not return_dict:
343
+ output = (start_logits, end_logits) + encoder_outputs[1:]
344
+ return ((total_loss,) + output) if total_loss is not None else output
345
+
346
+ return QuestionAnsweringModelOutput(
347
+ loss=total_loss,
348
+ start_logits=start_logits,
349
+ end_logits=end_logits,
350
+ hidden_states=encoder_outputs.hidden_states,
351
+ attentions=encoder_outputs.attentions,
352
+ )
353
+
354
+
355
+
356
+ class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
357
+ _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
358
+
359
+ def __init__(self, config: FlashT5Config):
360
+ super().__init__(config)
361
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
362
+
363
+ encoder_config = copy.deepcopy(config)
364
+ encoder_config.is_decoder = False
365
+ encoder_config.is_encoder_decoder = False
366
+ self.encoder = FlashT5Stack(encoder_config, self.shared)
367
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
368
+
369
+ # Initialize weights and apply final processing
370
+ self.post_init()
371
+
372
+ self.qa_outputs.weight.data.normal_(mean=0.0, std=config.initializer_factor * 1.0)
373
+ self.qa_outputs.bias.data.zero_()
374
+
375
  self.model_parallel = False
376
 
377
  def forward(
 
388
  ) -> Union[Tuple, QuestionAnsweringModelOutput]:
389
  r"""
390
  Returns:
 
391
  Example:
 
392
  ```python
393
  >>> from transformers import AutoTokenizer, MTxEncoderForQuestionAnswering
 
394
  >>> tokenizer = AutoTokenizer.from_pretrained("MTx-small")
395
  >>> model = MTxEncoderForQuestionAnswering.from_pretrained("MTx-small")
396
  >>> input_ids = tokenizer(
 
401
  >>> end_logits = outputs.end_logits
402
  ```"""
403
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
404
+
405
+ outputs = self.encoder(
406
  input_ids,
407
  attention_mask=attention_mask,
 
408
  inputs_embeds=inputs_embeds,
 
 
 
409
  )
410
  sequence_output = outputs[0]
411