Markus28 commited on
Commit
4c68a4c
·
1 Parent(s): c35343d

try to simplify checkpointing

Browse files
Files changed (1) hide show
  1. modeling_bert.py +2 -247
modeling_bert.py CHANGED
@@ -329,9 +329,7 @@ class BertPreTrainedModel(nn.Module):
329
  """
330
  # Instantiate model.
331
  model = cls(config, *inputs, **kwargs)
332
- load_return = model.load_state_dict(
333
- remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False
334
- )
335
  logger.info(load_return)
336
  return model
337
 
@@ -528,247 +526,4 @@ class BertForPreTraining(BertPreTrainedModel):
528
  loss=total_loss,
529
  prediction_logits=prediction_scores,
530
  seq_relationship_logits=seq_relationship_score,
531
- )
532
-
533
-
534
- def remap_state_dict(state_dict, config: PretrainedConfig):
535
- """
536
- Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
537
- """
538
-
539
- # LayerNorm
540
- def key_mapping_ln_gamma_beta(key):
541
- key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
542
- key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
543
- return key
544
-
545
- state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
546
-
547
- # Layers
548
- def key_mapping_layers(key):
549
- return re.sub(r"^bert.encoder.layer.", "bert.encoder.layers.", key)
550
-
551
- state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
552
-
553
- # LayerNorm
554
- def key_mapping_ln(key):
555
- key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
556
- key = re.sub(
557
- r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
558
- r"bert.encoder.layers.\1.norm1.\2",
559
- key,
560
- )
561
- key = re.sub(
562
- r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
563
- r"bert.encoder.layers.\1.norm2.\2",
564
- key,
565
- )
566
- key = re.sub(
567
- r"^cls.predictions.transform.LayerNorm.(weight|bias)",
568
- r"cls.predictions.transform.layer_norm.\1",
569
- key,
570
- )
571
- return key
572
-
573
- state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
574
-
575
- # MLP
576
- def key_mapping_mlp(key):
577
- key = re.sub(
578
- r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
579
- r"bert.encoder.layers.\1.mlp.fc1.\2",
580
- key,
581
- )
582
- key = re.sub(
583
- r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
584
- r"bert.encoder.layers.\1.mlp.fc2.\2",
585
- key,
586
- )
587
- return key
588
-
589
- state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
590
-
591
- # Attention
592
- last_layer_subset = getattr(config, "last_layer_subset", False)
593
- for d in range(config.num_hidden_layers):
594
- Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
595
- Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
596
- Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
597
- bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
598
- bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
599
- bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
600
- if not (last_layer_subset and d == config.num_hidden_layers - 1):
601
- state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
602
- [Wq, Wk, Wv], dim=0
603
- )
604
- state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
605
- else:
606
- state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
607
- state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
608
- state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
609
- state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat([bk, bv], dim=0)
610
-
611
- def key_mapping_attn(key):
612
- return re.sub(
613
- r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
614
- r"bert.encoder.layers.\1.mixer.out_proj.\2",
615
- key,
616
- )
617
-
618
- state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
619
-
620
- def key_mapping_decoder_bias(key):
621
- return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
622
-
623
- state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
624
-
625
- # Word embedding
626
- pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
627
- if pad_vocab_size_multiple > 1:
628
- word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
629
- state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
630
- word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
631
- )
632
- decoder_weight = state_dict["cls.predictions.decoder.weight"]
633
- state_dict["cls.predictions.decoder.weight"] = F.pad(
634
- decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
635
- )
636
- # If the vocab was padded, we want to set the decoder bias for those padded indices to be
637
- # strongly negative (i.e. the decoder shouldn't predict those indices).
638
- # TD [2022-05-09]: I don't think it affects the MLPerf training.
639
- decoder_bias = state_dict["cls.predictions.decoder.bias"]
640
- state_dict["cls.predictions.decoder.bias"] = F.pad(
641
- decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
642
- )
643
-
644
- return state_dict
645
-
646
-
647
- def inv_remap_state_dict(state_dict, config: PretrainedConfig):
648
- """
649
- Map the state_dict of a flash_attn model to be Huggingface BERT compatible.
650
-
651
- This function is meant to be the inverse of remap_state_dict.
652
- """
653
- # Word embedding
654
- pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
655
- if pad_vocab_size_multiple > 1:
656
- word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
657
- decoder_weight = state_dict["cls.predictions.decoder.weight"]
658
- decoder_bias = state_dict["cls.predictions.decoder.bias"]
659
- # unpad embeddings
660
- state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
661
- : config.orig_vocab_size, :
662
- ]
663
- state_dict["cls.predictions.decoder.weight"] = decoder_weight[: config.orig_vocab_size, :]
664
- state_dict["cls.predictions.decoder.bias"] = decoder_bias[: config.orig_vocab_size]
665
-
666
- for d in range(config.num_hidden_layers):
667
- last_layer_subset = getattr(config, "last_layer_subset", False)
668
- if not last_layer_subset or d != (config.num_hidden_layers - 1):
669
- Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
670
- Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
671
- state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wqkv_weights[
672
- : Wqkv_weights.shape[0] // 3, :
673
- ]
674
- state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wqkv_weights[
675
- Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
676
- ]
677
- state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wqkv_weights[
678
- 2 * Wqkv_weights.shape[0] // 3 :, :
679
- ]
680
- state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wqkv_biases[
681
- : Wqkv_biases.shape[0] // 3
682
- ]
683
- state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wqkv_biases[
684
- Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3
685
- ]
686
- state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wqkv_biases[
687
- 2 * Wqkv_biases.shape[0] // 3 :
688
- ]
689
- else:
690
- Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
691
- Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
692
- Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
693
- Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
694
- state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wq_weight
695
- state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wkv_weights[
696
- : Wkv_weights.shape[0] // 2, :
697
- ]
698
- state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wkv_weights[
699
- Wkv_weights.shape[0] // 2 :, :
700
- ]
701
- state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
702
- state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
703
- : Wkv_biases.shape[0] // 2
704
- ]
705
- state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wkv_biases[
706
- Wkv_biases.shape[0] // 2 :
707
- ]
708
-
709
- def inv_key_mapping_ln(key):
710
- key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
711
- key = re.sub(
712
- r"bert.encoder.layers.(\d+).norm1.(weight|bias)",
713
- r"bert.encoder.layers.\1.attention.output.LayerNorm.\2",
714
- key,
715
- )
716
- key = re.sub(
717
- r"bert.encoder.layers.(\d+).norm2.(weight|bias)",
718
- r"bert.encoder.layers.\1.output.LayerNorm.\2",
719
- key,
720
- )
721
- key = re.sub(
722
- r"cls.predictions.transform.layer_norm.(weight|bias)",
723
- r"cls.predictions.transform.LayerNorm.\1",
724
- key,
725
- )
726
- return key
727
-
728
- def inv_key_mapping_ln_gamma_beta(key):
729
- key = re.sub(r"LayerNorm.weight$", "LayerNorm.gamma", key)
730
- key = re.sub(r"LayerNorm.bias$", "LayerNorm.beta", key)
731
- return key
732
-
733
- def inv_key_mapping_layers(key):
734
- return re.sub(r"bert.encoder.layers.", "bert.encoder.layer.", key)
735
-
736
- def inv_key_mapping_mlp(key):
737
- key = re.sub(
738
- r"bert.encoder.layer.(\d+).mlp.fc1.(weight|bias)",
739
- r"bert.encoder.layer.\1.intermediate.dense.\2",
740
- key,
741
- )
742
- key = re.sub(
743
- r"bert.encoder.layer.(\d+).mlp.fc2.(weight|bias)",
744
- r"bert.encoder.layer.\1.output.dense.\2",
745
- key,
746
- )
747
- return key
748
-
749
- def inv_key_mapping_attn(key):
750
- return re.sub(
751
- r"bert.encoder.layer.(\d+).mixer.out_proj.(weight|bias)",
752
- r"bert.encoder.layer.\1.attention.output.dense.\2",
753
- key,
754
- )
755
-
756
- def inv_key_mapping_decoder_bias(key):
757
- return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)
758
-
759
- state_dict = OrderedDict((inv_key_mapping_ln(key), value) for key, value in state_dict.items())
760
- state_dict = OrderedDict(
761
- (inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
762
- )
763
- state_dict = OrderedDict(
764
- (inv_key_mapping_layers(key), value) for key, value in state_dict.items()
765
- )
766
- state_dict = OrderedDict((inv_key_mapping_mlp(key), value) for key, value in state_dict.items())
767
- state_dict = OrderedDict(
768
- (inv_key_mapping_attn(key), value) for key, value in state_dict.items()
769
- )
770
- state_dict = OrderedDict(
771
- (inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
772
- )
773
-
774
- return state_dict
 
329
  """
330
  # Instantiate model.
331
  model = cls(config, *inputs, **kwargs)
332
+ load_return = model.load_state_dict(state_dict_from_pretrained(model_name), strict=False)
 
 
333
  logger.info(load_return)
334
  return model
335
 
 
526
  loss=total_loss,
527
  prediction_logits=prediction_scores,
528
  seq_relationship_logits=seq_relationship_score,
529
+ )