bourdoiscatie commited on
Commit
948f07b
·
verified ·
1 Parent(s): 7e059b8

Update custom_heads_flash_t5.py

Browse files
Files changed (1) hide show
  1. custom_heads_flash_t5.py +0 -95
custom_heads_flash_t5.py CHANGED
@@ -307,98 +307,3 @@ class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
307
  hidden_states=encoder_outputs.hidden_states,
308
  attentions=encoder_outputs.attentions,
309
  )
310
-
311
-
312
-
313
- class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
314
- _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
315
-
316
- def __init__(self, config: FlashT5Config):
317
- super().__init__(config)
318
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
319
-
320
- encoder_config = copy.deepcopy(config)
321
- encoder_config.is_decoder = False
322
- encoder_config.is_encoder_decoder = False
323
- self.encoder = FlashT5Stack(encoder_config, self.shared)
324
- self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
325
-
326
- # Initialize weights and apply final processing
327
- self.post_init()
328
-
329
- self.qa_outputs.weight.data.normal_(mean=0.0, std=config.initializer_factor * 1.0)
330
- self.qa_outputs.bias.data.zero_()
331
-
332
- self.model_parallel = False
333
-
334
- def forward(
335
- self,
336
- input_ids: Optional[torch.LongTensor] = None,
337
- attention_mask: Optional[torch.FloatTensor] = None,
338
- head_mask: Optional[torch.FloatTensor] = None,
339
- inputs_embeds: Optional[torch.FloatTensor] = None,
340
- start_positions: Optional[torch.LongTensor] = None,
341
- end_positions: Optional[torch.LongTensor] = None,
342
- output_attentions: Optional[bool] = None,
343
- output_hidden_states: Optional[bool] = None,
344
- return_dict: Optional[bool] = None,
345
- ) -> Union[Tuple, QuestionAnsweringModelOutput]:
346
- r"""
347
- Returns:
348
-
349
- Example:
350
-
351
- ```python
352
- >>> from transformers import AutoTokenizer, MTxEncoderForQuestionAnswering
353
-
354
- >>> tokenizer = AutoTokenizer.from_pretrained("MTx-small")
355
- >>> model = MTxEncoderForQuestionAnswering.from_pretrained("MTx-small")
356
- >>> input_ids = tokenizer(
357
- ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
358
- ... ).input_ids # Batch size 1
359
- >>> outputs = model(input_ids=input_ids)
360
- >>> start_logits = outputs.start_logits
361
- >>> end_logits = outputs.end_logits
362
- ```"""
363
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
364
-
365
- outputs = self.encoder(
366
- input_ids,
367
- attention_mask=attention_mask,
368
- inputs_embeds=inputs_embeds,
369
- )
370
- sequence_output = outputs[0]
371
-
372
- logits = self.qa_outputs(sequence_output)
373
- start_logits, end_logits = logits.split(1, dim=-1)
374
- start_logits = start_logits.squeeze(-1).contiguous()
375
- end_logits = end_logits.squeeze(-1).contiguous()
376
-
377
- total_loss = None
378
- if start_positions is not None and end_positions is not None:
379
- # If we are on multi-GPU, split add a dimension
380
- if len(start_positions.size()) > 1:
381
- start_positions = start_positions.squeeze(-1).to(start_logits.device)
382
- if len(end_positions.size()) > 1:
383
- end_positions = end_positions.squeeze(-1).to(end_logits.device)
384
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
385
- ignored_index = start_logits.size(1)
386
- start_positions = start_positions.clamp(0, ignored_index)
387
- end_positions = end_positions.clamp(0, ignored_index)
388
-
389
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
390
- start_loss = loss_fct(start_logits, start_positions)
391
- end_loss = loss_fct(end_logits, end_positions)
392
- total_loss = (start_loss + end_loss) / 2
393
-
394
- if not return_dict:
395
- output = (start_logits, end_logits) + outputs[1:]
396
- return ((total_loss,) + output) if total_loss is not None else output
397
-
398
- return QuestionAnsweringModelOutput(
399
- loss=total_loss,
400
- start_logits=start_logits,
401
- end_logits=end_logits,
402
- hidden_states=outputs.hidden_states,
403
- attentions=outputs.attentions,
404
- )
 
307
  hidden_states=encoder_outputs.hidden_states,
308
  attentions=encoder_outputs.attentions,
309
  )