huxy912 commited on
Commit
4f1c13a
·
1 Parent(s): d2f45e6
README.md CHANGED
@@ -1,3 +1,88 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - MoE
7
  ---
8
+ # LLaMA-MoE-v2-3.8B (2/8) SFT
9
+
10
+ [[💻 Code]](https://github.com/OpenSparseLLMs/LLaMA-MoE-v2) | [[📃 Technical Report]](https://arxiv.org/pdf/2411.15708)
11
+
12
+ LLaMA-MoE-v2 is a series of open-sourced Mixture-of-Expert (MoE) models based on [LLaMA3](https://github.com/facebookresearch/llama).
13
+ We build LLaMA-MoE-v2 with the following two steps:
14
+ 1. **Partition** LLaMA's FFN layers or Attention layers into sparse experts and insert top-K gate for each layer of experts.
15
+ 2. Supervised fine-tuning the constructed MoE models using open-source data with a two-stage training.
16
+
17
+
18
+ | Model | \#Activated Experts | \#Experts | \#Activated Params | SFT Model |
19
+ | :-----------------------: | :-----------------: | :-------: | :----------------: | :------------------------------------------------------------------------: |
20
+ | **LLaMA-MLP-MoE (2/8)** | 2 | 8 | 3.8B | [🤗 SFT](https://huggingface.co/llama-moe/LLaMA-MoE-v2-3_8B-2_8-sft) |
21
+ | **LLaMA-MLP-MoE (1+1/7)** | 2 | 8 | 3.8B | [🤗 SFT](https://huggingface.co/llama-moe/LLaMA-MoE-v2-3_8B-residual-sft) |
22
+
23
+
24
+ ## 🚀 QuickStart
25
+
26
+ ```python
27
+ # python>=3.10
28
+
29
+ import torch
30
+ from transformers import AutoTokenizer, AutoModelForCausalLM
31
+
32
+ model_dir = "llama-moe/LLaMA-MoE-v2-3_8B-2_8-sft"
33
+ tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
34
+ model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16, trust_remote_code=True)
35
+ model.eval()
36
+ model.cuda()
37
+
38
+ input_text = "Could you recommend me some mystery novels?"
39
+ input_text = f"<|start_header_id|>user<|end_header_id|>\n\n{input_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
40
+ inputs = tokenizer(input_text, return_tensors="pt")
41
+ input_ids = inputs["input_ids"].cuda()
42
+
43
+ pred = model.generate(input_ids, max_length=200, temperature=1.0, do_sample=True, use_cache=True)
44
+ print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
45
+ """
46
+ I'd be delighted to recommend some mystery novels to you! Here are a few suggestions across various sub-genres:
47
+
48
+ **Classic Whodunit**
49
+
50
+ 1. "And Then There Were None" by Agatha Christie - A timeless tale of ten strangers who are invited to an isolated island, only to be killed off one by one.
51
+ 2. "The Murder on the Orient Express" by Agatha Christie - A classic whodunit set on a luxurious train traveling from Istanbul to Paris, where a famous author goes missing.
52
+ 3. "The Devil in the White City" by Erik Larson - A non-fiction book that combines historical events with a mystery, exploring the 1893 World's Columbian Exposition in Chicago and the serial killer H.H. Holmes.
53
+
54
+ **Modern Whodunits**
55
+
56
+ 1. "Gone Girl" by Gillian Flynn - A twisty, psychological thriller about a couple whose seemingly perfect ...
57
+ """
58
+ ```
59
+
60
+ ## 📊 Performance
61
+
62
+ | Model | #Training Tokens | MMLU(5) | GSM8k(8) | HumanEval(pass@10) | IFEval | BoolQ(32) | SciQ | PIQA | ARC-c(25) | TruthfulQA | HellaSwag(10) |
63
+ |:---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
64
+ | [LLaMA3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | 15T | 67.2 | 76.5 | 71.4 | 76.5 | 83.0 | 93.2 | 78.5 | 61.9 | 51.7 | 78.8 |
65
+ | [INCITE-3B](https://huggingface.co/togethercomputer/RedPajama-INCITE-Instruct-3B-v1) | 1T | 25.1 | 2.1 | 6.92 | 30.1 | 66.5 | 94.7 | 74.4 | 40.2 | 36.4 | 65.6 |
66
+ | [Sheared-LLaMA-2.7B](https://huggingface.co/princeton-nlp/Sheared-LLaMA-2.7B-ShareGPT) | 50B | 28.2 | 1.9 | 3.2 | 28.8 | 67.6 | 75.8 | 41.1 | 47.6 | 71.2 | 39.0 |
67
+ | [Gemma-2-2b](https://huggingface.co/google/gemma-2-2b-it) | 2T | 53.0 | 26.3 | 46.1 | 34.9 | 72.3 | 75.8 | 67.5 | 52.6 | 50.8 | 69.0 |
68
+ | [Salamandra-2b](https://huggingface.co/BSC-LT/salamandra-2b-instruct) | 7.8T | 25.1 | 1.90 | 5.82 | 27.7 | 68.0 | 89.8 | 74.7 | 46.3 | 43.4 | 62.3 |
69
+ | [SmolLM2-1.7B](https://huggingface.co/HuggingFaceTB/SmolLM2-1.7B-Instruct) | 11T | 50.4 | 38.5 | 39.1 | 29.0 | 68.2 | 84.3 | 76.0 | 53.2 | 39.9 | 72.6 |
70
+ | [OpenMoE-3B-9B](https://huggingface.co/OrionZheng/openmoe-8b-chat) | 1T | 26.5 | 1.36 | 1.01 | 31.2 | 61.7 | 68.4 | 65.7 | 33.3 | 40.5 | 56.5 |
71
+ | [LLaMA-MoE-3B-7B](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-2_8-sft) | 200B | 28.2 | 4.62 | 12.0 | 28.1 | 68.1 | 88.8 | 77.9 | 44.0 | 33.3 | 73.2 |
72
+ | [OLMoE-1B-7B](https://huggingface.co/allenai/OLMoE-1B-7B-0924-SFT) | 1T | 53.8 | 40.9 | 40.5 | 35.5 | 80.9 | 94.9 | 80.1 | 55.6 | 43.3 | 79.6 |
73
+ | **MLP-MoE (8top2)** | **7B** | 40.6 | 53.1 | 53.5 | 32.7 | 74.6 | 90.6 | 69.3 | 42.8 | 45.6 | 59.0 |
74
+ | **MLP-MoE (8top2)** | **8.4B** | 41.0 | **59.6** | **57.1** | 31.7 | 74.5 | 90.2 | 69.5 | 43.3 | 46.9 | 58.1 |
75
+ | **MLP-MoE (1+7top1)** | **7B** | 42.7 | 55.0 | 51.2 | **36.0** | 76.9 | 88.8 | 67.9 | 40.2 | 46.9 | 53.7 |
76
+
77
+
78
+ ## 📃 Citation
79
+
80
+ ```bibtex
81
+ @misc{llama-moe-v2,
82
+ title={LLaMA-MoE v2: Exploring Sparsity of LLaMA from Perspective of Mixture-of-Experts with Post-Training},
83
+ author={Xiaoye Qu, Daize Dong, Xuyang Hu, Tong Zhu, Weigao Sun, Yu Cheng},
84
+ year={2024},
85
+ month={Nov},
86
+ url={https://arxiv.org/abs/2411.15708}
87
+ }
88
+ ```
config.json CHANGED
@@ -1,11 +1,12 @@
1
  {
2
- "_name_or_path": "/mnt/petrelfs/huxuyang/LLaMA-MoE-v2/outputs/v2_mixtral/moe-res-droppad-nosys-all/3653852/checkpoint-3600",
3
  "add_rescale_bias": false,
4
  "architectures": [
5
  "MixtralForCausalLM"
6
  ],
7
  "attention_bias": false,
8
  "attention_dropout": 0.0,
 
9
  "auto_map": {
10
  "AutoConfig": "configuration_mixtral.MixtralConfig",
11
  "AutoModel": "modeling_mixtral.MixtralModel",
 
1
  {
2
+ "_name_or_path": "/mnt/petrelfs/quxiaoye/models/sft-v2/moe8top2_onestage",
3
  "add_rescale_bias": false,
4
  "architectures": [
5
  "MixtralForCausalLM"
6
  ],
7
  "attention_bias": false,
8
  "attention_dropout": 0.0,
9
+ "attn_experts": null,
10
  "auto_map": {
11
  "AutoConfig": "configuration_mixtral.MixtralConfig",
12
  "AutoModel": "modeling_mixtral.MixtralModel",
configuration_mixtral.py CHANGED
@@ -170,6 +170,7 @@ class MixtralConfig(PretrainedConfig):
170
  num_moe_contract_layers: int = 0, # 🔍 the number of layers that are not converted into MoE at each side of the model
171
  use_attn_moe: bool = False, # 🔍
172
  top_k_attn: int = None, # 🔍
 
173
  scale_factor_attn: float = None, # 🔍
174
  use_layer_wise_balance: bool = False, # ✨ whether to fix the balance loss bug for Mixtral
175
  add_rescale_bias: bool = False, # 🔍 whether to add bias to the AttentionMoE `o_proj` & MoE `down_proj` for distribution alignment
@@ -208,6 +209,7 @@ class MixtralConfig(PretrainedConfig):
208
  self.use_attn_moe = use_attn_moe
209
  self.top_k_attn = top_k_attn
210
  self.scale_factor_attn = scale_factor_attn
 
211
 
212
  # ✨ For balance loss bugfix
213
  self.use_layer_wise_balance = use_layer_wise_balance
@@ -232,11 +234,15 @@ class MixtralConfig(PretrainedConfig):
232
  if hasattr(self, "_attn_implementation_internal"):
233
  if self._attn_implementation_internal is None:
234
  # `config.attn_implementation` should never be None, for backward compatibility.
235
- return "eager"
 
236
  else:
237
  return self._attn_implementation_internal
238
  else:
239
- return "eager"
 
 
 
240
 
241
  @_attn_implementation.setter
242
  def _attn_implementation(self, value):
 
170
  num_moe_contract_layers: int = 0, # 🔍 the number of layers that are not converted into MoE at each side of the model
171
  use_attn_moe: bool = False, # 🔍
172
  top_k_attn: int = None, # 🔍
173
+ attn_experts: int = None,
174
  scale_factor_attn: float = None, # 🔍
175
  use_layer_wise_balance: bool = False, # ✨ whether to fix the balance loss bug for Mixtral
176
  add_rescale_bias: bool = False, # 🔍 whether to add bias to the AttentionMoE `o_proj` & MoE `down_proj` for distribution alignment
 
209
  self.use_attn_moe = use_attn_moe
210
  self.top_k_attn = top_k_attn
211
  self.scale_factor_attn = scale_factor_attn
212
+ self.attn_experts = attn_experts
213
 
214
  # ✨ For balance loss bugfix
215
  self.use_layer_wise_balance = use_layer_wise_balance
 
234
  if hasattr(self, "_attn_implementation_internal"):
235
  if self._attn_implementation_internal is None:
236
  # `config.attn_implementation` should never be None, for backward compatibility.
237
+ return "flash_attention_2"
238
+ # return "eager"
239
  else:
240
  return self._attn_implementation_internal
241
  else:
242
+ return "flash_attention_2"
243
+ # return "eager"
244
+
245
+
246
 
247
  @_attn_implementation.setter
248
  def _attn_implementation(self, value):
model-00001-of-00004.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e5bc31ca5dbdc23c38713b734d2654cfa413133981c35bdb633ea0d310f90cb8
3
  size 4977314560
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5c37f87fd8cb399be7701cafd53561b462b68451df8888e37e27a87afd9cd80
3
  size 4977314560
model-00002-of-00004.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:762ce2834feae9ba9f238e4d927104291da3d73198328f31ebd722c6429cae17
3
  size 4985941976
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f24b0cd37967f622d52fb345d76dcd0f26d41959a70d1bf940b9ca28f9f2bef
3
  size 4985941976
model-00003-of-00004.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:504bebbc7dabc69a76ef584204bdcbbcef1f31b9e61e39aa5c96690aa9461522
3
  size 4990070968
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:501a1bdc13d200b85e7f9be67535da141dd39092f31dd91d90f220469d67d395
3
  size 4990070968
model-00004-of-00004.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:59ac862ad596b661d4101aba2d2555e37c3ab8617e4bb4107737cfe63e7aca40
3
  size 1109418960
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:086b141ea5a5163e6fcbaf0db9fa5439476fa4eac16c8b1cb0f4de33d8ceebb7
3
  size 1109418960
modeling_mixtral.py CHANGED
@@ -49,8 +49,6 @@ from transformers.utils.import_utils import (
49
  is_torchdynamo_compiling,
50
  )
51
 
52
- from smoe.utils.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
53
-
54
  from .configuration_mixtral import MixtralConfig
55
 
56
  logger = logging.get_logger(__name__)
@@ -123,6 +121,338 @@ def is_flash_attn_available():
123
  return is_flash_attn_2_available()
124
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  @dataclass
127
  class MoeCausalLMOutputWithPast(ModelOutput):
128
  """
@@ -270,7 +600,7 @@ def load_balancing_loss_func(
270
  Returns:
271
  The auxiliary loss.
272
  """
273
- if gate_logits is None:
274
  return 0
275
 
276
  # ✨ Here is the fix for balance loss in Mixtral.
@@ -812,16 +1142,20 @@ class MixtralAttentionMoE(MixtralAttention):
812
  )
813
 
814
  # 🔍
815
- self.gate = nn.Linear(self.hidden_size, self.num_key_value_heads, bias=False)
816
  self.softmax = nn.Softmax(dim=-1)
817
  self.top_k_attn = config.top_k_attn
 
818
  self.scale_factor_attn = config.scale_factor_attn
819
 
 
 
 
 
820
  # 🔍
821
- self.q_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_key_value_groups * self.head_dim, bias=False) for _ in range(self.num_key_value_heads)])
822
- self.k_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.head_dim, bias=False) for _ in range(self.num_key_value_heads)])
823
- self.v_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.head_dim, bias=False) for _ in range(self.num_key_value_heads)])
824
- self.o_proj = nn.ModuleList([nn.Linear(self.num_key_value_groups * self.head_dim, self.hidden_size, bias=config.add_rescale_bias) for _ in range(self.num_key_value_heads)]) # 🔍 (may add bias for rescaling)
825
 
826
  self.rotary_emb = MixtralRotaryEmbedding(
827
  self.head_dim,
@@ -847,6 +1181,7 @@ class MixtralAttentionMoE(MixtralAttention):
847
  raise TypeError(
848
  "`past_key_value` must be a `MoECache` instance for attention MoE!"
849
  )
 
850
  device = hidden_states.device
851
  dtype = hidden_states.dtype
852
  bsz, q_len, hidden_dim = hidden_states.size()
@@ -865,12 +1200,12 @@ class MixtralAttentionMoE(MixtralAttention):
865
 
866
  # One hot encode the selected experts to create an expert mask
867
  # this will be used to easily index which expert is going to be sollicitated
868
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_key_value_heads) # (bsz * q_len, top_k_attn, num_key_value_heads)
869
  expert_mask = expert_mask.permute(2, 1, 0) # (num_key_value_heads, top_k_attn, bsz * q_len)
870
 
871
  # Loop over all available experts in the model and perform the computation on each expert
872
  all_attn_weights = [] if output_attentions else None
873
- for expert_idx in range(self.num_key_value_heads):
874
  # expert_mask[expert_idx]: (top_k_attn, bsz * q_len)
875
  # idx: the topk position. (selected_num)
876
  # top_x: token index. (selected_num)
@@ -911,7 +1246,7 @@ class MixtralAttentionMoE(MixtralAttention):
911
  key_states = self.k_proj[expert_idx](current_state) # 🔍 specify expert
912
  value_states = self.v_proj[expert_idx](current_state) # 🔍 specify expert
913
 
914
- query_states = query_states.view(bsz, this_q_len, self.num_key_value_groups, self.head_dim).transpose(1, 2) # 🔍 q_len -> this_q_len, num_heads -> num_key_value_groups
915
  key_states = key_states.view(bsz, this_q_len, 1, self.head_dim).transpose(1, 2) # 🔍 q_len -> this_q_len, num_key_value_heads -> 1
916
  value_states = value_states.view(bsz, this_q_len, 1, self.head_dim).transpose(1, 2) # 🔍 q_len -> this_q_len, num_key_value_heads -> 1
917
 
@@ -946,8 +1281,8 @@ class MixtralAttentionMoE(MixtralAttention):
946
 
947
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) # softmax temperature
948
 
949
- if attn_weights.size() != (bsz, self.num_key_value_groups, this_q_len, kv_seq_len): # 🔍 q_len -> this_q_len, num_heads -> num_key_value_groups
950
- raise ValueError(f"Attention weights should be of size {(bsz, self.num_key_value_groups, this_q_len, kv_seq_len)}, but is {attn_weights.size()}")
951
 
952
  # 🔍 create `current_attention_mask` with reduced `seq_len`
953
  # Notice that the `attention_mask` is passed intact during both training & generation, so we need to adjust the `top_x` by `past_key_values_length`.
@@ -961,11 +1296,12 @@ class MixtralAttentionMoE(MixtralAttention):
961
  temp_attention_mask = attention_mask[:, previous_seen_tokens_total:].flatten() # select along dimension 1 so that we get tokens in this iteration
962
  else:
963
  temp_attention_mask = attention_mask.flatten() # flatten the dim
964
- current_attention_mask[current_batch_ids, current_seq_ids] = temp_attention_mask[top_x] # assign masks sparsely
965
 
966
  else:
967
  current_attention_mask[current_batch_ids, current_seq_ids] = True # assign masks sparsely
968
 
 
969
  if past_key_value is not None: # 🔍 we need to update with cached attention mask
970
  current_attention_mask = past_key_value.update_attention_mask(current_attention_mask, self.layer_idx, expert_idx)
971
 
@@ -983,17 +1319,17 @@ class MixtralAttentionMoE(MixtralAttention):
983
  raise ValueError(f"Attention mask should be of size {(bsz, 1, this_q_len, kv_seq_len)}, but is {current_attention_mask.size()}")
984
 
985
  attn_weights = attn_weights + current_attention_mask # 🔍
986
-
987
  # upcast attention to fp32
988
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
989
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
990
  attn_output = torch.matmul(attn_weights, value_states)
991
 
992
- if attn_output.size() != (bsz, self.num_key_value_groups, this_q_len, self.head_dim): # 🔍 q_len -> this_q_len, num_heads -> num_key_value_groups
993
- raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is {attn_output.size()}")
994
 
995
  attn_output = attn_output.transpose(1, 2).contiguous()
996
- attn_output = attn_output.reshape(bsz, this_q_len, self.num_key_value_groups * self.head_dim) # 🔍 q_len -> this_q_len, hidden_size -> num_key_value_groups * head_dim
997
  attn_output = self.o_proj[expert_idx](attn_output)
998
  # ---------------------------------------------- #
999
 
@@ -1026,27 +1362,16 @@ class MixtralAttentionMoE(MixtralAttention):
1026
  # init
1027
  attention_moe = MixtralAttentionMoE(config, layer_idx)
1028
 
 
1029
  # copy weights
1030
- num_key_value_groups = attention_moe.num_key_value_groups
1031
  head_dim = attention_moe.head_dim
1032
 
1033
- # attention
1034
- # q_proj: (self.hidden_size, self.num_heads * self.head_dim)
1035
- # k_proj: (self.hidden_size, self.num_key_value_heads * self.head_dim)
1036
- # v_proj: (self.hidden_size, self.num_key_value_heads * self.head_dim)
1037
- # o_proj: (self.num_heads * self.head_dim, self.hidden_size)
1038
-
1039
- # attention_moe
1040
- # q_proj: (self.hidden_size, self.num_key_value_groups * self.head_dim)
1041
- # k_proj: (self.hidden_size, self.head_dim)
1042
- # v_proj: (self.hidden_size, self.head_dim)
1043
- # o_proj: (self.num_key_value_groups * self.head_dim, self.hidden_size)
1044
-
1045
- for i in range(config.num_key_value_heads):
1046
  indices_q_o = [j for j in range(head_dim * num_key_value_groups * i, head_dim * num_key_value_groups * (i + 1))]
1047
- indices_k_v = [j for j in range(head_dim * i, head_dim * (i + 1))]
1048
 
1049
- # print(i, "indices_q_o", indices_q_o)
1050
  # print(i, "indices_k_v", indices_k_v)
1051
 
1052
  attention_moe.q_proj[i].weight.data = attention.q_proj.weight.data[indices_q_o].clone()
@@ -1204,6 +1529,7 @@ class MixtralFlashAttention2(MixtralAttention):
1204
  key_states = key_states.transpose(1, 2)
1205
  value_states = value_states.transpose(1, 2)
1206
 
 
1207
  attn_output = self._flash_attention_forward(
1208
  query_states,
1209
  key_states,
@@ -1341,7 +1667,6 @@ class MixtralFlashAttention2(MixtralAttention):
1341
  self, query_layer, key_layer, value_layer, attention_mask, query_length
1342
  ):
1343
  batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
1344
-
1345
  # On the first iteration we need to properly re-create the padding mask
1346
  # by slicing it on the proper place
1347
  if kv_seq_len != attention_mask.shape[-1]:
@@ -1389,6 +1714,517 @@ class MixtralFlashAttention2(MixtralAttention):
1389
  )
1390
 
1391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1392
  class MixtralBLockSparseTop2MLP(nn.Module):
1393
  def __init__(self, config: MixtralConfig, ffn_dim, add_rescale_bias=False): # 🔍
1394
  super().__init__()
@@ -1419,7 +2255,7 @@ MISTRAL_ATTENTION_CLASSES = {
1419
  # 🔍
1420
  MISTRAL_ATTENTION_MOE_CLASSES = {
1421
  "eager": MixtralAttentionMoE,
1422
- "flash_attention_2": None,
1423
  }
1424
 
1425
 
@@ -1698,13 +2534,14 @@ class MixtralDecoderLayer(nn.Module):
1698
  )
1699
  self.use_attn_moe = config.use_attn_moe
1700
 
 
 
 
 
 
 
 
1701
  if self.is_moe:
1702
- attn_class = (
1703
- MISTRAL_ATTENTION_MOE_CLASSES[config._attn_implementation]
1704
- if self.use_attn_moe
1705
- else MISTRAL_ATTENTION_CLASSES[config._attn_implementation]
1706
- )
1707
- self.self_attn = attn_class(config, layer_idx)
1708
  self.block_sparse_moe = MixtralSparseMoeBlock(config)
1709
  self.mlp_residual = (
1710
  MixtralBLockSparseTop2MLP(config, config.intermediate_size_residual)
@@ -1713,8 +2550,6 @@ class MixtralDecoderLayer(nn.Module):
1713
  )
1714
 
1715
  else:
1716
- attn_class = MISTRAL_ATTENTION_CLASSES[config._attn_implementation]
1717
- self.self_attn = attn_class(config, layer_idx)
1718
  self.block_sparse_moe = MixtralBLockSparseTop2MLP(
1719
  config, config.intermediate_size * config.num_local_experts
1720
  )
@@ -1766,7 +2601,7 @@ class MixtralDecoderLayer(nn.Module):
1766
  hidden_states = self.input_layernorm(hidden_states)
1767
 
1768
  # 🔍 Self Attention
1769
- if self.is_moe and self.use_attn_moe:
1770
  (
1771
  hidden_states,
1772
  self_attn_weights,
@@ -1795,18 +2630,18 @@ class MixtralDecoderLayer(nn.Module):
1795
 
1796
  # Fully Connected
1797
  residual = hidden_states
1798
- hidden_states = self.post_attention_layernorm(hidden_states)
1799
 
1800
  # 🔍
1801
  if self.is_moe:
1802
- hidden_states, router_logits = self.block_sparse_moe(hidden_states)
1803
  else:
1804
- hidden_states = self.block_sparse_moe(hidden_states)
1805
  router_logits = None
1806
 
1807
  if self.mlp_residual is not None:
1808
- # hidden_states += self.mlp_residual(hidden_states) # 🔍
1809
- hidden_states = self.mlp_residual(hidden_states) + hidden_states # 🔍
1810
  hidden_states = residual + hidden_states
1811
 
1812
  outputs = (hidden_states,)
@@ -2223,7 +3058,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
2223
  if len(valid_attn_router_logits) > 0: # exist logits that is not None
2224
  attn_aux_loss = load_balancing_loss_func(
2225
  valid_attn_router_logits,
2226
- self.config.num_key_value_heads,
2227
  self.config.top_k_attn,
2228
  use_layer_wise_balance=self.config.use_layer_wise_balance, # ✨
2229
  )
@@ -2632,7 +3467,8 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
2632
  if past is None:
2633
  if self.config.use_attn_moe: # 🔍
2634
  model_kwargs["past_key_values"] = MoECache(
2635
- self.config.num_key_value_heads
 
2636
  )
2637
  else: # 🔍
2638
  model_kwargs["past_key_values"] = DynamicCache()
 
49
  is_torchdynamo_compiling,
50
  )
51
 
 
 
52
  from .configuration_mixtral import MixtralConfig
53
 
54
  logger = logging.get_logger(__name__)
 
121
  return is_flash_attn_2_available()
122
 
123
 
124
+ @dataclass
125
+ class AttentionMaskConverter:
126
+ """
127
+ A utility attention mask class that allows one to:
128
+ - Create a causal 4d mask
129
+ - Create a causal 4d mask with slided window
130
+ - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
131
+ key_value_length) that can be multiplied with attention scores
132
+
133
+ Examples:
134
+
135
+ ```python
136
+ >>> import torch
137
+ >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
138
+
139
+ >>> converter = AttentionMaskConverter(True)
140
+ >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
141
+ tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
142
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
143
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
144
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
145
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
146
+ ```
147
+
148
+ Parameters:
149
+ is_causal (`bool`):
150
+ Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
151
+
152
+ sliding_window (`int`, *optional*):
153
+ Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
154
+ """
155
+
156
+ is_causal: bool
157
+ sliding_window: int
158
+
159
+ def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
160
+ self.is_causal = is_causal
161
+ self.sliding_window = sliding_window
162
+
163
+ if self.sliding_window is not None and self.sliding_window <= 0:
164
+ raise ValueError(
165
+ f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
166
+ )
167
+
168
+ def to_causal_4d(
169
+ self,
170
+ batch_size: int,
171
+ query_length: int,
172
+ key_value_length: int,
173
+ dtype: torch.dtype,
174
+ device: Union[torch.device, "str"] = "cpu",
175
+ ) -> Optional[torch.Tensor]:
176
+ """
177
+ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
178
+ bias to upper right hand triangular matrix (causal mask).
179
+ """
180
+ if not self.is_causal:
181
+ raise ValueError(
182
+ f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True."
183
+ )
184
+
185
+ # If shape is not cached, create a new causal mask and cache it
186
+ input_shape = (batch_size, query_length)
187
+ past_key_values_length = key_value_length - query_length
188
+
189
+ # create causal mask
190
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
191
+ causal_4d_mask = None
192
+ if input_shape[-1] > 1 or self.sliding_window is not None:
193
+ causal_4d_mask = self._make_causal_mask(
194
+ input_shape,
195
+ dtype,
196
+ device=device,
197
+ past_key_values_length=past_key_values_length,
198
+ sliding_window=self.sliding_window,
199
+ )
200
+
201
+ return causal_4d_mask
202
+
203
+ def to_4d(
204
+ self,
205
+ attention_mask_2d: torch.Tensor,
206
+ query_length: int,
207
+ dtype: torch.dtype,
208
+ key_value_length: Optional[int] = None,
209
+ ) -> torch.Tensor:
210
+ """
211
+ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
212
+ key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
213
+ causal, a causal mask will be added.
214
+ """
215
+ input_shape = (attention_mask_2d.shape[0], query_length)
216
+
217
+ # create causal mask
218
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
219
+ causal_4d_mask = None
220
+ if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
221
+ if key_value_length is None:
222
+ raise ValueError(
223
+ "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
224
+ )
225
+
226
+ past_key_values_length = key_value_length - query_length
227
+ causal_4d_mask = self._make_causal_mask(
228
+ input_shape,
229
+ dtype,
230
+ device=attention_mask_2d.device,
231
+ past_key_values_length=past_key_values_length,
232
+ sliding_window=self.sliding_window,
233
+ )
234
+ elif self.sliding_window is not None:
235
+ raise NotImplementedError(
236
+ "Sliding window is currently only implemented for causal masking"
237
+ )
238
+
239
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
240
+ expanded_attn_mask = self._expand_mask(
241
+ attention_mask_2d, dtype, tgt_len=input_shape[-1]
242
+ ).to(attention_mask_2d.device)
243
+ if causal_4d_mask is not None:
244
+ expanded_attn_mask = causal_4d_mask.masked_fill(
245
+ expanded_attn_mask.bool(), torch.finfo(dtype).min
246
+ )
247
+
248
+ # expanded_attn_mask + causal_4d_mask can cause some overflow
249
+ expanded_4d_mask = expanded_attn_mask
250
+
251
+ return expanded_4d_mask
252
+
253
+ @staticmethod
254
+ def _make_causal_mask(
255
+ input_ids_shape: torch.Size,
256
+ dtype: torch.dtype,
257
+ device: torch.device,
258
+ past_key_values_length: int = 0,
259
+ sliding_window: Optional[int] = None,
260
+ ):
261
+ """
262
+ Make causal mask used for bi-directional self-attention.
263
+ """
264
+ bsz, tgt_len = input_ids_shape
265
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
266
+ mask_cond = torch.arange(mask.size(-1), device=device)
267
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
268
+
269
+ mask = mask.to(dtype)
270
+
271
+ if past_key_values_length > 0:
272
+ mask = torch.cat(
273
+ [
274
+ torch.zeros(
275
+ tgt_len, past_key_values_length, dtype=dtype, device=device
276
+ ),
277
+ mask,
278
+ ],
279
+ dim=-1,
280
+ )
281
+
282
+ # add lower triangular sliding window mask if necessary
283
+ if sliding_window is not None:
284
+ diagonal = past_key_values_length - sliding_window + 1
285
+
286
+ context_mask = 1 - torch.triu(
287
+ torch.ones_like(mask, dtype=torch.int), diagonal=diagonal
288
+ )
289
+ mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
290
+
291
+ return mask[None, None, :, :].expand(
292
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
293
+ )
294
+
295
+ @staticmethod
296
+ def _expand_mask(
297
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
298
+ ):
299
+ """
300
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
301
+ """
302
+ bsz, src_len = mask.size()
303
+ tgt_len = tgt_len if tgt_len is not None else src_len
304
+
305
+ expanded_mask = (
306
+ mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
307
+ )
308
+
309
+ inverted_mask = 1.0 - expanded_mask
310
+
311
+ return inverted_mask.masked_fill(
312
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
313
+ )
314
+
315
+ @staticmethod
316
+ def _unmask_unattended(
317
+ expanded_mask: torch.Tensor,
318
+ attention_mask: torch.Tensor,
319
+ unmasked_value: Union[bool, float],
320
+ ):
321
+ # fmt: off
322
+ """
323
+ Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
324
+ using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
325
+ Details: https://github.com/pytorch/pytorch/issues/110213
326
+
327
+ `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
328
+ `attention_mask` is [bsz, src_seq_len].
329
+
330
+ The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
331
+
332
+ For example, if `attention_mask` is
333
+ ```
334
+ [[0, 0, 1],
335
+ [1, 1, 1],
336
+ [0, 1, 1]]
337
+ ```
338
+ and `expanded_mask` is (e.g. here left-padding case)
339
+ ```
340
+ [[[[0, 0, 0],
341
+ [0, 0, 0],
342
+ [0, 0, 1]]],
343
+ [[[1, 0, 0],
344
+ [1, 1, 0],
345
+ [1, 1, 1]]],
346
+ [[[0, 0, 0],
347
+ [0, 1, 0],
348
+ [0, 1, 1]]]]
349
+ ```
350
+ then the modified `expanded_mask` will be
351
+ ```
352
+ [[[[1, 1, 1], <-- modified
353
+ [1, 1, 1], <-- modified
354
+ [0, 0, 1]]],
355
+ [[[1, 0, 0],
356
+ [1, 1, 0],
357
+ [1, 1, 1]]],
358
+ [[[1, 1, 1], <-- modified
359
+ [0, 1, 0],
360
+ [0, 1, 1]]]]
361
+ ```
362
+ """
363
+ # fmt: on
364
+
365
+ # Get the index of the first non-zero value for every sample in the batch.
366
+ # In the above example, indices = [[2], [0], [1]]]
367
+ tmp = torch.arange(attention_mask.shape[1], 0, -1)
368
+ indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True)
369
+
370
+ # Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the
371
+ # expanded mask will be completely unattended.
372
+ left_masked_rows = torch.where(indices > 0)[0]
373
+
374
+ if left_masked_rows.shape[0] == 0:
375
+ return expanded_mask
376
+ indices = indices[left_masked_rows]
377
+
378
+ max_len = torch.max(indices)
379
+ range_tensor = torch.arange(max_len).unsqueeze(0)
380
+ range_tensor = range_tensor.repeat(indices.size(0), 1)
381
+
382
+ # Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above.
383
+ range_tensor[range_tensor >= indices] = 0
384
+
385
+ # TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case
386
+ if expanded_mask.dim() == 4:
387
+ num_masks = expanded_mask.shape[1]
388
+ if num_masks == 1:
389
+ # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
390
+ mask_slice = (left_masked_rows[:, None], 0, range_tensor)
391
+ else:
392
+ # Broadcast [left_masked_rows, 1, 1], [1, num_masks, 1], [left_masked_rows, 1, max_len]
393
+ mask_slice = (
394
+ left_masked_rows[:, None, None],
395
+ torch.arange(num_masks)[None, :, None],
396
+ range_tensor[:, None, :],
397
+ )
398
+ else:
399
+ # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
400
+ mask_slice = (left_masked_rows[:, None], range_tensor)
401
+
402
+ expanded_mask[mask_slice] = unmasked_value
403
+
404
+ return expanded_mask
405
+
406
+
407
+ def _prepare_4d_causal_attention_mask(
408
+ attention_mask: Optional[torch.Tensor],
409
+ input_shape: Union[torch.Size, Tuple, List],
410
+ inputs_embeds: torch.Tensor,
411
+ past_key_values_length: int,
412
+ sliding_window: Optional[int] = None,
413
+ ):
414
+ """
415
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
416
+ `(batch_size, key_value_length)`
417
+
418
+ Args:
419
+ attention_mask (`torch.Tensor` or `None`):
420
+ A 2D attention mask of shape `(batch_size, key_value_length)`
421
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
422
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
423
+ inputs_embeds (`torch.Tensor`):
424
+ The embedded inputs as a torch Tensor.
425
+ past_key_values_length (`int`):
426
+ The length of the key value cache.
427
+ sliding_window (`int`, *optional*):
428
+ If the model uses windowed attention, a sliding window should be passed.
429
+ """
430
+ attn_mask_converter = AttentionMaskConverter(
431
+ is_causal=True, sliding_window=sliding_window
432
+ )
433
+
434
+ key_value_length = input_shape[-1] + past_key_values_length
435
+
436
+ # 4d mask is passed through the layers
437
+ if attention_mask is not None:
438
+ attention_mask = attn_mask_converter.to_4d(
439
+ attention_mask,
440
+ input_shape[-1],
441
+ key_value_length=key_value_length,
442
+ dtype=inputs_embeds.dtype,
443
+ )
444
+ else:
445
+ attention_mask = attn_mask_converter.to_causal_4d(
446
+ input_shape[0],
447
+ input_shape[-1],
448
+ key_value_length,
449
+ dtype=inputs_embeds.dtype,
450
+ device=inputs_embeds.device,
451
+ )
452
+
453
+ return attention_mask
454
+
455
+
456
  @dataclass
457
  class MoeCausalLMOutputWithPast(ModelOutput):
458
  """
 
600
  Returns:
601
  The auxiliary loss.
602
  """
603
+ if gate_logits is None or (isinstance(gate_logits, Iterable) and len(gate_logits) == 0):
604
  return 0
605
 
606
  # ✨ Here is the fix for balance loss in Mixtral.
 
1142
  )
1143
 
1144
  # 🔍
 
1145
  self.softmax = nn.Softmax(dim=-1)
1146
  self.top_k_attn = config.top_k_attn
1147
+ self.attn_experts = config.attn_experts
1148
  self.scale_factor_attn = config.scale_factor_attn
1149
 
1150
+ self.split_ratio = self.attn_experts // self.num_key_value_heads
1151
+
1152
+ self.gate = nn.Linear(self.hidden_size, self.attn_experts, bias=False)
1153
+
1154
  # 🔍
1155
+ self.q_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_key_value_groups * self.head_dim // self.split_ratio, bias=False) for _ in range(self.attn_experts)])
1156
+ self.k_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.head_dim, bias=False) for _ in range(self.attn_experts)])
1157
+ self.v_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.head_dim, bias=False) for _ in range(self.attn_experts)])
1158
+ self.o_proj = nn.ModuleList([nn.Linear(self.num_key_value_groups * self.head_dim // self.split_ratio, self.hidden_size, bias=config.add_rescale_bias) for _ in range(self.attn_experts)]) # 🔍 (may add bias for rescaling)
1159
 
1160
  self.rotary_emb = MixtralRotaryEmbedding(
1161
  self.head_dim,
 
1181
  raise TypeError(
1182
  "`past_key_value` must be a `MoECache` instance for attention MoE!"
1183
  )
1184
+ # print("attention_mask", attention_mask, attention_mask.shape)
1185
  device = hidden_states.device
1186
  dtype = hidden_states.dtype
1187
  bsz, q_len, hidden_dim = hidden_states.size()
 
1200
 
1201
  # One hot encode the selected experts to create an expert mask
1202
  # this will be used to easily index which expert is going to be sollicitated
1203
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.attn_experts) # (bsz * q_len, top_k_attn, num_key_value_heads)
1204
  expert_mask = expert_mask.permute(2, 1, 0) # (num_key_value_heads, top_k_attn, bsz * q_len)
1205
 
1206
  # Loop over all available experts in the model and perform the computation on each expert
1207
  all_attn_weights = [] if output_attentions else None
1208
+ for expert_idx in range(self.attn_experts):
1209
  # expert_mask[expert_idx]: (top_k_attn, bsz * q_len)
1210
  # idx: the topk position. (selected_num)
1211
  # top_x: token index. (selected_num)
 
1246
  key_states = self.k_proj[expert_idx](current_state) # 🔍 specify expert
1247
  value_states = self.v_proj[expert_idx](current_state) # 🔍 specify expert
1248
 
1249
+ query_states = query_states.view(bsz, this_q_len, self.num_key_value_groups // self.split_ratio, self.head_dim).transpose(1, 2) # 🔍 q_len -> this_q_len, num_heads -> num_key_value_groups
1250
  key_states = key_states.view(bsz, this_q_len, 1, self.head_dim).transpose(1, 2) # 🔍 q_len -> this_q_len, num_key_value_heads -> 1
1251
  value_states = value_states.view(bsz, this_q_len, 1, self.head_dim).transpose(1, 2) # 🔍 q_len -> this_q_len, num_key_value_heads -> 1
1252
 
 
1281
 
1282
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) # softmax temperature
1283
 
1284
+ if attn_weights.size() != (bsz, self.num_key_value_groups // self.split_ratio, this_q_len, kv_seq_len): # 🔍 q_len -> this_q_len, num_heads -> num_key_value_groups
1285
+ raise ValueError(f"Attention weights should be of size {(bsz, self.num_key_value_groups // self.split_ratio, this_q_len, kv_seq_len)}, but is {attn_weights.size()}")
1286
 
1287
  # 🔍 create `current_attention_mask` with reduced `seq_len`
1288
  # Notice that the `attention_mask` is passed intact during both training & generation, so we need to adjust the `top_x` by `past_key_values_length`.
 
1296
  temp_attention_mask = attention_mask[:, previous_seen_tokens_total:].flatten() # select along dimension 1 so that we get tokens in this iteration
1297
  else:
1298
  temp_attention_mask = attention_mask.flatten() # flatten the dim
1299
+ current_attention_mask[current_batch_ids, current_seq_ids] = temp_attention_mask[top_x].bool() # assign masks sparsely
1300
 
1301
  else:
1302
  current_attention_mask[current_batch_ids, current_seq_ids] = True # assign masks sparsely
1303
 
1304
+ # print("current_attention_mask", current_attention_mask, current_attention_mask.shape)
1305
  if past_key_value is not None: # 🔍 we need to update with cached attention mask
1306
  current_attention_mask = past_key_value.update_attention_mask(current_attention_mask, self.layer_idx, expert_idx)
1307
 
 
1319
  raise ValueError(f"Attention mask should be of size {(bsz, 1, this_q_len, kv_seq_len)}, but is {current_attention_mask.size()}")
1320
 
1321
  attn_weights = attn_weights + current_attention_mask # 🔍
1322
+ # print("current_attention_mask", current_attention_mask.shape, current_attention_mask[0])
1323
  # upcast attention to fp32
1324
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
1325
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
1326
  attn_output = torch.matmul(attn_weights, value_states)
1327
 
1328
+ # if attn_output.size() != (bsz, self.num_key_value_groups // self.split_ratio, this_q_len, self.head_dim): # 🔍 q_len -> this_q_len, num_heads -> num_key_value_groups
1329
+ # raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is {attn_output.size()}")
1330
 
1331
  attn_output = attn_output.transpose(1, 2).contiguous()
1332
+ attn_output = attn_output.reshape(bsz, this_q_len, self.num_key_value_groups * self.head_dim // self.split_ratio) # 🔍 q_len -> this_q_len, hidden_size -> num_key_value_groups * head_dim
1333
  attn_output = self.o_proj[expert_idx](attn_output)
1334
  # ---------------------------------------------- #
1335
 
 
1362
  # init
1363
  attention_moe = MixtralAttentionMoE(config, layer_idx)
1364
 
1365
+ split = 1 # split the hidden_size, support split=1 --> 8/2, split=2 --> 16/4, split=4 --> 32/8
1366
  # copy weights
1367
+ num_key_value_groups = attention_moe.num_key_value_groups // split
1368
  head_dim = attention_moe.head_dim
1369
 
1370
+ for i in range(config.num_key_value_heads * split):
 
 
 
 
 
 
 
 
 
 
 
 
1371
  indices_q_o = [j for j in range(head_dim * num_key_value_groups * i, head_dim * num_key_value_groups * (i + 1))]
1372
+ indices_k_v = [j for j in range(head_dim * (i // split), head_dim * ((i // split) + 1))]
1373
 
1374
+ print(i, "indices_q_o", indices_q_o)
1375
  # print(i, "indices_k_v", indices_k_v)
1376
 
1377
  attention_moe.q_proj[i].weight.data = attention.q_proj.weight.data[indices_q_o].clone()
 
1529
  key_states = key_states.transpose(1, 2)
1530
  value_states = value_states.transpose(1, 2)
1531
 
1532
+ # print("attention_mask", attention_mask, attention_mask.shape)
1533
  attn_output = self._flash_attention_forward(
1534
  query_states,
1535
  key_states,
 
1667
  self, query_layer, key_layer, value_layer, attention_mask, query_length
1668
  ):
1669
  batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
 
1670
  # On the first iteration we need to properly re-create the padding mask
1671
  # by slicing it on the proper place
1672
  if kv_seq_len != attention_mask.shape[-1]:
 
1714
  )
1715
 
1716
 
1717
+
1718
+ class MixtralFlashAttention2MoE(MixtralFlashAttention2):
1719
+ def __init__(self, *args, **kwargs):
1720
+ super().__init__(*args, **kwargs)
1721
+
1722
+ self.top_k_attn = self.config.top_k_attn
1723
+ self.attn_experts = self.config.attn_experts
1724
+ self.scale_factor_attn = self.config.scale_factor_attn
1725
+ self.split_ratio = self.attn_experts // self.num_key_value_heads
1726
+
1727
+ self.gate = nn.Linear(self.hidden_size, self.attn_experts, bias=False)
1728
+
1729
+ self.q_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_key_value_groups * self.head_dim // self.split_ratio, bias=False) for _ in range(self.attn_experts)])
1730
+ self.k_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.head_dim, bias=False) for _ in range(self.attn_experts)])
1731
+ self.v_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.head_dim, bias=False) for _ in range(self.attn_experts)])
1732
+ self.o_proj = nn.ModuleList([nn.Linear(self.num_key_value_groups * self.head_dim // self.split_ratio, self.hidden_size, bias=self.config.add_rescale_bias) for _ in range(self.attn_experts)])
1733
+
1734
+ def forward(
1735
+ self,
1736
+ hidden_states: torch.Tensor,
1737
+ attention_mask: Optional[torch.Tensor] = None,
1738
+ position_ids: Optional[torch.LongTensor] = None,
1739
+ past_key_value: Optional[Cache] = None,
1740
+ output_attentions: bool = False,
1741
+ use_cache: bool = False,
1742
+ **kwargs,
1743
+ ):
1744
+
1745
+ if "padding_mask" in kwargs:
1746
+ warnings.warn(
1747
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1748
+ )
1749
+
1750
+ # overwrite attention_mask with padding_mask
1751
+ # attention_mask = kwargs.pop("padding_mask")
1752
+
1753
+ if past_key_value is not None and not isinstance(past_key_value, MoECache): # 🔍 type check
1754
+ raise TypeError(
1755
+ "`past_key_value` must be a `MoECache` instance for attention MoE!"
1756
+ )
1757
+
1758
+ bsz, q_len, hidden_dim = hidden_states.size()
1759
+ device = hidden_states.device
1760
+ dtype = hidden_states.dtype
1761
+
1762
+ hidden_states = hidden_states.reshape(-1, hidden_dim)
1763
+ # gate compute
1764
+ router_logits = self.gate(hidden_states)
1765
+ router_scores = F.softmax(router_logits, dim=1, dtype=torch.float)
1766
+ routing_weights, selected_experts = torch.topk(router_scores, self.top_k_attn, dim=-1)
1767
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
1768
+ routing_weights = routing_weights.to(dtype)
1769
+
1770
+ final_attn_output = torch.zeros_like(hidden_states).reshape(-1, hidden_dim)
1771
+
1772
+ expert_mask = F.one_hot(selected_experts, num_classes=self.num_heads).permute(2, 1, 0)
1773
+
1774
+ all_attn_weights = [] if output_attentions else None
1775
+
1776
+ for expert_idx in range(self.attn_experts):
1777
+ idx, top_x = torch.nonzero(expert_mask[expert_idx], as_tuple=True)
1778
+ # top_x_list = top_x.tolist()
1779
+ # idx_list = idx.tolist()
1780
+
1781
+ if top_x.shape[0] == 0 and not self.training: # skip during training will lead to asynchrony among different GPUs and blocks the training!
1782
+ if output_attentions:
1783
+ all_attn_weights.append(None)
1784
+ continue
1785
+
1786
+ # create position_ids for selected tokens
1787
+ current_batch_ids = (top_x // q_len)
1788
+ each_batch_selected_token_num = torch.bincount(current_batch_ids, minlength=bsz) # (bsz)
1789
+ this_q_len = each_batch_selected_token_num.max().item()
1790
+
1791
+ selection_mask = torch.zeros((bsz * q_len,), device=device, dtype=torch.bool)
1792
+ selection_mask[top_x] = True
1793
+ selection_mask = selection_mask.reshape(bsz, q_len)
1794
+ token_position_indices = torch.cumsum(selection_mask, dim=1) - 1
1795
+ token_position_indices = token_position_indices.flatten()
1796
+ current_seq_ids = token_position_indices[top_x]
1797
+
1798
+
1799
+ # 🔍 initialize hidden_states for this expert
1800
+ current_state = torch.zeros((bsz, this_q_len, hidden_dim), dtype=dtype, device=device)
1801
+ current_state[current_batch_ids, current_seq_ids] = hidden_states[top_x] # assign tokens sparsely
1802
+
1803
+ # for attention forward
1804
+ # expert_inputs = viewed_hidden_states[None, top_x_list].reshape(-1, self.hidden_size)
1805
+
1806
+ query_states = self.q_proj[expert_idx](current_state)
1807
+ key_states = self.k_proj[expert_idx](current_state)
1808
+ value_states = self.v_proj[expert_idx](current_state)
1809
+
1810
+ # seq_len = query_states.numel() // (bsz * self.num_key_value_groups * self.head_dim)
1811
+ query_states = query_states.view(bsz, -1, self.num_key_value_groups // self.split_ratio, self.head_dim).transpose(1, 2)
1812
+ key_states = key_states.view(bsz, -1, 1, self.head_dim).transpose(1, 2)
1813
+ value_states = value_states.view(bsz, -1, 1, self.head_dim).transpose(1, 2)
1814
+
1815
+ # for moe kv cache
1816
+ past_key_values_length = 0
1817
+ kv_seq_len = key_states.shape[-2]
1818
+ if past_key_value is not None:
1819
+ if self.layer_idx is None:
1820
+ raise ValueError(
1821
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
1822
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
1823
+ "with a layer index."
1824
+ )
1825
+ past_key_values_length = past_key_value.get_usable_length(kv_seq_len, self.layer_idx, expert_idx) # 🔍 specify expert index
1826
+ kv_seq_len += past_key_values_length
1827
+
1828
+ current_position_ids = torch.zeros((bsz, this_q_len), device=hidden_states.device, dtype=torch.long)
1829
+ current_position_ids[current_batch_ids, current_seq_ids] = position_ids.expand(bsz, q_len).flatten()[top_x]
1830
+
1831
+ if top_x.shape[0] > 0: # apply only when there are tokens
1832
+ cos, sin = self.rotary_emb(value_states, seq_len=current_position_ids.max().item() + 1) # 🔍 adjust the seq_len to the maximum possible value
1833
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, current_position_ids)
1834
+
1835
+ if past_key_value is not None:
1836
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
1837
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, expert_idx, cache_kwargs) # 🔍 specify expert index
1838
+
1839
+ # print("attention_mask", attention_mask.shape, attention_mask)
1840
+ # for current attention mask
1841
+
1842
+ '''
1843
+ current_attention_mask = torch.zeros((bsz, this_q_len), dtype=torch.bool, device=device)
1844
+
1845
+ if attention_mask is not None:
1846
+ if past_key_values_length > 0: # 🔍 we need to exclude previous tokens
1847
+ previous_seen_tokens_total = past_key_value._seen_tokens_total - q_len
1848
+ temp_attention_mask = attention_mask[:, previous_seen_tokens_total:].flatten() # select along dimension 1 so that we get tokens in this iteration
1849
+ else:
1850
+ temp_attention_mask = attention_mask.flatten() # flatten the dim
1851
+ current_attention_mask[current_batch_ids, current_seq_ids] = temp_attention_mask[top_x] # bug here !!!
1852
+
1853
+ else:
1854
+ current_attention_mask[current_batch_ids, current_seq_ids] = True # assign masks sparsely
1855
+
1856
+ if past_key_value is not None: # 🔍 we need to update with cached attention mask
1857
+ current_attention_mask = past_key_value.update_attention_mask(current_attention_mask, self.layer_idx, expert_idx)
1858
+
1859
+
1860
+ current_attention_mask = _prepare_4d_causal_attention_mask(
1861
+ current_attention_mask,
1862
+ (bsz, this_q_len),
1863
+ current_state,
1864
+ past_key_values_length,
1865
+ sliding_window=self.config.sliding_window,
1866
+ )
1867
+
1868
+ if current_attention_mask.size() != (bsz, 1, this_q_len, kv_seq_len): # 🔍 q_len -> this_q_len
1869
+ raise ValueError(f"Attention mask should be of size {(bsz, 1, this_q_len, kv_seq_len)}, but is {current_attention_mask.size()}")
1870
+
1871
+ '''
1872
+
1873
+ # for sliding window
1874
+ use_sliding_windows = (
1875
+ _flash_supports_window_size
1876
+ and getattr(self.config, "sliding_window", None) is not None
1877
+ and kv_seq_len > self.config.sliding_window
1878
+ )
1879
+
1880
+ if not _flash_supports_window_size:
1881
+ logger.warning_once(
1882
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
1883
+ " make sure to upgrade flash-attn library."
1884
+ )
1885
+
1886
+ # wait for change! sliding_window=4096
1887
+ if past_key_value is not None:
1888
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
1889
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
1890
+ if (
1891
+ getattr(self.config, "sliding_window", None) is not None
1892
+ and kv_seq_len > self.config.sliding_window
1893
+ and cache_has_contents
1894
+ ):
1895
+ slicing_tokens = 1 - self.config.sliding_window
1896
+
1897
+ past_key = past_key_value[self.layer_idx][0]
1898
+ past_value = past_key_value[self.layer_idx][1]
1899
+
1900
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
1901
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
1902
+
1903
+ if past_key.shape[-2] != self.config.sliding_window - 1:
1904
+ raise ValueError(
1905
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
1906
+ f" {past_key.shape}"
1907
+ )
1908
+
1909
+ if attention_mask is not None:
1910
+ attention_mask = attention_mask[:, slicing_tokens:]
1911
+ attention_mask = torch.cat(
1912
+ [attention_mask, torch.ones_like(attention_mask[:, -1:])],
1913
+ dim=-1,
1914
+ )
1915
+
1916
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
1917
+ key_states, value_states = past_key_value.update(
1918
+ key_states, value_states, self.layer_idx, cache_kwargs
1919
+ )
1920
+
1921
+ # for input dtype
1922
+ input_dtype = query_states.dtype
1923
+ if input_dtype == torch.float32:
1924
+ # Handle the case where the model is quantized
1925
+ if hasattr(self.config, "_pre_quantization_dtype"):
1926
+ target_dtype = self.config._pre_quantization_dtype
1927
+ else:
1928
+ target_dtype = self.q_proj[0].weight.dtype
1929
+
1930
+ logger.warning_once(
1931
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
1932
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
1933
+ f" {target_dtype}."
1934
+ )
1935
+
1936
+ query_states = query_states.to(target_dtype)
1937
+ key_states = key_states.to(target_dtype)
1938
+ value_states = value_states.to(target_dtype)
1939
+
1940
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
1941
+
1942
+ repeat_num = query_states.shape[1]
1943
+ key_states = repeat_kv(key_states, repeat_num)
1944
+ value_states = repeat_kv(value_states, repeat_num)
1945
+
1946
+ # print("repeat_num", repeat_num)
1947
+ # print("query_states shape", query_states.shape, key_states.shape, value_states.shape)
1948
+
1949
+ # Reashape to the expected shape for Flash Attention
1950
+ query_states = query_states.transpose(1, 2)
1951
+ key_states = key_states.transpose(1, 2)
1952
+ value_states = value_states.transpose(1, 2)
1953
+
1954
+ attn_output = self._flash_attention_forward(
1955
+ query_states,
1956
+ key_states,
1957
+ value_states,
1958
+ attention_mask,
1959
+ this_q_len,
1960
+ dropout=dropout_rate,
1961
+ use_sliding_windows=use_sliding_windows,
1962
+ )
1963
+
1964
+ attn_output = attn_output.reshape(bsz, this_q_len, self.num_key_value_groups * self.head_dim // self.split_ratio).contiguous()
1965
+ attn_output = self.o_proj[expert_idx](attn_output)
1966
+ attn_output = attn_output[current_batch_ids, current_seq_ids] * (routing_weights[top_x, idx, None] * self.scale_factor_attn)
1967
+
1968
+ final_attn_output.index_add_(0, top_x, attn_output)
1969
+
1970
+ final_attn_output = final_attn_output.reshape(bsz, q_len, hidden_dim)
1971
+
1972
+ if not output_attentions:
1973
+ attn_weights = None
1974
+
1975
+ return final_attn_output, attn_weights, past_key_value, router_logits # 🔍 return an extra `router_logits`
1976
+
1977
+
1978
+
1979
+ class MixtralFlashAttention2MoE_zt(MixtralFlashAttention2):
1980
+ def __init__(self, *args, **kwargs):
1981
+ super().__init__(*args, **kwargs)
1982
+
1983
+ self.top_k_attn = self.config.top_k_attn
1984
+ self.scale_factor_attn = self.config.scale_factor_attn
1985
+ # self.num_heads
1986
+ # self.head_dim
1987
+ # self.num_key_value_heads
1988
+ # self.num_key_value_groups # total number of experts
1989
+ assert self.top_k_attn <= self.num_key_value_groups
1990
+ # assert self.top_k_attn % self.num_key_value_heads == 0
1991
+ self.attn_hsz = self.hidden_size // self.num_key_value_groups * self.top_k_attn
1992
+ self.kv_repeat_num = self.attn_hsz // (self.num_key_value_heads * self.head_dim)
1993
+ self.simulated_attn_head_num = self.attn_hsz // self.head_dim
1994
+ assert self.attn_hsz % (self.num_key_value_heads * self.head_dim) == 0
1995
+ assert self.simulated_attn_head_num == self.num_heads * (self.top_k_attn / self.num_key_value_groups)
1996
+ assert self.kv_repeat_num * self.num_key_value_heads == self.simulated_attn_head_num
1997
+
1998
+ self.gate = nn.Linear(self.hidden_size, self.num_key_value_groups, bias=False)
1999
+ # tzhu: there are self.num_key_value_groups experts
2000
+ # each expert has a size of self.attn_hsz
2001
+ self.q_proj = nn.ModuleList(
2002
+ [nn.Linear(self.hidden_size, self.attn_hsz) for _ in range(self.num_key_value_groups)]
2003
+ )
2004
+ self.o_proj = nn.ModuleList(
2005
+ [nn.Linear(self.attn_hsz, self.hidden_size) for _ in range(self.num_key_value_groups)]
2006
+ )
2007
+
2008
+ def forward(
2009
+ self,
2010
+ hidden_states: torch.Tensor,
2011
+ attention_mask: Optional[torch.Tensor] = None,
2012
+ position_ids: Optional[torch.LongTensor] = None,
2013
+ past_key_value: Optional[Cache] = None,
2014
+ output_attentions: bool = False,
2015
+ use_cache: bool = False,
2016
+ **kwargs,
2017
+ ):
2018
+ if "padding_mask" in kwargs:
2019
+ warnings.warn(
2020
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
2021
+ )
2022
+
2023
+ # overwrite attention_mask with padding_mask
2024
+ attention_mask = kwargs.pop("padding_mask")
2025
+ bsz, q_len, _ = hidden_states.size()
2026
+
2027
+ key_states = self.k_proj(hidden_states)
2028
+ value_states = self.v_proj(hidden_states)
2029
+
2030
+ # tzhu: attn-moe on q_proj
2031
+ viewed_hidden_states = hidden_states.view(bsz * q_len, self.hidden_size)
2032
+ # router
2033
+ router_logits = self.gate(viewed_hidden_states)
2034
+ router_scores = F.softmax(router_logits, dim=-1, dtype=torch.float)
2035
+ routing_weights, selected_experts = torch.topk(router_scores, self.top_k_attn, dim=-1)
2036
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
2037
+ routing_weights = routing_weights.to(hidden_states.dtype)
2038
+ query_states = torch.zeros(
2039
+ (bsz * q_len, self.attn_hsz),
2040
+ dtype=hidden_states.dtype,
2041
+ device=hidden_states.device,
2042
+ )
2043
+ # expert_mask: (num_experts, top_k_attn, bsz * q_len)
2044
+ expert_mask = F.one_hot(selected_experts, num_classes=self.num_heads).permute(2, 1, 0)
2045
+ for expert_idx in range(self.num_key_value_groups):
2046
+ expert_layer = self.q_proj[expert_idx]
2047
+ idx, top_x = torch.where(expert_mask[expert_idx])
2048
+ top_x_list = top_x.tolist()
2049
+ idx_list = idx.tolist()
2050
+ expert_inputs = viewed_hidden_states[None, top_x_list].reshape(-1, self.hidden_size)
2051
+ # inputs (-1, hidden_size) -> outputs (-1, attn_hsz)
2052
+ expert_outs = expert_layer(expert_inputs) * routing_weights[top_x_list, idx_list, None] * self.scale_factor_attn
2053
+ query_states.index_add_(0, top_x, expert_outs.to(query_states.dtype))
2054
+ query_states = query_states.view(bsz, q_len, self.attn_hsz)
2055
+ # query_states = query_states.view(
2056
+ # bsz, q_len, self.num_heads, self.simulated_attn_head_num
2057
+ # ).transpose(1, 2)
2058
+ query_states = query_states.view(
2059
+ bsz, q_len, self.simulated_attn_head_num, self.head_dim
2060
+ ).transpose(1, 2)
2061
+ key_states = key_states.view(
2062
+ bsz, q_len, self.num_key_value_heads, self.head_dim
2063
+ ).transpose(1, 2)
2064
+ value_states = value_states.view(
2065
+ bsz, q_len, self.num_key_value_heads, self.head_dim
2066
+ ).transpose(1, 2)
2067
+
2068
+ kv_seq_len = key_states.shape[-2]
2069
+ if past_key_value is not None:
2070
+ if self.layer_idx is None:
2071
+ raise ValueError(
2072
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
2073
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
2074
+ "with a layer index."
2075
+ )
2076
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
2077
+
2078
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
2079
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
2080
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
2081
+
2082
+ query_states, key_states = apply_rotary_pos_emb(
2083
+ query_states, key_states, cos, sin, position_ids
2084
+ )
2085
+
2086
+ use_sliding_windows = (
2087
+ _flash_supports_window_size
2088
+ and getattr(self.config, "sliding_window", None) is not None
2089
+ and kv_seq_len > self.config.sliding_window
2090
+ )
2091
+
2092
+ if not _flash_supports_window_size:
2093
+ logger.warning_once(
2094
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
2095
+ " make sure to upgrade flash-attn library."
2096
+ )
2097
+
2098
+ if past_key_value is not None:
2099
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
2100
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
2101
+ if (
2102
+ getattr(self.config, "sliding_window", None) is not None
2103
+ and kv_seq_len > self.config.sliding_window
2104
+ and cache_has_contents
2105
+ ):
2106
+ slicing_tokens = 1 - self.config.sliding_window
2107
+
2108
+ past_key = past_key_value[self.layer_idx][0]
2109
+ past_value = past_key_value[self.layer_idx][1]
2110
+
2111
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
2112
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
2113
+
2114
+ if past_key.shape[-2] != self.config.sliding_window - 1:
2115
+ raise ValueError(
2116
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
2117
+ f" {past_key.shape}"
2118
+ )
2119
+
2120
+ if attention_mask is not None:
2121
+ attention_mask = attention_mask[:, slicing_tokens:]
2122
+ attention_mask = torch.cat(
2123
+ [attention_mask, torch.ones_like(attention_mask[:, -1:])],
2124
+ dim=-1,
2125
+ )
2126
+
2127
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
2128
+ key_states, value_states = past_key_value.update(
2129
+ key_states, value_states, self.layer_idx, cache_kwargs
2130
+ )
2131
+
2132
+ # repeat k/v heads if n_kv_heads < n_heads
2133
+ key_states = repeat_kv(key_states, self.kv_repeat_num)
2134
+ value_states = repeat_kv(value_states, self.kv_repeat_num)
2135
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
2136
+
2137
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
2138
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
2139
+ # cast them back in float16 just to be sure everything works as expected.
2140
+ input_dtype = query_states.dtype
2141
+ if input_dtype == torch.float32:
2142
+ # Handle the case where the model is quantized
2143
+ if hasattr(self.config, "_pre_quantization_dtype"):
2144
+ target_dtype = self.config._pre_quantization_dtype
2145
+ else:
2146
+ target_dtype = self.q_proj.weight.dtype
2147
+
2148
+ logger.warning_once(
2149
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
2150
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
2151
+ f" {target_dtype}."
2152
+ )
2153
+
2154
+ query_states = query_states.to(target_dtype)
2155
+ key_states = key_states.to(target_dtype)
2156
+ value_states = value_states.to(target_dtype)
2157
+
2158
+ # Reashape to the expected shape for Flash Attention
2159
+ query_states = query_states.transpose(1, 2)
2160
+ key_states = key_states.transpose(1, 2)
2161
+ value_states = value_states.transpose(1, 2)
2162
+
2163
+ attn_output = self._flash_attention_forward(
2164
+ query_states,
2165
+ key_states,
2166
+ value_states,
2167
+ attention_mask,
2168
+ q_len,
2169
+ dropout=dropout_rate,
2170
+ use_sliding_windows=use_sliding_windows,
2171
+ )
2172
+
2173
+ attn_output = attn_output.reshape(bsz * q_len, self.attn_hsz).contiguous()
2174
+ final_attn_output = torch.zeros(
2175
+ (bsz * q_len, self.hidden_size),
2176
+ dtype=hidden_states.dtype,
2177
+ device=hidden_states.device,
2178
+ )
2179
+ for expert_idx in range(self.num_key_value_groups):
2180
+ expert_layer = self.o_proj[expert_idx]
2181
+ idx, top_x = torch.where(expert_mask[expert_idx])
2182
+ top_x_list = top_x.tolist()
2183
+ idx_list = idx.tolist()
2184
+ expert_inputs = attn_output[None, top_x_list].reshape(-1, self.attn_hsz)
2185
+ expert_outs = expert_layer(expert_inputs) * routing_weights[top_x_list, idx_list, None] * self.scale_factor_attn
2186
+ final_attn_output.index_add_(0, top_x, expert_outs.to(final_attn_output.dtype))
2187
+ final_attn_output = final_attn_output.view(bsz, q_len, self.hidden_size)
2188
+
2189
+ if not output_attentions:
2190
+ attn_weights = None
2191
+
2192
+ return final_attn_output, attn_weights, past_key_value, router_logits
2193
+
2194
+
2195
+ @torch.no_grad()
2196
+ def from_vanilla_attention(attention: MixtralAttention, top_k_attn, scale_factor_attn):
2197
+ # config
2198
+ layer_idx = attention.layer_idx
2199
+ config = attention.config
2200
+ config.top_k_attn = top_k_attn
2201
+ config.scale_factor_attn = scale_factor_attn
2202
+
2203
+ # init
2204
+ attention_moe = MixtralFlashAttention2MoE(config, layer_idx)
2205
+
2206
+ # copy weights
2207
+ num_key_value_groups = attention_moe.num_key_value_groups
2208
+ head_dim = attention_moe.head_dim
2209
+
2210
+ for i in range(num_key_value_groups):
2211
+ indices_q_o = []
2212
+ for j in range(attention_moe.num_key_value_heads):
2213
+ k = i + j * num_key_value_groups
2214
+ indices_q_o.extend(
2215
+ list(range(k * head_dim, (k + 1) * head_dim))
2216
+ )
2217
+
2218
+ print(i, "indices_q_o", indices_q_o)
2219
+
2220
+ attention_moe.q_proj[i].weight.data = attention.q_proj.weight.data[indices_q_o].clone()
2221
+ attention_moe.o_proj[i].weight.data = attention.o_proj.weight.data[:, indices_q_o].clone()
2222
+
2223
+ return attention_moe
2224
+
2225
+
2226
+
2227
+
2228
  class MixtralBLockSparseTop2MLP(nn.Module):
2229
  def __init__(self, config: MixtralConfig, ffn_dim, add_rescale_bias=False): # 🔍
2230
  super().__init__()
 
2255
  # 🔍
2256
  MISTRAL_ATTENTION_MOE_CLASSES = {
2257
  "eager": MixtralAttentionMoE,
2258
+ "flash_attention_2": MixtralFlashAttention2MoE,
2259
  }
2260
 
2261
 
 
2534
  )
2535
  self.use_attn_moe = config.use_attn_moe
2536
 
2537
+ if self.use_attn_moe:
2538
+ attn_class = MISTRAL_ATTENTION_MOE_CLASSES[config._attn_implementation]
2539
+ else:
2540
+ attn_class = MISTRAL_ATTENTION_CLASSES[config._attn_implementation]
2541
+ self.self_attn = attn_class(config, layer_idx)
2542
+
2543
+
2544
  if self.is_moe:
 
 
 
 
 
 
2545
  self.block_sparse_moe = MixtralSparseMoeBlock(config)
2546
  self.mlp_residual = (
2547
  MixtralBLockSparseTop2MLP(config, config.intermediate_size_residual)
 
2550
  )
2551
 
2552
  else:
 
 
2553
  self.block_sparse_moe = MixtralBLockSparseTop2MLP(
2554
  config, config.intermediate_size * config.num_local_experts
2555
  )
 
2601
  hidden_states = self.input_layernorm(hidden_states)
2602
 
2603
  # 🔍 Self Attention
2604
+ if self.use_attn_moe:
2605
  (
2606
  hidden_states,
2607
  self_attn_weights,
 
2630
 
2631
  # Fully Connected
2632
  residual = hidden_states
2633
+ hidden_states_input = self.post_attention_layernorm(hidden_states)
2634
 
2635
  # 🔍
2636
  if self.is_moe:
2637
+ hidden_states, router_logits = self.block_sparse_moe(hidden_states_input)
2638
  else:
2639
+ hidden_states = self.block_sparse_moe(hidden_states_input)
2640
  router_logits = None
2641
 
2642
  if self.mlp_residual is not None:
2643
+ hidden_states += self.mlp_residual(hidden_states_input) #
2644
+
2645
  hidden_states = residual + hidden_states
2646
 
2647
  outputs = (hidden_states,)
 
3058
  if len(valid_attn_router_logits) > 0: # exist logits that is not None
3059
  attn_aux_loss = load_balancing_loss_func(
3060
  valid_attn_router_logits,
3061
+ self.config.attn_experts,
3062
  self.config.top_k_attn,
3063
  use_layer_wise_balance=self.config.use_layer_wise_balance, # ✨
3064
  )
 
3467
  if past is None:
3468
  if self.config.use_attn_moe: # 🔍
3469
  model_kwargs["past_key_values"] = MoECache(
3470
+ # self.config.num_key_value_heads
3471
+ self.config.attn_experts
3472
  )
3473
  else: # 🔍
3474
  model_kwargs["past_key_values"] = DynamicCache()
trainer_state.json CHANGED
@@ -1,1278 +1,2398 @@
1
  {
2
  "best_metric": null,
3
  "best_model_checkpoint": null,
4
- "epoch": 1.8575851393188856,
5
  "eval_steps": 500,
6
- "global_step": 1800,
7
  "is_hyper_param_search": false,
8
  "is_local_process_zero": true,
9
  "is_world_process_zero": true,
10
  "log_history": [
11
  {
12
- "epoch": 0.010319917440660475,
13
- "grad_norm": 2.8247148990631104,
14
- "learning_rate": 2.2727272727272728e-06,
15
- "loss": 0.8288,
16
- "step": 10
17
  },
18
  {
19
- "epoch": 0.02063983488132095,
20
- "grad_norm": 1.1619349718093872,
21
- "learning_rate": 4.5454545454545455e-06,
22
- "loss": 0.7902,
23
- "step": 20
24
  },
25
  {
26
- "epoch": 0.030959752321981424,
27
- "grad_norm": 0.7691543698310852,
28
- "learning_rate": 6.818181818181818e-06,
29
- "loss": 0.7388,
30
- "step": 30
31
  },
32
  {
33
- "epoch": 0.0412796697626419,
34
- "grad_norm": 0.687256395816803,
35
- "learning_rate": 9.090909090909091e-06,
36
- "loss": 0.7177,
37
- "step": 40
38
  },
39
  {
40
- "epoch": 0.05159958720330237,
41
- "grad_norm": 0.6163066029548645,
42
- "learning_rate": 1.1363636363636366e-05,
43
- "loss": 0.701,
44
- "step": 50
45
  },
46
  {
47
- "epoch": 0.06191950464396285,
48
- "grad_norm": 0.6468276381492615,
49
- "learning_rate": 1.3636363636363637e-05,
50
- "loss": 0.6853,
51
- "step": 60
52
  },
53
  {
54
- "epoch": 0.07223942208462332,
55
- "grad_norm": 0.9129849672317505,
56
- "learning_rate": 1.590909090909091e-05,
57
- "loss": 0.6749,
58
- "step": 70
59
  },
60
  {
61
- "epoch": 0.0825593395252838,
62
- "grad_norm": 0.9610547423362732,
63
- "learning_rate": 1.8181818181818182e-05,
64
- "loss": 0.664,
65
- "step": 80
66
  },
67
  {
68
- "epoch": 0.09287925696594428,
69
- "grad_norm": 0.9436660408973694,
70
- "learning_rate": 1.9999975160696756e-05,
71
- "loss": 0.6637,
72
- "step": 90
73
  },
74
  {
75
- "epoch": 0.10319917440660474,
76
- "grad_norm": 0.828860878944397,
77
- "learning_rate": 1.999910579803988e-05,
78
- "loss": 0.6578,
79
- "step": 100
80
  },
81
  {
82
- "epoch": 0.11351909184726522,
83
- "grad_norm": 0.8615094423294067,
84
- "learning_rate": 1.9996994593616145e-05,
85
- "loss": 0.6473,
86
- "step": 110
87
  },
88
  {
89
- "epoch": 0.1238390092879257,
90
- "grad_norm": 0.8153389096260071,
91
- "learning_rate": 1.9993641809627166e-05,
92
- "loss": 0.6402,
93
- "step": 120
94
  },
95
  {
96
- "epoch": 0.13415892672858618,
97
- "grad_norm": 0.8015602827072144,
98
- "learning_rate": 1.9989047862472904e-05,
99
- "loss": 0.6378,
100
- "step": 130
101
  },
102
  {
103
- "epoch": 0.14447884416924664,
104
- "grad_norm": 0.7367793321609497,
105
- "learning_rate": 1.9983213322699926e-05,
106
- "loss": 0.6346,
107
- "step": 140
108
  },
109
  {
110
- "epoch": 0.15479876160990713,
111
- "grad_norm": 0.913837194442749,
112
- "learning_rate": 1.997613891493054e-05,
113
- "loss": 0.6322,
114
- "step": 150
115
  },
116
  {
117
- "epoch": 0.1651186790505676,
118
- "grad_norm": 0.7812018990516663,
119
- "learning_rate": 1.996782551777282e-05,
120
- "loss": 0.6206,
121
- "step": 160
122
  },
123
  {
124
- "epoch": 0.17543859649122806,
125
- "grad_norm": 0.7282320857048035,
126
- "learning_rate": 1.995827416371147e-05,
127
- "loss": 0.6127,
128
- "step": 170
129
  },
130
  {
131
- "epoch": 0.18575851393188855,
132
- "grad_norm": 0.7379522919654846,
133
- "learning_rate": 1.9947486038979606e-05,
134
- "loss": 0.6098,
135
- "step": 180
136
  },
137
  {
138
- "epoch": 0.19607843137254902,
139
- "grad_norm": 0.7425850033760071,
140
- "learning_rate": 1.993546248341142e-05,
141
- "loss": 0.6079,
142
- "step": 190
143
  },
144
  {
145
- "epoch": 0.20639834881320948,
146
- "grad_norm": 0.6972795724868774,
147
- "learning_rate": 1.9922204990275788e-05,
148
- "loss": 0.6006,
149
- "step": 200
150
  },
151
  {
152
- "epoch": 0.21671826625386997,
153
- "grad_norm": 0.7257381677627563,
154
- "learning_rate": 1.9907715206090817e-05,
155
- "loss": 0.6042,
156
- "step": 210
157
  },
158
  {
159
- "epoch": 0.22703818369453044,
160
- "grad_norm": 0.6420859098434448,
161
- "learning_rate": 1.989199493041935e-05,
162
- "loss": 0.593,
163
- "step": 220
164
  },
165
  {
166
- "epoch": 0.23735810113519093,
167
- "grad_norm": 0.7002107501029968,
168
- "learning_rate": 1.9875046115645443e-05,
169
- "loss": 0.5931,
170
- "step": 230
171
  },
172
  {
173
- "epoch": 0.2476780185758514,
174
- "grad_norm": 0.7185678482055664,
175
- "learning_rate": 1.9856870866731946e-05,
176
- "loss": 0.5926,
177
- "step": 240
178
  },
179
  {
180
- "epoch": 0.2579979360165119,
181
- "grad_norm": 0.6459465026855469,
182
- "learning_rate": 1.983747144095902e-05,
183
- "loss": 0.5878,
184
- "step": 250
185
  },
186
  {
187
- "epoch": 0.26831785345717235,
188
- "grad_norm": 0.6379982233047485,
189
- "learning_rate": 1.9816850247643834e-05,
190
- "loss": 0.5796,
191
- "step": 260
192
  },
193
  {
194
- "epoch": 0.2786377708978328,
195
- "grad_norm": 0.7094199061393738,
196
- "learning_rate": 1.97950098478413e-05,
197
- "loss": 0.5771,
198
- "step": 270
199
  },
200
  {
201
- "epoch": 0.2889576883384933,
202
- "grad_norm": 0.6646308302879333,
203
- "learning_rate": 1.9771952954026038e-05,
204
- "loss": 0.5767,
205
- "step": 280
206
  },
207
  {
208
- "epoch": 0.29927760577915374,
209
- "grad_norm": 0.6392974257469177,
210
- "learning_rate": 1.9747682429755493e-05,
211
- "loss": 0.5737,
212
- "step": 290
213
  },
214
  {
215
- "epoch": 0.30959752321981426,
216
- "grad_norm": 0.5905966758728027,
217
- "learning_rate": 1.972220128931427e-05,
218
- "loss": 0.576,
219
- "step": 300
220
  },
221
  {
222
- "epoch": 0.31991744066047473,
223
- "grad_norm": 0.8001016974449158,
224
- "learning_rate": 1.9695512697339797e-05,
225
- "loss": 0.5698,
226
- "step": 310
227
  },
228
  {
229
- "epoch": 0.3302373581011352,
230
- "grad_norm": 0.5997283458709717,
231
- "learning_rate": 1.966761996842929e-05,
232
- "loss": 0.5703,
233
- "step": 320
234
  },
235
  {
236
- "epoch": 0.34055727554179566,
237
- "grad_norm": 0.6440294981002808,
238
- "learning_rate": 1.9638526566728088e-05,
239
- "loss": 0.5584,
240
- "step": 330
241
  },
242
  {
243
- "epoch": 0.3508771929824561,
244
- "grad_norm": 0.7667876482009888,
245
- "learning_rate": 1.960823610549943e-05,
246
- "loss": 0.5585,
247
- "step": 340
248
  },
249
  {
250
- "epoch": 0.36119711042311664,
251
- "grad_norm": 0.6358545422554016,
252
- "learning_rate": 1.9576752346675692e-05,
253
- "loss": 0.5578,
254
- "step": 350
255
  },
256
  {
257
- "epoch": 0.3715170278637771,
258
- "grad_norm": 0.6375100612640381,
259
- "learning_rate": 1.954407920039119e-05,
260
- "loss": 0.5621,
261
- "step": 360
262
  },
263
  {
264
- "epoch": 0.38183694530443757,
265
- "grad_norm": 0.7324113845825195,
266
- "learning_rate": 1.951022072449655e-05,
267
- "loss": 0.5527,
268
- "step": 370
269
  },
270
  {
271
- "epoch": 0.39215686274509803,
272
- "grad_norm": 0.658400297164917,
273
- "learning_rate": 1.9475181124054742e-05,
274
- "loss": 0.5538,
275
- "step": 380
276
  },
277
  {
278
- "epoch": 0.4024767801857585,
279
- "grad_norm": 0.7300146222114563,
280
- "learning_rate": 1.9438964750818833e-05,
281
- "loss": 0.5494,
282
- "step": 390
283
  },
284
  {
285
- "epoch": 0.41279669762641896,
286
- "grad_norm": 0.7315788865089417,
287
- "learning_rate": 1.940157610269152e-05,
288
- "loss": 0.5493,
289
- "step": 400
290
  },
291
  {
292
- "epoch": 0.4231166150670795,
293
- "grad_norm": 0.6689688563346863,
294
- "learning_rate": 1.9363019823166506e-05,
295
- "loss": 0.5509,
296
- "step": 410
297
  },
298
  {
299
- "epoch": 0.43343653250773995,
300
- "grad_norm": 0.6882718205451965,
301
- "learning_rate": 1.9323300700751816e-05,
302
- "loss": 0.5473,
303
- "step": 420
304
  },
305
  {
306
- "epoch": 0.4437564499484004,
307
- "grad_norm": 0.6466957330703735,
308
- "learning_rate": 1.9282423668375064e-05,
309
- "loss": 0.5435,
310
- "step": 430
311
  },
312
  {
313
- "epoch": 0.4540763673890609,
314
- "grad_norm": 0.6492331624031067,
315
- "learning_rate": 1.9240393802770824e-05,
316
- "loss": 0.5449,
317
- "step": 440
318
  },
319
  {
320
- "epoch": 0.46439628482972134,
321
- "grad_norm": 0.5815872550010681,
322
- "learning_rate": 1.9197216323850122e-05,
323
- "loss": 0.5398,
324
- "step": 450
325
  },
326
  {
327
- "epoch": 0.47471620227038186,
328
- "grad_norm": 0.6003971099853516,
329
- "learning_rate": 1.9152896594052134e-05,
330
- "loss": 0.533,
331
- "step": 460
332
  },
333
  {
334
- "epoch": 0.4850361197110423,
335
- "grad_norm": 0.5987655520439148,
336
- "learning_rate": 1.910744011767821e-05,
337
- "loss": 0.5309,
338
- "step": 470
339
  },
340
  {
341
- "epoch": 0.4953560371517028,
342
- "grad_norm": 0.6432524919509888,
343
- "learning_rate": 1.9060852540208277e-05,
344
- "loss": 0.5344,
345
- "step": 480
346
  },
347
  {
348
- "epoch": 0.5056759545923633,
349
- "grad_norm": 0.5650415420532227,
350
- "learning_rate": 1.9013139647599656e-05,
351
- "loss": 0.5333,
352
- "step": 490
353
  },
354
  {
355
- "epoch": 0.5159958720330238,
356
- "grad_norm": 0.6225659847259521,
357
- "learning_rate": 1.8964307365568513e-05,
358
- "loss": 0.5231,
359
- "step": 500
360
  },
361
  {
362
- "epoch": 0.5263157894736842,
363
- "grad_norm": 0.6020525097846985,
364
- "learning_rate": 1.89143617588539e-05,
365
- "loss": 0.5241,
366
- "step": 510
367
  },
368
  {
369
- "epoch": 0.5366357069143447,
370
- "grad_norm": 0.5726006031036377,
371
- "learning_rate": 1.886330903046454e-05,
372
- "loss": 0.5278,
373
- "step": 520
374
  },
375
  {
376
- "epoch": 0.5469556243550051,
377
- "grad_norm": 0.5783742666244507,
378
- "learning_rate": 1.8811155520908445e-05,
379
- "loss": 0.5253,
380
- "step": 530
381
  },
382
  {
383
- "epoch": 0.5572755417956656,
384
- "grad_norm": 0.5478541254997253,
385
- "learning_rate": 1.8757907707405456e-05,
386
- "loss": 0.5166,
387
- "step": 540
388
  },
389
  {
390
- "epoch": 0.5675954592363261,
391
- "grad_norm": 0.5668419003486633,
392
- "learning_rate": 1.8703572203082795e-05,
393
- "loss": 0.5206,
394
- "step": 550
395
  },
396
  {
397
- "epoch": 0.5779153766769866,
398
- "grad_norm": 0.5729948282241821,
399
- "learning_rate": 1.8648155756153768e-05,
400
- "loss": 0.516,
401
- "step": 560
402
  },
403
  {
404
- "epoch": 0.5882352941176471,
405
- "grad_norm": 0.651300311088562,
406
- "learning_rate": 1.859166524907963e-05,
407
- "loss": 0.5183,
408
- "step": 570
409
  },
410
  {
411
- "epoch": 0.5985552115583075,
412
- "grad_norm": 0.6236013174057007,
413
- "learning_rate": 1.8534107697714864e-05,
414
- "loss": 0.5242,
415
- "step": 580
416
  },
417
  {
418
- "epoch": 0.608875128998968,
419
- "grad_norm": 0.5427743196487427,
420
- "learning_rate": 1.84754902504358e-05,
421
- "loss": 0.5291,
422
- "step": 590
423
  },
424
  {
425
- "epoch": 0.6191950464396285,
426
- "grad_norm": 0.5849993824958801,
427
- "learning_rate": 1.8415820187252847e-05,
428
- "loss": 0.5213,
429
- "step": 600
430
  },
431
  {
432
- "epoch": 0.6295149638802889,
433
- "grad_norm": 0.6405364274978638,
434
- "learning_rate": 1.8355104918906353e-05,
435
- "loss": 0.5187,
436
- "step": 610
437
  },
438
  {
439
- "epoch": 0.6398348813209495,
440
- "grad_norm": 0.5616128444671631,
441
- "learning_rate": 1.8293351985946194e-05,
442
- "loss": 0.5108,
443
- "step": 620
444
  },
445
  {
446
- "epoch": 0.6501547987616099,
447
- "grad_norm": 0.5770090222358704,
448
- "learning_rate": 1.823056905779532e-05,
449
- "loss": 0.5172,
450
- "step": 630
451
  },
452
  {
453
- "epoch": 0.6604747162022704,
454
- "grad_norm": 0.5251275300979614,
455
- "learning_rate": 1.816676393179721e-05,
456
- "loss": 0.5116,
457
- "step": 640
458
  },
459
  {
460
- "epoch": 0.6707946336429309,
461
- "grad_norm": 0.5879736542701721,
462
- "learning_rate": 1.8101944532247495e-05,
463
- "loss": 0.5157,
464
- "step": 650
465
  },
466
  {
467
- "epoch": 0.6811145510835913,
468
- "grad_norm": 0.5661890506744385,
469
- "learning_rate": 1.80361189094098e-05,
470
- "loss": 0.5088,
471
- "step": 660
472
  },
473
  {
474
- "epoch": 0.6914344685242518,
475
- "grad_norm": 0.5618740916252136,
476
- "learning_rate": 1.796929523851593e-05,
477
- "loss": 0.5111,
478
- "step": 670
479
  },
480
  {
481
- "epoch": 0.7017543859649122,
482
- "grad_norm": 0.5378845930099487,
483
- "learning_rate": 1.790148181875055e-05,
484
- "loss": 0.5118,
485
- "step": 680
486
  },
487
  {
488
- "epoch": 0.7120743034055728,
489
- "grad_norm": 0.5547090172767639,
490
- "learning_rate": 1.783268707222048e-05,
491
- "loss": 0.5088,
492
- "step": 690
493
  },
494
  {
495
- "epoch": 0.7223942208462333,
496
- "grad_norm": 0.5933310389518738,
497
- "learning_rate": 1.776291954290867e-05,
498
- "loss": 0.5063,
499
- "step": 700
500
  },
501
  {
502
- "epoch": 0.7327141382868937,
503
- "grad_norm": 0.5393312573432922,
504
- "learning_rate": 1.769218789561312e-05,
505
- "loss": 0.5014,
506
- "step": 710
507
  },
508
  {
509
- "epoch": 0.7430340557275542,
510
- "grad_norm": 0.5515422821044922,
511
- "learning_rate": 1.7620500914870734e-05,
512
- "loss": 0.5116,
513
- "step": 720
514
  },
515
  {
516
- "epoch": 0.7533539731682146,
517
- "grad_norm": 0.5601432919502258,
518
- "learning_rate": 1.7547867503866315e-05,
519
- "loss": 0.5024,
520
- "step": 730
521
  },
522
  {
523
- "epoch": 0.7636738906088751,
524
- "grad_norm": 0.5876237154006958,
525
- "learning_rate": 1.7474296683326844e-05,
526
- "loss": 0.5098,
527
- "step": 740
528
  },
529
  {
530
- "epoch": 0.7739938080495357,
531
- "grad_norm": 0.518947184085846,
532
- "learning_rate": 1.739979759040114e-05,
533
- "loss": 0.5017,
534
- "step": 750
535
  },
536
  {
537
- "epoch": 0.7843137254901961,
538
- "grad_norm": 0.5550107955932617,
539
- "learning_rate": 1.7324379477525086e-05,
540
- "loss": 0.5044,
541
- "step": 760
542
  },
543
  {
544
- "epoch": 0.7946336429308566,
545
- "grad_norm": 0.5430490374565125,
546
- "learning_rate": 1.724805171127249e-05,
547
- "loss": 0.5029,
548
- "step": 770
549
  },
550
  {
551
- "epoch": 0.804953560371517,
552
- "grad_norm": 0.5498166680335999,
553
- "learning_rate": 1.7170823771191824e-05,
554
- "loss": 0.499,
555
- "step": 780
556
  },
557
  {
558
- "epoch": 0.8152734778121775,
559
- "grad_norm": 0.5843333601951599,
560
- "learning_rate": 1.709270524862891e-05,
561
- "loss": 0.4968,
562
- "step": 790
563
  },
564
  {
565
- "epoch": 0.8255933952528379,
566
- "grad_norm": 0.5710884928703308,
567
- "learning_rate": 1.7013705845535704e-05,
568
- "loss": 0.5024,
569
- "step": 800
570
  },
571
  {
572
- "epoch": 0.8359133126934984,
573
- "grad_norm": 0.5185025930404663,
574
- "learning_rate": 1.6933835373265373e-05,
575
- "loss": 0.503,
576
- "step": 810
577
  },
578
  {
579
- "epoch": 0.846233230134159,
580
- "grad_norm": 0.5252718329429626,
581
- "learning_rate": 1.685310375135376e-05,
582
- "loss": 0.5028,
583
- "step": 820
584
  },
585
  {
586
- "epoch": 0.8565531475748194,
587
- "grad_norm": 0.5351059436798096,
588
- "learning_rate": 1.6771521006287442e-05,
589
- "loss": 0.4927,
590
- "step": 830
591
  },
592
  {
593
- "epoch": 0.8668730650154799,
594
- "grad_norm": 0.5176792740821838,
595
- "learning_rate": 1.6689097270258463e-05,
596
- "loss": 0.5012,
597
- "step": 840
598
  },
599
  {
600
- "epoch": 0.8771929824561403,
601
- "grad_norm": 0.5016619563102722,
602
- "learning_rate": 1.6605842779905984e-05,
603
- "loss": 0.4941,
604
- "step": 850
605
  },
606
  {
607
- "epoch": 0.8875128998968008,
608
- "grad_norm": 0.536718487739563,
609
- "learning_rate": 1.6521767875044935e-05,
610
- "loss": 0.488,
611
- "step": 860
612
  },
613
  {
614
- "epoch": 0.8978328173374613,
615
- "grad_norm": 0.49594587087631226,
616
- "learning_rate": 1.643688299738186e-05,
617
- "loss": 0.4901,
618
- "step": 870
619
  },
620
  {
621
- "epoch": 0.9081527347781218,
622
- "grad_norm": 0.5281170606613159,
623
- "learning_rate": 1.635119868921809e-05,
624
- "loss": 0.4979,
625
- "step": 880
626
  },
627
  {
628
- "epoch": 0.9184726522187823,
629
- "grad_norm": 0.5000081658363342,
630
- "learning_rate": 1.6264725592140468e-05,
631
- "loss": 0.4935,
632
- "step": 890
633
  },
634
  {
635
- "epoch": 0.9287925696594427,
636
- "grad_norm": 0.5359088182449341,
637
- "learning_rate": 1.6177474445699695e-05,
638
- "loss": 0.4854,
639
- "step": 900
640
  },
641
  {
642
- "epoch": 0.9391124871001032,
643
- "grad_norm": 0.5657668709754944,
644
- "learning_rate": 1.6089456086076527e-05,
645
- "loss": 0.4877,
646
- "step": 910
647
  },
648
  {
649
- "epoch": 0.9494324045407637,
650
- "grad_norm": 0.507234513759613,
651
- "learning_rate": 1.6000681444735976e-05,
652
- "loss": 0.4903,
653
- "step": 920
654
  },
655
  {
656
- "epoch": 0.9597523219814241,
657
- "grad_norm": 0.5578757524490356,
658
- "learning_rate": 1.5911161547069688e-05,
659
- "loss": 0.4884,
660
- "step": 930
661
  },
662
  {
663
- "epoch": 0.9700722394220846,
664
- "grad_norm": 0.5635477304458618,
665
- "learning_rate": 1.582090751102662e-05,
666
- "loss": 0.4973,
667
- "step": 940
668
  },
669
  {
670
- "epoch": 0.9803921568627451,
671
- "grad_norm": 0.5168154835700989,
672
- "learning_rate": 1.5729930545732247e-05,
673
- "loss": 0.4818,
674
- "step": 950
675
  },
676
  {
677
- "epoch": 0.9907120743034056,
678
- "grad_norm": 0.5357134342193604,
679
- "learning_rate": 1.5638241950096458e-05,
680
- "loss": 0.4863,
681
- "step": 960
682
  },
683
  {
684
- "epoch": 1.001031991744066,
685
- "grad_norm": 1.1038967370986938,
686
- "learning_rate": 1.554585311141027e-05,
687
- "loss": 0.4791,
688
- "step": 970
689
  },
690
  {
691
- "epoch": 1.0113519091847265,
692
- "grad_norm": 0.6728698015213013,
693
- "learning_rate": 1.5452775503931566e-05,
694
- "loss": 0.4229,
695
- "step": 980
696
  },
697
  {
698
- "epoch": 1.021671826625387,
699
- "grad_norm": 0.5582284331321716,
700
- "learning_rate": 1.5359020687460096e-05,
701
- "loss": 0.4193,
702
- "step": 990
703
  },
704
  {
705
- "epoch": 1.0319917440660475,
706
- "grad_norm": 0.5344264507293701,
707
- "learning_rate": 1.5264600305901744e-05,
708
- "loss": 0.4241,
709
- "step": 1000
710
  },
711
  {
712
- "epoch": 1.0423116615067078,
713
- "grad_norm": 0.5118332505226135,
714
- "learning_rate": 1.5169526085822451e-05,
715
- "loss": 0.4178,
716
- "step": 1010
717
  },
718
  {
719
- "epoch": 1.0526315789473684,
720
- "grad_norm": 0.54106605052948,
721
- "learning_rate": 1.5073809834991816e-05,
722
- "loss": 0.4167,
723
- "step": 1020
724
  },
725
  {
726
- "epoch": 1.0629514963880289,
727
- "grad_norm": 0.591042697429657,
728
- "learning_rate": 1.4977463440916621e-05,
729
- "loss": 0.4154,
730
- "step": 1030
731
  },
732
  {
733
- "epoch": 1.0732714138286894,
734
- "grad_norm": 0.5546119809150696,
735
- "learning_rate": 1.4880498869364482e-05,
736
- "loss": 0.4211,
737
- "step": 1040
738
  },
739
  {
740
- "epoch": 1.08359133126935,
741
- "grad_norm": 0.5102314352989197,
742
- "learning_rate": 1.4782928162877722e-05,
743
- "loss": 0.4187,
744
- "step": 1050
745
  },
746
  {
747
- "epoch": 1.0939112487100102,
748
- "grad_norm": 0.5234063863754272,
749
- "learning_rate": 1.468476343927778e-05,
750
- "loss": 0.4177,
751
- "step": 1060
752
  },
753
  {
754
- "epoch": 1.1042311661506707,
755
- "grad_norm": 0.5099871158599854,
756
- "learning_rate": 1.4586016890160208e-05,
757
- "loss": 0.4213,
758
- "step": 1070
759
  },
760
  {
761
- "epoch": 1.1145510835913313,
762
- "grad_norm": 0.5453868508338928,
763
- "learning_rate": 1.4486700779380547e-05,
764
- "loss": 0.4192,
765
- "step": 1080
766
  },
767
  {
768
- "epoch": 1.1248710010319918,
769
- "grad_norm": 0.5475857257843018,
770
- "learning_rate": 1.4386827441531202e-05,
771
- "loss": 0.4178,
772
- "step": 1090
773
  },
774
  {
775
- "epoch": 1.1351909184726523,
776
- "grad_norm": 0.5636183619499207,
777
- "learning_rate": 1.4286409280409558e-05,
778
- "loss": 0.4167,
779
- "step": 1100
780
  },
781
  {
782
- "epoch": 1.1455108359133126,
783
- "grad_norm": 0.5477967262268066,
784
- "learning_rate": 1.4185458767477487e-05,
785
- "loss": 0.4184,
786
- "step": 1110
787
  },
788
  {
789
- "epoch": 1.1558307533539731,
790
- "grad_norm": 0.5478163361549377,
791
- "learning_rate": 1.4083988440312429e-05,
792
- "loss": 0.419,
793
- "step": 1120
794
  },
795
  {
796
- "epoch": 1.1661506707946336,
797
- "grad_norm": 0.5689426064491272,
798
- "learning_rate": 1.3982010901050305e-05,
799
- "loss": 0.4239,
800
- "step": 1130
801
  },
802
  {
803
- "epoch": 1.1764705882352942,
804
- "grad_norm": 0.5106656551361084,
805
- "learning_rate": 1.3879538814820395e-05,
806
- "loss": 0.4135,
807
- "step": 1140
808
  },
809
  {
810
- "epoch": 1.1867905056759547,
811
- "grad_norm": 0.5251624584197998,
812
- "learning_rate": 1.3776584908172364e-05,
813
- "loss": 0.4202,
814
- "step": 1150
815
  },
816
  {
817
- "epoch": 1.197110423116615,
818
- "grad_norm": 0.5535441040992737,
819
- "learning_rate": 1.3673161967495708e-05,
820
- "loss": 0.4181,
821
- "step": 1160
822
  },
823
  {
824
- "epoch": 1.2074303405572755,
825
- "grad_norm": 0.5619220733642578,
826
- "learning_rate": 1.3569282837431737e-05,
827
- "loss": 0.4202,
828
- "step": 1170
829
  },
830
  {
831
- "epoch": 1.217750257997936,
832
- "grad_norm": 0.5495029091835022,
833
- "learning_rate": 1.3464960419278332e-05,
834
- "loss": 0.4135,
835
- "step": 1180
836
  },
837
  {
838
- "epoch": 1.2280701754385965,
839
- "grad_norm": 0.5409591197967529,
840
- "learning_rate": 1.336020766938766e-05,
841
- "loss": 0.4099,
842
- "step": 1190
843
  },
844
  {
845
- "epoch": 1.238390092879257,
846
- "grad_norm": 0.5582126379013062,
847
- "learning_rate": 1.3255037597557057e-05,
848
- "loss": 0.4168,
849
- "step": 1200
850
  },
851
  {
852
- "epoch": 1.2487100103199174,
853
- "grad_norm": 0.5315924882888794,
854
- "learning_rate": 1.3149463265413282e-05,
855
- "loss": 0.4163,
856
- "step": 1210
857
  },
858
  {
859
- "epoch": 1.2590299277605779,
860
- "grad_norm": 0.5000606775283813,
861
- "learning_rate": 1.3043497784790315e-05,
862
- "loss": 0.4155,
863
- "step": 1220
864
  },
865
  {
866
- "epoch": 1.2693498452012384,
867
- "grad_norm": 0.5188019275665283,
868
- "learning_rate": 1.2937154316100927e-05,
869
- "loss": 0.4155,
870
- "step": 1230
871
  },
872
  {
873
- "epoch": 1.279669762641899,
874
- "grad_norm": 0.5054394006729126,
875
- "learning_rate": 1.283044606670223e-05,
876
- "loss": 0.4079,
877
- "step": 1240
878
  },
879
  {
880
- "epoch": 1.2899896800825594,
881
- "grad_norm": 0.5096462368965149,
882
- "learning_rate": 1.2723386289255374e-05,
883
- "loss": 0.4149,
884
- "step": 1250
885
  },
886
  {
887
- "epoch": 1.3003095975232197,
888
- "grad_norm": 0.5191652178764343,
889
- "learning_rate": 1.2615988280079645e-05,
890
- "loss": 0.4103,
891
- "step": 1260
892
  },
893
  {
894
- "epoch": 1.3106295149638802,
895
- "grad_norm": 0.4963880777359009,
896
- "learning_rate": 1.2508265377501102e-05,
897
- "loss": 0.4117,
898
- "step": 1270
899
  },
900
  {
901
- "epoch": 1.3209494324045408,
902
- "grad_norm": 0.5644184947013855,
903
- "learning_rate": 1.240023096019603e-05,
904
- "loss": 0.4139,
905
- "step": 1280
906
  },
907
  {
908
- "epoch": 1.3312693498452013,
909
- "grad_norm": 0.521536111831665,
910
- "learning_rate": 1.2291898445529384e-05,
911
- "loss": 0.4107,
912
- "step": 1290
913
  },
914
  {
915
- "epoch": 1.3415892672858618,
916
- "grad_norm": 0.5256720781326294,
917
- "learning_rate": 1.2183281287888398e-05,
918
- "loss": 0.4104,
919
- "step": 1300
920
  },
921
  {
922
- "epoch": 1.351909184726522,
923
- "grad_norm": 0.531589686870575,
924
- "learning_rate": 1.2074392977011629e-05,
925
- "loss": 0.4111,
926
- "step": 1310
927
  },
928
  {
929
- "epoch": 1.3622291021671826,
930
- "grad_norm": 0.534598171710968,
931
- "learning_rate": 1.1965247036313573e-05,
932
- "loss": 0.416,
933
- "step": 1320
934
  },
935
  {
936
- "epoch": 1.3725490196078431,
937
- "grad_norm": 0.5281124711036682,
938
- "learning_rate": 1.185585702120515e-05,
939
- "loss": 0.4041,
940
- "step": 1330
941
  },
942
  {
943
- "epoch": 1.3828689370485037,
944
- "grad_norm": 0.5332800149917603,
945
- "learning_rate": 1.1746236517410155e-05,
946
- "loss": 0.4076,
947
- "step": 1340
948
  },
949
  {
950
- "epoch": 1.3931888544891642,
951
- "grad_norm": 0.4961317181587219,
952
- "learning_rate": 1.1636399139277998e-05,
953
- "loss": 0.4067,
954
- "step": 1350
955
  },
956
  {
957
- "epoch": 1.4035087719298245,
958
- "grad_norm": 0.5210182070732117,
959
- "learning_rate": 1.1526358528092861e-05,
960
- "loss": 0.4071,
961
- "step": 1360
962
  },
963
  {
964
- "epoch": 1.413828689370485,
965
- "grad_norm": 0.518181324005127,
966
- "learning_rate": 1.1416128350379503e-05,
967
- "loss": 0.4118,
968
- "step": 1370
969
  },
970
  {
971
- "epoch": 1.4241486068111455,
972
- "grad_norm": 0.5396980047225952,
973
- "learning_rate": 1.1305722296205968e-05,
974
- "loss": 0.4073,
975
- "step": 1380
976
  },
977
  {
978
- "epoch": 1.434468524251806,
979
- "grad_norm": 0.5073665976524353,
980
- "learning_rate": 1.1195154077483313e-05,
981
- "loss": 0.4083,
982
- "step": 1390
983
  },
984
  {
985
- "epoch": 1.4447884416924666,
986
- "grad_norm": 0.5103346705436707,
987
- "learning_rate": 1.1084437426262666e-05,
988
- "loss": 0.4094,
989
- "step": 1400
990
  },
991
  {
992
- "epoch": 1.4551083591331269,
993
- "grad_norm": 0.5441737174987793,
994
- "learning_rate": 1.097358609302978e-05,
995
- "loss": 0.4124,
996
- "step": 1410
997
  },
998
  {
999
- "epoch": 1.4654282765737874,
1000
- "grad_norm": 0.49091413617134094,
1001
- "learning_rate": 1.0862613844997272e-05,
1002
- "loss": 0.4059,
1003
- "step": 1420
1004
  },
1005
  {
1006
- "epoch": 1.475748194014448,
1007
- "grad_norm": 0.49451103806495667,
1008
- "learning_rate": 1.0751534464394809e-05,
1009
- "loss": 0.4028,
1010
- "step": 1430
1011
  },
1012
  {
1013
- "epoch": 1.4860681114551084,
1014
- "grad_norm": 0.5205165147781372,
1015
- "learning_rate": 1.0640361746757413e-05,
1016
- "loss": 0.4038,
1017
- "step": 1440
1018
  },
1019
  {
1020
- "epoch": 1.496388028895769,
1021
- "grad_norm": 0.5233325958251953,
1022
- "learning_rate": 1.0529109499212137e-05,
1023
- "loss": 0.4097,
1024
- "step": 1450
1025
  },
1026
  {
1027
- "epoch": 1.5067079463364292,
1028
- "grad_norm": 0.5237818956375122,
1029
- "learning_rate": 1.0417791538763269e-05,
1030
- "loss": 0.4059,
1031
- "step": 1460
1032
  },
1033
  {
1034
- "epoch": 1.5170278637770898,
1035
- "grad_norm": 0.5263275504112244,
1036
- "learning_rate": 1.0306421690576318e-05,
1037
- "loss": 0.4074,
1038
- "step": 1470
1039
  },
1040
  {
1041
- "epoch": 1.5273477812177503,
1042
- "grad_norm": 0.5042173862457275,
1043
- "learning_rate": 1.0195013786261017e-05,
1044
- "loss": 0.4061,
1045
- "step": 1480
1046
  },
1047
  {
1048
- "epoch": 1.5376676986584106,
1049
- "grad_norm": 0.48727792501449585,
1050
- "learning_rate": 1.0083581662153488e-05,
1051
- "loss": 0.4021,
1052
- "step": 1490
1053
  },
1054
  {
1055
- "epoch": 1.5479876160990713,
1056
- "grad_norm": 0.5014871954917908,
1057
- "learning_rate": 9.972139157597836e-06,
1058
- "loss": 0.411,
1059
- "step": 1500
1060
  },
1061
  {
1062
- "epoch": 1.5583075335397316,
1063
- "grad_norm": 0.49665823578834534,
1064
- "learning_rate": 9.86070011322737e-06,
1065
- "loss": 0.4069,
1066
- "step": 1510
1067
  },
1068
  {
1069
- "epoch": 1.5686274509803921,
1070
- "grad_norm": 0.48189592361450195,
1071
- "learning_rate": 9.749278369245658e-06,
1072
- "loss": 0.4055,
1073
- "step": 1520
1074
  },
1075
  {
1076
- "epoch": 1.5789473684210527,
1077
- "grad_norm": 0.5003267526626587,
1078
- "learning_rate": 9.637887763707649e-06,
1079
- "loss": 0.4023,
1080
- "step": 1530
1081
  },
1082
  {
1083
- "epoch": 1.589267285861713,
1084
- "grad_norm": 0.4762038290500641,
1085
- "learning_rate": 9.52654213080103e-06,
1086
- "loss": 0.4063,
1087
- "step": 1540
1088
  },
1089
  {
1090
- "epoch": 1.5995872033023737,
1091
- "grad_norm": 0.48036977648735046,
1092
- "learning_rate": 9.415255299128115e-06,
1093
- "loss": 0.3991,
1094
- "step": 1550
1095
  },
1096
  {
1097
- "epoch": 1.609907120743034,
1098
- "grad_norm": 1.7054091691970825,
1099
- "learning_rate": 9.304041089988367e-06,
1100
- "loss": 0.4099,
1101
- "step": 1560
1102
  },
1103
  {
1104
- "epoch": 1.6202270381836945,
1105
- "grad_norm": 0.5128041505813599,
1106
- "learning_rate": 9.192913315661887e-06,
1107
- "loss": 0.4093,
1108
- "step": 1570
1109
  },
1110
  {
1111
- "epoch": 1.630546955624355,
1112
- "grad_norm": 0.5168408751487732,
1113
- "learning_rate": 9.081885777693969e-06,
1114
- "loss": 0.4012,
1115
- "step": 1580
1116
  },
1117
  {
1118
- "epoch": 1.6408668730650153,
1119
- "grad_norm": 0.4789281189441681,
1120
- "learning_rate": 8.97097226518103e-06,
1121
- "loss": 0.4024,
1122
- "step": 1590
1123
  },
1124
  {
1125
- "epoch": 1.651186790505676,
1126
- "grad_norm": 0.4675295650959015,
1127
- "learning_rate": 8.860186553058066e-06,
1128
- "loss": 0.3992,
1129
- "step": 1600
1130
  },
1131
  {
1132
- "epoch": 1.6615067079463364,
1133
- "grad_norm": 0.4954163730144501,
1134
- "learning_rate": 8.749542400387861e-06,
1135
- "loss": 0.3986,
1136
- "step": 1610
1137
  },
1138
  {
1139
- "epoch": 1.671826625386997,
1140
- "grad_norm": 0.4895382523536682,
1141
- "learning_rate": 8.639053548652183e-06,
1142
- "loss": 0.3949,
1143
- "step": 1620
1144
  },
1145
  {
1146
- "epoch": 1.6821465428276574,
1147
- "grad_norm": 0.49679800868034363,
1148
- "learning_rate": 8.528733720045162e-06,
1149
- "loss": 0.4042,
1150
- "step": 1630
1151
  },
1152
  {
1153
- "epoch": 1.6924664602683177,
1154
- "grad_norm": 0.470292866230011,
1155
- "learning_rate": 8.418596615769048e-06,
1156
- "loss": 0.3977,
1157
- "step": 1640
1158
  },
1159
  {
1160
- "epoch": 1.7027863777089784,
1161
- "grad_norm": 0.46729475259780884,
1162
- "learning_rate": 8.308655914332599e-06,
1163
- "loss": 0.4022,
1164
- "step": 1650
1165
  },
1166
  {
1167
- "epoch": 1.7131062951496387,
1168
- "grad_norm": 0.49843648076057434,
1169
- "learning_rate": 8.198925269852251e-06,
1170
- "loss": 0.3953,
1171
- "step": 1660
1172
  },
1173
  {
1174
- "epoch": 1.7234262125902993,
1175
- "grad_norm": 0.4577590227127075,
1176
- "learning_rate": 8.089418310356379e-06,
1177
- "loss": 0.398,
1178
- "step": 1670
1179
  },
1180
  {
1181
- "epoch": 1.7337461300309598,
1182
- "grad_norm": 0.45520010590553284,
1183
- "learning_rate": 7.980148636092719e-06,
1184
- "loss": 0.3986,
1185
- "step": 1680
1186
  },
1187
  {
1188
- "epoch": 1.74406604747162,
1189
- "grad_norm": 0.48741379380226135,
1190
- "learning_rate": 7.871129817839304e-06,
1191
- "loss": 0.3926,
1192
- "step": 1690
1193
  },
1194
  {
1195
- "epoch": 1.7543859649122808,
1196
- "grad_norm": 0.47943034768104553,
1197
- "learning_rate": 7.762375395219045e-06,
1198
- "loss": 0.403,
1199
- "step": 1700
1200
  },
1201
  {
1202
- "epoch": 1.7647058823529411,
1203
- "grad_norm": 0.4822390675544739,
1204
- "learning_rate": 7.653898875018151e-06,
1205
- "loss": 0.3967,
1206
- "step": 1710
1207
  },
1208
  {
1209
- "epoch": 1.7750257997936016,
1210
- "grad_norm": 0.47492411732673645,
1211
- "learning_rate": 7.545713729508673e-06,
1212
- "loss": 0.3955,
1213
- "step": 1720
1214
  },
1215
  {
1216
- "epoch": 1.7853457172342622,
1217
- "grad_norm": 0.48685282468795776,
1218
- "learning_rate": 7.437833394775283e-06,
1219
- "loss": 0.3974,
1220
- "step": 1730
1221
  },
1222
  {
1223
- "epoch": 1.7956656346749225,
1224
- "grad_norm": 0.47495120763778687,
1225
- "learning_rate": 7.330271269046614e-06,
1226
- "loss": 0.3997,
1227
- "step": 1740
1228
  },
1229
  {
1230
- "epoch": 1.8059855521155832,
1231
- "grad_norm": 0.4861559271812439,
1232
- "learning_rate": 7.223040711031225e-06,
1233
- "loss": 0.3972,
1234
- "step": 1750
1235
  },
1236
  {
1237
- "epoch": 1.8163054695562435,
1238
- "grad_norm": 0.4717768728733063,
1239
- "learning_rate": 7.116155038258531e-06,
1240
- "loss": 0.3963,
1241
- "step": 1760
1242
  },
1243
  {
1244
- "epoch": 1.826625386996904,
1245
- "grad_norm": 0.47078821063041687,
1246
- "learning_rate": 7.009627525424836e-06,
1247
- "loss": 0.3962,
1248
- "step": 1770
1249
  },
1250
  {
1251
- "epoch": 1.8369453044375645,
1252
- "grad_norm": 0.4606710374355316,
1253
- "learning_rate": 6.903471402744662e-06,
1254
- "loss": 0.3929,
1255
- "step": 1780
1256
  },
1257
  {
1258
- "epoch": 1.8472652218782248,
1259
- "grad_norm": 0.45694735646247864,
1260
- "learning_rate": 6.797699854307631e-06,
1261
- "loss": 0.3897,
1262
- "step": 1790
1263
  },
1264
  {
1265
- "epoch": 1.8575851393188856,
1266
- "grad_norm": 0.4747222661972046,
1267
- "learning_rate": 6.692326016441054e-06,
1268
- "loss": 0.3904,
1269
- "step": 1800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1270
  }
1271
  ],
1272
- "logging_steps": 10,
1273
- "max_steps": 2907,
1274
  "num_input_tokens_seen": 0,
1275
- "num_train_epochs": 3,
1276
  "save_steps": 200,
1277
  "stateful_callbacks": {
1278
  "TrainerControl": {
@@ -1286,8 +2406,8 @@
1286
  "attributes": {}
1287
  }
1288
  },
1289
- "total_flos": 8.247296645602371e+19,
1290
- "train_batch_size": 2,
1291
  "trial_name": null,
1292
  "trial_params": null
1293
  }
 
1
  {
2
  "best_metric": null,
3
  "best_model_checkpoint": null,
4
+ "epoch": 1.9293516810895164,
5
  "eval_steps": 500,
6
+ "global_step": 6800,
7
  "is_hyper_param_search": false,
8
  "is_local_process_zero": true,
9
  "is_world_process_zero": true,
10
  "log_history": [
11
  {
12
+ "epoch": 0.005674563767910342,
13
+ "grad_norm": 1.8945719003677368,
14
+ "learning_rate": 2.830188679245283e-06,
15
+ "loss": 0.9878,
16
+ "step": 20
17
  },
18
  {
19
+ "epoch": 0.011349127535820683,
20
+ "grad_norm": 0.8699278235435486,
21
+ "learning_rate": 5.660377358490566e-06,
22
+ "loss": 0.9338,
23
+ "step": 40
24
  },
25
  {
26
+ "epoch": 0.017023691303731027,
27
+ "grad_norm": 0.9612842798233032,
28
+ "learning_rate": 8.49056603773585e-06,
29
+ "loss": 0.8992,
30
+ "step": 60
31
  },
32
  {
33
+ "epoch": 0.022698255071641367,
34
+ "grad_norm": 1.0209581851959229,
35
+ "learning_rate": 1.1320754716981132e-05,
36
+ "loss": 0.8802,
37
+ "step": 80
38
  },
39
  {
40
+ "epoch": 0.02837281883955171,
41
+ "grad_norm": 1.1397087574005127,
42
+ "learning_rate": 1.4150943396226415e-05,
43
+ "loss": 0.8636,
44
+ "step": 100
45
  },
46
  {
47
+ "epoch": 0.034047382607462054,
48
+ "grad_norm": 1.0688011646270752,
49
+ "learning_rate": 1.69811320754717e-05,
50
+ "loss": 0.8589,
51
+ "step": 120
52
  },
53
  {
54
+ "epoch": 0.039721946375372394,
55
+ "grad_norm": 1.0701323747634888,
56
+ "learning_rate": 1.981132075471698e-05,
57
+ "loss": 0.8445,
58
+ "step": 140
59
  },
60
  {
61
+ "epoch": 0.045396510143282734,
62
+ "grad_norm": 1.0749995708465576,
63
+ "learning_rate": 2.2641509433962265e-05,
64
+ "loss": 0.8438,
65
+ "step": 160
66
  },
67
  {
68
+ "epoch": 0.051071073911193074,
69
+ "grad_norm": 1.2973322868347168,
70
+ "learning_rate": 2.547169811320755e-05,
71
+ "loss": 0.8399,
72
+ "step": 180
73
  },
74
  {
75
+ "epoch": 0.05674563767910342,
76
+ "grad_norm": 0.9941120743751526,
77
+ "learning_rate": 2.830188679245283e-05,
78
+ "loss": 0.8459,
79
+ "step": 200
80
  },
81
  {
82
+ "epoch": 0.06242020144701376,
83
+ "grad_norm": 1.1092499494552612,
84
+ "learning_rate": 2.9999898623711896e-05,
85
+ "loss": 0.8396,
86
+ "step": 220
87
  },
88
  {
89
+ "epoch": 0.06809476521492411,
90
+ "grad_norm": 1.10667085647583,
91
+ "learning_rate": 2.999875815620755e-05,
92
+ "loss": 0.8403,
93
+ "step": 240
94
  },
95
  {
96
+ "epoch": 0.07376932898283445,
97
+ "grad_norm": 1.0986227989196777,
98
+ "learning_rate": 2.999635059750628e-05,
99
+ "loss": 0.8296,
100
+ "step": 260
101
  },
102
  {
103
+ "epoch": 0.07944389275074479,
104
+ "grad_norm": 0.9648028612136841,
105
+ "learning_rate": 2.9992676150998032e-05,
106
+ "loss": 0.8187,
107
+ "step": 280
108
  },
109
  {
110
+ "epoch": 0.08511845651865513,
111
+ "grad_norm": 0.8029258251190186,
112
+ "learning_rate": 2.998773512709909e-05,
113
+ "loss": 0.8224,
114
+ "step": 300
115
  },
116
  {
117
+ "epoch": 0.09079302028656547,
118
+ "grad_norm": 0.888502299785614,
119
+ "learning_rate": 2.9981527943225862e-05,
120
+ "loss": 0.8178,
121
+ "step": 320
122
  },
123
  {
124
+ "epoch": 0.09646758405447581,
125
+ "grad_norm": 0.7894881963729858,
126
+ "learning_rate": 2.997405512375964e-05,
127
+ "loss": 0.8153,
128
+ "step": 340
129
  },
130
  {
131
+ "epoch": 0.10214214782238615,
132
+ "grad_norm": 0.8492247462272644,
133
+ "learning_rate": 2.996531730000227e-05,
134
+ "loss": 0.8105,
135
+ "step": 360
136
  },
137
  {
138
+ "epoch": 0.1078167115902965,
139
+ "grad_norm": 0.8247759938240051,
140
+ "learning_rate": 2.9955315210122842e-05,
141
+ "loss": 0.8,
142
+ "step": 380
143
  },
144
  {
145
+ "epoch": 0.11349127535820684,
146
+ "grad_norm": 0.8270812034606934,
147
+ "learning_rate": 2.99440496990953e-05,
148
+ "loss": 0.802,
149
+ "step": 400
150
  },
151
  {
152
+ "epoch": 0.11916583912611718,
153
+ "grad_norm": 0.8336136937141418,
154
+ "learning_rate": 2.9931521718627107e-05,
155
+ "loss": 0.7932,
156
+ "step": 420
157
  },
158
  {
159
+ "epoch": 0.12484040289402752,
160
+ "grad_norm": 0.7927630543708801,
161
+ "learning_rate": 2.991773232707879e-05,
162
+ "loss": 0.7903,
163
+ "step": 440
164
  },
165
  {
166
+ "epoch": 0.13051496666193788,
167
+ "grad_norm": 0.8075955510139465,
168
+ "learning_rate": 2.9902682689374578e-05,
169
+ "loss": 0.7897,
170
+ "step": 460
171
  },
172
  {
173
+ "epoch": 0.13618953042984822,
174
+ "grad_norm": 0.7381598353385925,
175
+ "learning_rate": 2.9886374076903945e-05,
176
+ "loss": 0.785,
177
+ "step": 480
178
  },
179
  {
180
+ "epoch": 0.14186409419775856,
181
+ "grad_norm": 0.799022912979126,
182
+ "learning_rate": 2.986880786741426e-05,
183
+ "loss": 0.7862,
184
+ "step": 500
185
  },
186
  {
187
+ "epoch": 0.1475386579656689,
188
+ "grad_norm": 0.7515665292739868,
189
+ "learning_rate": 2.9849985544894333e-05,
190
+ "loss": 0.7845,
191
+ "step": 520
192
  },
193
  {
194
+ "epoch": 0.15321322173357924,
195
+ "grad_norm": 0.8161646723747253,
196
+ "learning_rate": 2.982990869944908e-05,
197
+ "loss": 0.7745,
198
+ "step": 540
199
  },
200
  {
201
+ "epoch": 0.15888778550148958,
202
+ "grad_norm": 0.671816885471344,
203
+ "learning_rate": 2.9808579027165204e-05,
204
+ "loss": 0.7786,
205
+ "step": 560
206
  },
207
  {
208
+ "epoch": 0.16456234926939992,
209
+ "grad_norm": 0.7310769557952881,
210
+ "learning_rate": 2.978599832996788e-05,
211
+ "loss": 0.7742,
212
+ "step": 580
213
  },
214
  {
215
+ "epoch": 0.17023691303731026,
216
+ "grad_norm": 0.7568747401237488,
217
+ "learning_rate": 2.9762168515468548e-05,
218
+ "loss": 0.7691,
219
+ "step": 600
220
  },
221
  {
222
+ "epoch": 0.1759114768052206,
223
+ "grad_norm": 0.6345218420028687,
224
+ "learning_rate": 2.973709159680375e-05,
225
+ "loss": 0.7695,
226
+ "step": 620
227
  },
228
  {
229
+ "epoch": 0.18158604057313094,
230
+ "grad_norm": 0.7218050360679626,
231
+ "learning_rate": 2.9710769692465073e-05,
232
+ "loss": 0.7681,
233
+ "step": 640
234
  },
235
  {
236
+ "epoch": 0.18726060434104128,
237
+ "grad_norm": 0.7665095925331116,
238
+ "learning_rate": 2.9683205026120163e-05,
239
+ "loss": 0.7667,
240
+ "step": 660
241
  },
242
  {
243
+ "epoch": 0.19293516810895162,
244
+ "grad_norm": 0.6717973947525024,
245
+ "learning_rate": 2.9654399926424884e-05,
246
+ "loss": 0.7684,
247
+ "step": 680
248
  },
249
  {
250
+ "epoch": 0.19860973187686196,
251
+ "grad_norm": 0.7454754114151001,
252
+ "learning_rate": 2.9624356826826577e-05,
253
+ "loss": 0.7622,
254
+ "step": 700
255
  },
256
  {
257
+ "epoch": 0.2042842956447723,
258
+ "grad_norm": 0.6865426898002625,
259
+ "learning_rate": 2.9593078265358498e-05,
260
+ "loss": 0.761,
261
+ "step": 720
262
  },
263
  {
264
+ "epoch": 0.20995885941268266,
265
+ "grad_norm": 0.7075285315513611,
266
+ "learning_rate": 2.956056688442541e-05,
267
+ "loss": 0.7578,
268
+ "step": 740
269
  },
270
  {
271
+ "epoch": 0.215633423180593,
272
+ "grad_norm": 0.7438149452209473,
273
+ "learning_rate": 2.9526825430580337e-05,
274
+ "loss": 0.7571,
275
+ "step": 760
276
  },
277
  {
278
+ "epoch": 0.22130798694850334,
279
+ "grad_norm": 0.6830400228500366,
280
+ "learning_rate": 2.949185675429254e-05,
281
+ "loss": 0.759,
282
+ "step": 780
283
  },
284
  {
285
+ "epoch": 0.22698255071641368,
286
+ "grad_norm": 0.7147162556648254,
287
+ "learning_rate": 2.9455663809706725e-05,
288
+ "loss": 0.756,
289
+ "step": 800
290
  },
291
  {
292
+ "epoch": 0.23265711448432402,
293
+ "grad_norm": 0.7116013765335083,
294
+ "learning_rate": 2.9418249654393443e-05,
295
+ "loss": 0.7538,
296
+ "step": 820
297
  },
298
  {
299
+ "epoch": 0.23833167825223436,
300
+ "grad_norm": 0.64736407995224,
301
+ "learning_rate": 2.9379617449090847e-05,
302
+ "loss": 0.7513,
303
+ "step": 840
304
  },
305
  {
306
+ "epoch": 0.2440062420201447,
307
+ "grad_norm": 0.6453843116760254,
308
+ "learning_rate": 2.93397704574376e-05,
309
+ "loss": 0.7538,
310
+ "step": 860
311
  },
312
  {
313
+ "epoch": 0.24968080578805504,
314
+ "grad_norm": 0.6253499388694763,
315
+ "learning_rate": 2.929871204569722e-05,
316
+ "loss": 0.7463,
317
+ "step": 880
318
  },
319
  {
320
+ "epoch": 0.2553553695559654,
321
+ "grad_norm": 0.6677010655403137,
322
+ "learning_rate": 2.9256445682473683e-05,
323
+ "loss": 0.7419,
324
+ "step": 900
325
  },
326
  {
327
+ "epoch": 0.26102993332387575,
328
+ "grad_norm": 0.7070403695106506,
329
+ "learning_rate": 2.9212974938418385e-05,
330
+ "loss": 0.7449,
331
+ "step": 920
332
  },
333
  {
334
+ "epoch": 0.26670449709178606,
335
+ "grad_norm": 0.6784743070602417,
336
+ "learning_rate": 2.9168303485928495e-05,
337
+ "loss": 0.7453,
338
+ "step": 940
339
  },
340
  {
341
+ "epoch": 0.27237906085969643,
342
+ "grad_norm": 0.6076740026473999,
343
+ "learning_rate": 2.912243509883673e-05,
344
+ "loss": 0.7457,
345
+ "step": 960
346
  },
347
  {
348
+ "epoch": 0.27805362462760674,
349
+ "grad_norm": 0.6722409129142761,
350
+ "learning_rate": 2.9075373652092535e-05,
351
+ "loss": 0.7373,
352
+ "step": 980
353
  },
354
  {
355
+ "epoch": 0.2837281883955171,
356
+ "grad_norm": 0.7188818454742432,
357
+ "learning_rate": 2.9027123121434714e-05,
358
+ "loss": 0.7343,
359
+ "step": 1000
360
  },
361
  {
362
+ "epoch": 0.2894027521634274,
363
+ "grad_norm": 0.657289981842041,
364
+ "learning_rate": 2.897768758305558e-05,
365
+ "loss": 0.7336,
366
+ "step": 1020
367
  },
368
  {
369
+ "epoch": 0.2950773159313378,
370
+ "grad_norm": 0.6076385378837585,
371
+ "learning_rate": 2.892707121325658e-05,
372
+ "loss": 0.7331,
373
+ "step": 1040
374
  },
375
  {
376
+ "epoch": 0.3007518796992481,
377
+ "grad_norm": 0.6217896342277527,
378
+ "learning_rate": 2.8875278288095507e-05,
379
+ "loss": 0.7339,
380
+ "step": 1060
381
  },
382
  {
383
+ "epoch": 0.30642644346715847,
384
+ "grad_norm": 0.6453694701194763,
385
+ "learning_rate": 2.882231318302523e-05,
386
+ "loss": 0.7334,
387
+ "step": 1080
388
  },
389
  {
390
+ "epoch": 0.3121010072350688,
391
+ "grad_norm": 0.6069263219833374,
392
+ "learning_rate": 2.8768180372524093e-05,
393
+ "loss": 0.734,
394
+ "step": 1100
395
  },
396
  {
397
+ "epoch": 0.31777557100297915,
398
+ "grad_norm": 0.6342785358428955,
399
+ "learning_rate": 2.8712884429717873e-05,
400
+ "loss": 0.7254,
401
+ "step": 1120
402
  },
403
  {
404
+ "epoch": 0.32345013477088946,
405
+ "grad_norm": 0.5936433672904968,
406
+ "learning_rate": 2.8656430025993464e-05,
407
+ "loss": 0.7232,
408
+ "step": 1140
409
  },
410
  {
411
+ "epoch": 0.32912469853879983,
412
+ "grad_norm": 0.5988269448280334,
413
+ "learning_rate": 2.8598821930604252e-05,
414
+ "loss": 0.726,
415
+ "step": 1160
416
  },
417
  {
418
+ "epoch": 0.3347992623067102,
419
+ "grad_norm": 0.6247944235801697,
420
+ "learning_rate": 2.8540065010267183e-05,
421
+ "loss": 0.729,
422
+ "step": 1180
423
  },
424
  {
425
+ "epoch": 0.3404738260746205,
426
+ "grad_norm": 0.6017037034034729,
427
+ "learning_rate": 2.848016422875164e-05,
428
+ "loss": 0.7216,
429
+ "step": 1200
430
  },
431
  {
432
+ "epoch": 0.3461483898425309,
433
+ "grad_norm": 0.7368952631950378,
434
+ "learning_rate": 2.84191246464601e-05,
435
+ "loss": 0.7331,
436
+ "step": 1220
437
  },
438
  {
439
+ "epoch": 0.3518229536104412,
440
+ "grad_norm": 0.6655734777450562,
441
+ "learning_rate": 2.835695142000064e-05,
442
+ "loss": 0.7233,
443
+ "step": 1240
444
  },
445
  {
446
+ "epoch": 0.35749751737835156,
447
+ "grad_norm": 0.6325275301933289,
448
+ "learning_rate": 2.8293649801751288e-05,
449
+ "loss": 0.7208,
450
+ "step": 1260
451
  },
452
  {
453
+ "epoch": 0.36317208114626187,
454
+ "grad_norm": 0.6046157479286194,
455
+ "learning_rate": 2.822922513941634e-05,
456
+ "loss": 0.7156,
457
+ "step": 1280
458
  },
459
  {
460
+ "epoch": 0.36884664491417224,
461
+ "grad_norm": 0.6081031560897827,
462
+ "learning_rate": 2.816368287557454e-05,
463
+ "loss": 0.722,
464
+ "step": 1300
465
  },
466
  {
467
+ "epoch": 0.37452120868208255,
468
+ "grad_norm": 0.6153631806373596,
469
+ "learning_rate": 2.809702854721934e-05,
470
+ "loss": 0.7171,
471
+ "step": 1320
472
  },
473
  {
474
+ "epoch": 0.3801957724499929,
475
+ "grad_norm": 0.6361656188964844,
476
+ "learning_rate": 2.8029267785291092e-05,
477
+ "loss": 0.7134,
478
+ "step": 1340
479
  },
480
  {
481
+ "epoch": 0.38587033621790323,
482
+ "grad_norm": 0.6033869981765747,
483
+ "learning_rate": 2.796040631420139e-05,
484
+ "loss": 0.7171,
485
+ "step": 1360
486
  },
487
  {
488
+ "epoch": 0.3915448999858136,
489
+ "grad_norm": 0.6300106644630432,
490
+ "learning_rate": 2.789044995134944e-05,
491
+ "loss": 0.7139,
492
+ "step": 1380
493
  },
494
  {
495
+ "epoch": 0.3972194637537239,
496
+ "grad_norm": 0.5989068150520325,
497
+ "learning_rate": 2.781940460663062e-05,
498
+ "loss": 0.7142,
499
+ "step": 1400
500
  },
501
  {
502
+ "epoch": 0.4028940275216343,
503
+ "grad_norm": 0.5790150761604309,
504
+ "learning_rate": 2.774727628193721e-05,
505
+ "loss": 0.7126,
506
+ "step": 1420
507
  },
508
  {
509
+ "epoch": 0.4085685912895446,
510
+ "grad_norm": 0.5948804616928101,
511
+ "learning_rate": 2.7674071070651378e-05,
512
+ "loss": 0.7103,
513
+ "step": 1440
514
  },
515
  {
516
+ "epoch": 0.41424315505745496,
517
+ "grad_norm": 0.6838712096214294,
518
+ "learning_rate": 2.7599795157130364e-05,
519
+ "loss": 0.7169,
520
+ "step": 1460
521
  },
522
  {
523
+ "epoch": 0.4199177188253653,
524
+ "grad_norm": 0.6502018570899963,
525
+ "learning_rate": 2.7524454816184076e-05,
526
+ "loss": 0.7094,
527
+ "step": 1480
528
  },
529
  {
530
+ "epoch": 0.42559228259327564,
531
+ "grad_norm": 0.6322967410087585,
532
+ "learning_rate": 2.7448056412544956e-05,
533
+ "loss": 0.7134,
534
+ "step": 1500
535
  },
536
  {
537
+ "epoch": 0.431266846361186,
538
+ "grad_norm": 0.5761287212371826,
539
+ "learning_rate": 2.7370606400330334e-05,
540
+ "loss": 0.7067,
541
+ "step": 1520
542
  },
543
  {
544
+ "epoch": 0.4369414101290963,
545
+ "grad_norm": 0.6147580742835999,
546
+ "learning_rate": 2.729211132249713e-05,
547
+ "loss": 0.7078,
548
+ "step": 1540
549
  },
550
  {
551
+ "epoch": 0.4426159738970067,
552
+ "grad_norm": 0.6231666207313538,
553
+ "learning_rate": 2.7212577810289157e-05,
554
+ "loss": 0.7066,
555
+ "step": 1560
556
  },
557
  {
558
+ "epoch": 0.448290537664917,
559
+ "grad_norm": 0.5739862322807312,
560
+ "learning_rate": 2.713201258267689e-05,
561
+ "loss": 0.708,
562
+ "step": 1580
563
  },
564
  {
565
+ "epoch": 0.45396510143282737,
566
+ "grad_norm": 0.7059602737426758,
567
+ "learning_rate": 2.7050422445789843e-05,
568
+ "loss": 0.7043,
569
+ "step": 1600
570
  },
571
  {
572
+ "epoch": 0.4596396652007377,
573
+ "grad_norm": 0.6156895160675049,
574
+ "learning_rate": 2.696781429234162e-05,
575
+ "loss": 0.7118,
576
+ "step": 1620
577
  },
578
  {
579
+ "epoch": 0.46531422896864805,
580
+ "grad_norm": 0.5444714426994324,
581
+ "learning_rate": 2.6884195101047567e-05,
582
+ "loss": 0.7031,
583
+ "step": 1640
584
  },
585
  {
586
+ "epoch": 0.47098879273655836,
587
+ "grad_norm": 0.6431369185447693,
588
+ "learning_rate": 2.6799571936035284e-05,
589
+ "loss": 0.7056,
590
+ "step": 1660
591
  },
592
  {
593
+ "epoch": 0.4766633565044687,
594
+ "grad_norm": 0.6375367641448975,
595
+ "learning_rate": 2.671395194624779e-05,
596
+ "loss": 0.6991,
597
+ "step": 1680
598
  },
599
  {
600
+ "epoch": 0.48233792027237904,
601
+ "grad_norm": 0.6311667561531067,
602
+ "learning_rate": 2.6627342364839604e-05,
603
+ "loss": 0.6991,
604
+ "step": 1700
605
  },
606
  {
607
+ "epoch": 0.4880124840402894,
608
+ "grad_norm": 0.580328643321991,
609
+ "learning_rate": 2.6539750508565683e-05,
610
+ "loss": 0.7027,
611
+ "step": 1720
612
  },
613
  {
614
+ "epoch": 0.4936870478081997,
615
+ "grad_norm": 0.6254743933677673,
616
+ "learning_rate": 2.6451183777163316e-05,
617
+ "loss": 0.6977,
618
+ "step": 1740
619
  },
620
  {
621
+ "epoch": 0.4993616115761101,
622
+ "grad_norm": 0.8747753500938416,
623
+ "learning_rate": 2.636164965272699e-05,
624
+ "loss": 0.6974,
625
+ "step": 1760
626
  },
627
  {
628
+ "epoch": 0.5050361753440205,
629
+ "grad_norm": 0.5931680798530579,
630
+ "learning_rate": 2.6271155699076305e-05,
631
+ "loss": 0.7001,
632
+ "step": 1780
633
  },
634
  {
635
+ "epoch": 0.5107107391119308,
636
+ "grad_norm": 0.5763223767280579,
637
+ "learning_rate": 2.6179709561116983e-05,
638
+ "loss": 0.7023,
639
+ "step": 1800
640
  },
641
  {
642
+ "epoch": 0.5163853028798411,
643
+ "grad_norm": 0.5211492776870728,
644
+ "learning_rate": 2.6087318964195032e-05,
645
+ "loss": 0.6957,
646
+ "step": 1820
647
  },
648
  {
649
+ "epoch": 0.5220598666477515,
650
+ "grad_norm": 0.5684000253677368,
651
+ "learning_rate": 2.59939917134441e-05,
652
+ "loss": 0.6916,
653
+ "step": 1840
654
  },
655
  {
656
+ "epoch": 0.5277344304156618,
657
+ "grad_norm": 0.6029589176177979,
658
+ "learning_rate": 2.5899735693126113e-05,
659
+ "loss": 0.6942,
660
+ "step": 1860
661
  },
662
  {
663
+ "epoch": 0.5334089941835721,
664
+ "grad_norm": 0.5765926837921143,
665
+ "learning_rate": 2.5804558865965206e-05,
666
+ "loss": 0.6973,
667
+ "step": 1880
668
  },
669
  {
670
+ "epoch": 0.5390835579514824,
671
+ "grad_norm": 0.5227144956588745,
672
+ "learning_rate": 2.5708469272475044e-05,
673
+ "loss": 0.6929,
674
+ "step": 1900
675
  },
676
  {
677
+ "epoch": 0.5447581217193929,
678
+ "grad_norm": 0.6175386309623718,
679
+ "learning_rate": 2.5611475030279546e-05,
680
+ "loss": 0.6908,
681
+ "step": 1920
682
  },
683
  {
684
+ "epoch": 0.5504326854873032,
685
+ "grad_norm": 0.5724866986274719,
686
+ "learning_rate": 2.5513584333427125e-05,
687
+ "loss": 0.6893,
688
+ "step": 1940
689
  },
690
  {
691
+ "epoch": 0.5561072492552135,
692
+ "grad_norm": 0.5964395403862,
693
+ "learning_rate": 2.541480545169846e-05,
694
+ "loss": 0.6944,
695
+ "step": 1960
696
  },
697
  {
698
+ "epoch": 0.5617818130231238,
699
+ "grad_norm": 0.6019209027290344,
700
+ "learning_rate": 2.5315146729907827e-05,
701
+ "loss": 0.6899,
702
+ "step": 1980
703
  },
704
  {
705
+ "epoch": 0.5674563767910342,
706
+ "grad_norm": 0.6371375918388367,
707
+ "learning_rate": 2.521461658719819e-05,
708
+ "loss": 0.6904,
709
+ "step": 2000
710
  },
711
  {
712
+ "epoch": 0.5731309405589445,
713
+ "grad_norm": 0.5762882232666016,
714
+ "learning_rate": 2.5113223516329924e-05,
715
+ "loss": 0.6887,
716
+ "step": 2020
717
  },
718
  {
719
+ "epoch": 0.5788055043268548,
720
+ "grad_norm": 0.591663122177124,
721
+ "learning_rate": 2.501097608296334e-05,
722
+ "loss": 0.6894,
723
+ "step": 2040
724
  },
725
  {
726
+ "epoch": 0.5844800680947652,
727
+ "grad_norm": 0.5833630561828613,
728
+ "learning_rate": 2.4907882924935072e-05,
729
+ "loss": 0.6866,
730
+ "step": 2060
731
  },
732
  {
733
+ "epoch": 0.5901546318626756,
734
+ "grad_norm": 0.5615355968475342,
735
+ "learning_rate": 2.4803952751528363e-05,
736
+ "loss": 0.6927,
737
+ "step": 2080
738
  },
739
  {
740
+ "epoch": 0.5958291956305859,
741
+ "grad_norm": 0.5507014989852905,
742
+ "learning_rate": 2.4699194342737295e-05,
743
+ "loss": 0.6934,
744
+ "step": 2100
745
  },
746
  {
747
+ "epoch": 0.6015037593984962,
748
+ "grad_norm": 0.5132161974906921,
749
+ "learning_rate": 2.459361654852505e-05,
750
+ "loss": 0.688,
751
+ "step": 2120
752
  },
753
  {
754
+ "epoch": 0.6071783231664066,
755
+ "grad_norm": 0.5238850116729736,
756
+ "learning_rate": 2.4487228288076293e-05,
757
+ "loss": 0.6804,
758
+ "step": 2140
759
  },
760
  {
761
+ "epoch": 0.6128528869343169,
762
+ "grad_norm": 0.5849164724349976,
763
+ "learning_rate": 2.438003854904366e-05,
764
+ "loss": 0.6911,
765
+ "step": 2160
766
  },
767
  {
768
+ "epoch": 0.6185274507022273,
769
+ "grad_norm": 0.5290674567222595,
770
+ "learning_rate": 2.4272056386788485e-05,
771
+ "loss": 0.6838,
772
+ "step": 2180
773
  },
774
  {
775
+ "epoch": 0.6242020144701376,
776
+ "grad_norm": 0.5804121494293213,
777
+ "learning_rate": 2.4163290923615814e-05,
778
+ "loss": 0.6894,
779
+ "step": 2200
780
  },
781
  {
782
+ "epoch": 0.629876578238048,
783
+ "grad_norm": 0.5559779405593872,
784
+ "learning_rate": 2.4053751348003757e-05,
785
+ "loss": 0.6859,
786
+ "step": 2220
787
  },
788
  {
789
+ "epoch": 0.6355511420059583,
790
+ "grad_norm": 0.5486791133880615,
791
+ "learning_rate": 2.394344691382723e-05,
792
+ "loss": 0.6836,
793
+ "step": 2240
794
  },
795
  {
796
+ "epoch": 0.6412257057738686,
797
+ "grad_norm": 0.5544127225875854,
798
+ "learning_rate": 2.3832386939576214e-05,
799
+ "loss": 0.681,
800
+ "step": 2260
801
  },
802
  {
803
+ "epoch": 0.6469002695417789,
804
+ "grad_norm": 0.5256103277206421,
805
+ "learning_rate": 2.3720580807568513e-05,
806
+ "loss": 0.6823,
807
+ "step": 2280
808
  },
809
  {
810
+ "epoch": 0.6525748333096894,
811
+ "grad_norm": 0.5488288402557373,
812
+ "learning_rate": 2.3608037963157142e-05,
813
+ "loss": 0.6818,
814
+ "step": 2300
815
  },
816
  {
817
+ "epoch": 0.6582493970775997,
818
+ "grad_norm": 0.5254908204078674,
819
+ "learning_rate": 2.3494767913932393e-05,
820
+ "loss": 0.6774,
821
+ "step": 2320
822
  },
823
  {
824
+ "epoch": 0.66392396084551,
825
+ "grad_norm": 0.5880591869354248,
826
+ "learning_rate": 2.338078022891864e-05,
827
+ "loss": 0.6795,
828
+ "step": 2340
829
  },
830
  {
831
+ "epoch": 0.6695985246134204,
832
+ "grad_norm": 0.5331950783729553,
833
+ "learning_rate": 2.3266084537765924e-05,
834
+ "loss": 0.6777,
835
+ "step": 2360
836
  },
837
  {
838
+ "epoch": 0.6752730883813307,
839
+ "grad_norm": 0.5736955404281616,
840
+ "learning_rate": 2.3150690529936475e-05,
841
+ "loss": 0.6792,
842
+ "step": 2380
843
  },
844
  {
845
+ "epoch": 0.680947652149241,
846
+ "grad_norm": 0.5705032348632812,
847
+ "learning_rate": 2.303460795388613e-05,
848
+ "loss": 0.6736,
849
+ "step": 2400
850
  },
851
  {
852
+ "epoch": 0.6866222159171513,
853
+ "grad_norm": 0.569355845451355,
854
+ "learning_rate": 2.2917846616240784e-05,
855
+ "loss": 0.6767,
856
+ "step": 2420
857
  },
858
  {
859
+ "epoch": 0.6922967796850618,
860
+ "grad_norm": 1.2819143533706665,
861
+ "learning_rate": 2.2800416380967952e-05,
862
+ "loss": 0.6772,
863
+ "step": 2440
864
  },
865
  {
866
+ "epoch": 0.6979713434529721,
867
+ "grad_norm": 0.5238373279571533,
868
+ "learning_rate": 2.268232716854343e-05,
869
+ "loss": 0.674,
870
+ "step": 2460
871
  },
872
  {
873
+ "epoch": 0.7036459072208824,
874
+ "grad_norm": 0.5886688828468323,
875
+ "learning_rate": 2.2563588955113246e-05,
876
+ "loss": 0.6757,
877
+ "step": 2480
878
  },
879
  {
880
+ "epoch": 0.7093204709887927,
881
+ "grad_norm": 0.5450348854064941,
882
+ "learning_rate": 2.244421177165085e-05,
883
+ "loss": 0.6691,
884
+ "step": 2500
885
  },
886
  {
887
+ "epoch": 0.7149950347567031,
888
+ "grad_norm": 0.5553733706474304,
889
+ "learning_rate": 2.232420570310974e-05,
890
+ "loss": 0.6751,
891
+ "step": 2520
892
  },
893
  {
894
+ "epoch": 0.7206695985246134,
895
+ "grad_norm": 0.5076789259910583,
896
+ "learning_rate": 2.2203580887571423e-05,
897
+ "loss": 0.6739,
898
+ "step": 2540
899
  },
900
  {
901
+ "epoch": 0.7263441622925237,
902
+ "grad_norm": 0.5153952240943909,
903
+ "learning_rate": 2.2082347515389027e-05,
904
+ "loss": 0.6734,
905
+ "step": 2560
906
  },
907
  {
908
+ "epoch": 0.732018726060434,
909
+ "grad_norm": 0.5176730155944824,
910
+ "learning_rate": 2.1960515828326372e-05,
911
+ "loss": 0.6706,
912
+ "step": 2580
913
  },
914
  {
915
+ "epoch": 0.7376932898283445,
916
+ "grad_norm": 0.526030421257019,
917
+ "learning_rate": 2.1838096118692768e-05,
918
+ "loss": 0.6694,
919
+ "step": 2600
920
  },
921
  {
922
+ "epoch": 0.7433678535962548,
923
+ "grad_norm": 0.6030652523040771,
924
+ "learning_rate": 2.1715098728473518e-05,
925
+ "loss": 0.6707,
926
+ "step": 2620
927
  },
928
  {
929
+ "epoch": 0.7490424173641651,
930
+ "grad_norm": 0.6607082486152649,
931
+ "learning_rate": 2.1591534048456225e-05,
932
+ "loss": 0.6668,
933
+ "step": 2640
934
  },
935
  {
936
+ "epoch": 0.7547169811320755,
937
+ "grad_norm": 0.5300272107124329,
938
+ "learning_rate": 2.1467412517352996e-05,
939
+ "loss": 0.6696,
940
+ "step": 2660
941
  },
942
  {
943
+ "epoch": 0.7603915448999858,
944
+ "grad_norm": 0.5344169735908508,
945
+ "learning_rate": 2.1342744620918568e-05,
946
+ "loss": 0.6736,
947
+ "step": 2680
948
  },
949
  {
950
+ "epoch": 0.7660661086678962,
951
+ "grad_norm": 0.5058417916297913,
952
+ "learning_rate": 2.121754089106448e-05,
953
+ "loss": 0.6681,
954
+ "step": 2700
955
  },
956
  {
957
+ "epoch": 0.7717406724358065,
958
+ "grad_norm": 0.5440433621406555,
959
+ "learning_rate": 2.1091811904969344e-05,
960
+ "loss": 0.6702,
961
+ "step": 2720
962
  },
963
  {
964
+ "epoch": 0.7774152362037169,
965
+ "grad_norm": 0.5361486077308655,
966
+ "learning_rate": 2.096556828418528e-05,
967
+ "loss": 0.6686,
968
+ "step": 2740
969
  },
970
  {
971
+ "epoch": 0.7830897999716272,
972
+ "grad_norm": 0.6350403428077698,
973
+ "learning_rate": 2.0838820693740603e-05,
974
+ "loss": 0.6678,
975
+ "step": 2760
976
  },
977
  {
978
+ "epoch": 0.7887643637395375,
979
+ "grad_norm": 0.5326098203659058,
980
+ "learning_rate": 2.0711579841238875e-05,
981
+ "loss": 0.6711,
982
+ "step": 2780
983
  },
984
  {
985
+ "epoch": 0.7944389275074478,
986
+ "grad_norm": 0.540676474571228,
987
+ "learning_rate": 2.058385647595429e-05,
988
+ "loss": 0.6705,
989
+ "step": 2800
990
  },
991
  {
992
+ "epoch": 0.8001134912753582,
993
+ "grad_norm": 0.4930702745914459,
994
+ "learning_rate": 2.045566138792361e-05,
995
+ "loss": 0.6683,
996
+ "step": 2820
997
  },
998
  {
999
+ "epoch": 0.8057880550432686,
1000
+ "grad_norm": 0.5729920268058777,
1001
+ "learning_rate": 2.032700540703459e-05,
1002
+ "loss": 0.6646,
1003
+ "step": 2840
1004
  },
1005
  {
1006
+ "epoch": 0.8114626188111789,
1007
+ "grad_norm": 0.5179927945137024,
1008
+ "learning_rate": 2.0197899402111127e-05,
1009
+ "loss": 0.6632,
1010
+ "step": 2860
1011
  },
1012
  {
1013
+ "epoch": 0.8171371825790892,
1014
+ "grad_norm": 0.5147942900657654,
1015
+ "learning_rate": 2.0068354279995008e-05,
1016
+ "loss": 0.6558,
1017
+ "step": 2880
1018
  },
1019
  {
1020
+ "epoch": 0.8228117463469996,
1021
+ "grad_norm": 0.5044906735420227,
1022
+ "learning_rate": 1.9938380984624533e-05,
1023
+ "loss": 0.6634,
1024
+ "step": 2900
1025
  },
1026
  {
1027
+ "epoch": 0.8284863101149099,
1028
+ "grad_norm": 0.5231923460960388,
1029
+ "learning_rate": 1.9807990496109965e-05,
1030
+ "loss": 0.6698,
1031
+ "step": 2920
1032
  },
1033
  {
1034
+ "epoch": 0.8341608738828202,
1035
+ "grad_norm": 0.5322957634925842,
1036
+ "learning_rate": 1.967719382980594e-05,
1037
+ "loss": 0.6568,
1038
+ "step": 2940
1039
  },
1040
  {
1041
+ "epoch": 0.8398354376507307,
1042
+ "grad_norm": 0.512269139289856,
1043
+ "learning_rate": 1.9546002035380886e-05,
1044
+ "loss": 0.6654,
1045
+ "step": 2960
1046
  },
1047
  {
1048
+ "epoch": 0.845510001418641,
1049
+ "grad_norm": 0.508976399898529,
1050
+ "learning_rate": 1.9414426195883558e-05,
1051
+ "loss": 0.6552,
1052
+ "step": 2980
1053
  },
1054
  {
1055
+ "epoch": 0.8511845651865513,
1056
+ "grad_norm": 0.5061299204826355,
1057
+ "learning_rate": 1.9282477426806723e-05,
1058
+ "loss": 0.6599,
1059
+ "step": 3000
1060
  },
1061
  {
1062
+ "epoch": 0.8568591289544616,
1063
+ "grad_norm": 0.510822057723999,
1064
+ "learning_rate": 1.9150166875148155e-05,
1065
+ "loss": 0.6612,
1066
+ "step": 3020
1067
  },
1068
  {
1069
+ "epoch": 0.862533692722372,
1070
+ "grad_norm": 0.5578708648681641,
1071
+ "learning_rate": 1.9017505718468934e-05,
1072
+ "loss": 0.658,
1073
+ "step": 3040
1074
  },
1075
  {
1076
+ "epoch": 0.8682082564902823,
1077
+ "grad_norm": 0.5130868554115295,
1078
+ "learning_rate": 1.888450516394914e-05,
1079
+ "loss": 0.6541,
1080
+ "step": 3060
1081
  },
1082
  {
1083
+ "epoch": 0.8738828202581926,
1084
+ "grad_norm": 0.5147811770439148,
1085
+ "learning_rate": 1.8751176447441104e-05,
1086
+ "loss": 0.6586,
1087
+ "step": 3080
1088
  },
1089
  {
1090
+ "epoch": 0.879557384026103,
1091
+ "grad_norm": 0.5556140542030334,
1092
+ "learning_rate": 1.861753083252021e-05,
1093
+ "loss": 0.6535,
1094
+ "step": 3100
1095
  },
1096
  {
1097
+ "epoch": 0.8852319477940134,
1098
+ "grad_norm": 0.509611964225769,
1099
+ "learning_rate": 1.8483579609533318e-05,
1100
+ "loss": 0.6537,
1101
+ "step": 3120
1102
  },
1103
  {
1104
+ "epoch": 0.8909065115619237,
1105
+ "grad_norm": 0.5088684558868408,
1106
+ "learning_rate": 1.834933409464499e-05,
1107
+ "loss": 0.6562,
1108
+ "step": 3140
1109
  },
1110
  {
1111
+ "epoch": 0.896581075329834,
1112
+ "grad_norm": 0.48405396938323975,
1113
+ "learning_rate": 1.821480562888148e-05,
1114
+ "loss": 0.6583,
1115
+ "step": 3160
1116
  },
1117
  {
1118
+ "epoch": 0.9022556390977443,
1119
+ "grad_norm": 0.5087782144546509,
1120
+ "learning_rate": 1.808000557717268e-05,
1121
+ "loss": 0.6558,
1122
+ "step": 3180
1123
  },
1124
  {
1125
+ "epoch": 0.9079302028656547,
1126
+ "grad_norm": 0.5303909778594971,
1127
+ "learning_rate": 1.7944945327391957e-05,
1128
+ "loss": 0.6517,
1129
+ "step": 3200
1130
  },
1131
  {
1132
+ "epoch": 0.913604766633565,
1133
+ "grad_norm": 0.5164442658424377,
1134
+ "learning_rate": 1.7809636289394185e-05,
1135
+ "loss": 0.6529,
1136
+ "step": 3220
1137
  },
1138
  {
1139
+ "epoch": 0.9192793304014754,
1140
+ "grad_norm": 0.5162308216094971,
1141
+ "learning_rate": 1.7674089894051774e-05,
1142
+ "loss": 0.6542,
1143
+ "step": 3240
1144
  },
1145
  {
1146
+ "epoch": 0.9249538941693858,
1147
+ "grad_norm": 0.545396625995636,
1148
+ "learning_rate": 1.753831759228903e-05,
1149
+ "loss": 0.6527,
1150
+ "step": 3260
1151
  },
1152
  {
1153
+ "epoch": 0.9306284579372961,
1154
+ "grad_norm": 0.5134595632553101,
1155
+ "learning_rate": 1.740233085411477e-05,
1156
+ "loss": 0.6555,
1157
+ "step": 3280
1158
  },
1159
  {
1160
+ "epoch": 0.9363030217052064,
1161
+ "grad_norm": 0.48815637826919556,
1162
+ "learning_rate": 1.7266141167653353e-05,
1163
+ "loss": 0.6554,
1164
+ "step": 3300
1165
  },
1166
  {
1167
+ "epoch": 0.9419775854731167,
1168
+ "grad_norm": 0.5034410953521729,
1169
+ "learning_rate": 1.7129760038174146e-05,
1170
+ "loss": 0.6514,
1171
+ "step": 3320
1172
  },
1173
  {
1174
+ "epoch": 0.9476521492410271,
1175
+ "grad_norm": 0.5322323441505432,
1176
+ "learning_rate": 1.6993198987119576e-05,
1177
+ "loss": 0.6533,
1178
+ "step": 3340
1179
  },
1180
  {
1181
+ "epoch": 0.9533267130089375,
1182
+ "grad_norm": 0.48363253474235535,
1183
+ "learning_rate": 1.6856469551131805e-05,
1184
+ "loss": 0.6468,
1185
+ "step": 3360
1186
  },
1187
  {
1188
+ "epoch": 0.9590012767768478,
1189
+ "grad_norm": 0.4600164592266083,
1190
+ "learning_rate": 1.67195832810781e-05,
1191
+ "loss": 0.6472,
1192
+ "step": 3380
1193
  },
1194
  {
1195
+ "epoch": 0.9646758405447581,
1196
+ "grad_norm": 0.49600768089294434,
1197
+ "learning_rate": 1.6582551741075033e-05,
1198
+ "loss": 0.6467,
1199
+ "step": 3400
1200
  },
1201
  {
1202
+ "epoch": 0.9703504043126685,
1203
+ "grad_norm": 0.7202423810958862,
1204
+ "learning_rate": 1.6445386507511546e-05,
1205
+ "loss": 0.6502,
1206
+ "step": 3420
1207
  },
1208
  {
1209
+ "epoch": 0.9760249680805788,
1210
+ "grad_norm": 0.502703070640564,
1211
+ "learning_rate": 1.630809916807098e-05,
1212
+ "loss": 0.6424,
1213
+ "step": 3440
1214
  },
1215
  {
1216
+ "epoch": 0.9816995318484891,
1217
+ "grad_norm": 0.49266818165779114,
1218
+ "learning_rate": 1.617070132075214e-05,
1219
+ "loss": 0.6485,
1220
+ "step": 3460
1221
  },
1222
  {
1223
+ "epoch": 0.9873740956163994,
1224
+ "grad_norm": 0.5194821357727051,
1225
+ "learning_rate": 1.6033204572889516e-05,
1226
+ "loss": 0.6499,
1227
+ "step": 3480
1228
  },
1229
  {
1230
+ "epoch": 0.9930486593843099,
1231
+ "grad_norm": 0.49109163880348206,
1232
+ "learning_rate": 1.5895620540172682e-05,
1233
+ "loss": 0.6506,
1234
+ "step": 3500
1235
  },
1236
  {
1237
+ "epoch": 0.9987232231522202,
1238
+ "grad_norm": 0.5099320411682129,
1239
+ "learning_rate": 1.575796084566503e-05,
1240
+ "loss": 0.6466,
1241
+ "step": 3520
1242
  },
1243
  {
1244
+ "epoch": 1.0043977869201306,
1245
+ "grad_norm": 0.5476223230361938,
1246
+ "learning_rate": 1.562023711882182e-05,
1247
+ "loss": 0.5924,
1248
+ "step": 3540
1249
  },
1250
  {
1251
+ "epoch": 1.010072350688041,
1252
+ "grad_norm": 0.4934983551502228,
1253
+ "learning_rate": 1.548246099450776e-05,
1254
+ "loss": 0.5683,
1255
+ "step": 3560
1256
  },
1257
  {
1258
+ "epoch": 1.0157469144559512,
1259
+ "grad_norm": 0.5262681841850281,
1260
+ "learning_rate": 1.534464411201409e-05,
1261
+ "loss": 0.5733,
1262
+ "step": 3580
1263
  },
1264
  {
1265
+ "epoch": 1.0214214782238615,
1266
+ "grad_norm": 0.5271425843238831,
1267
+ "learning_rate": 1.520679811407526e-05,
1268
+ "loss": 0.5697,
1269
+ "step": 3600
1270
+ },
1271
+ {
1272
+ "epoch": 1.0270960419917718,
1273
+ "grad_norm": 0.5124356150627136,
1274
+ "learning_rate": 1.506893464588542e-05,
1275
+ "loss": 0.5653,
1276
+ "step": 3620
1277
+ },
1278
+ {
1279
+ "epoch": 1.0327706057596822,
1280
+ "grad_norm": 0.5131009817123413,
1281
+ "learning_rate": 1.4931065354114584e-05,
1282
+ "loss": 0.5669,
1283
+ "step": 3640
1284
+ },
1285
+ {
1286
+ "epoch": 1.0384451695275925,
1287
+ "grad_norm": 0.5003370046615601,
1288
+ "learning_rate": 1.4793201885924745e-05,
1289
+ "loss": 0.565,
1290
+ "step": 3660
1291
+ },
1292
+ {
1293
+ "epoch": 1.044119733295503,
1294
+ "grad_norm": 0.5440374612808228,
1295
+ "learning_rate": 1.465535588798592e-05,
1296
+ "loss": 0.5708,
1297
+ "step": 3680
1298
+ },
1299
+ {
1300
+ "epoch": 1.0497942970634133,
1301
+ "grad_norm": 0.5212259292602539,
1302
+ "learning_rate": 1.4517539005492237e-05,
1303
+ "loss": 0.57,
1304
+ "step": 3700
1305
+ },
1306
+ {
1307
+ "epoch": 1.0554688608313236,
1308
+ "grad_norm": 0.5004721879959106,
1309
+ "learning_rate": 1.4379762881178182e-05,
1310
+ "loss": 0.5692,
1311
+ "step": 3720
1312
+ },
1313
+ {
1314
+ "epoch": 1.061143424599234,
1315
+ "grad_norm": 0.5253936648368835,
1316
+ "learning_rate": 1.4242039154334973e-05,
1317
+ "loss": 0.5685,
1318
+ "step": 3740
1319
+ },
1320
+ {
1321
+ "epoch": 1.0668179883671443,
1322
+ "grad_norm": 0.5163034200668335,
1323
+ "learning_rate": 1.410437945982732e-05,
1324
+ "loss": 0.5706,
1325
+ "step": 3760
1326
+ },
1327
+ {
1328
+ "epoch": 1.0724925521350546,
1329
+ "grad_norm": 0.49630168080329895,
1330
+ "learning_rate": 1.3966795427110493e-05,
1331
+ "loss": 0.5725,
1332
+ "step": 3780
1333
+ },
1334
+ {
1335
+ "epoch": 1.0781671159029649,
1336
+ "grad_norm": 0.5117852091789246,
1337
+ "learning_rate": 1.3829298679247865e-05,
1338
+ "loss": 0.5646,
1339
+ "step": 3800
1340
+ },
1341
+ {
1342
+ "epoch": 1.0838416796708752,
1343
+ "grad_norm": 0.5082918405532837,
1344
+ "learning_rate": 1.369190083192902e-05,
1345
+ "loss": 0.5705,
1346
+ "step": 3820
1347
+ },
1348
+ {
1349
+ "epoch": 1.0895162434387857,
1350
+ "grad_norm": 0.5319990515708923,
1351
+ "learning_rate": 1.3554613492488453e-05,
1352
+ "loss": 0.5684,
1353
+ "step": 3840
1354
+ },
1355
+ {
1356
+ "epoch": 1.095190807206696,
1357
+ "grad_norm": 0.5344195365905762,
1358
+ "learning_rate": 1.3417448258924971e-05,
1359
+ "loss": 0.5658,
1360
+ "step": 3860
1361
+ },
1362
+ {
1363
+ "epoch": 1.1008653709746063,
1364
+ "grad_norm": 0.507433295249939,
1365
+ "learning_rate": 1.3280416718921902e-05,
1366
+ "loss": 0.5717,
1367
+ "step": 3880
1368
+ },
1369
+ {
1370
+ "epoch": 1.1065399347425167,
1371
+ "grad_norm": 0.5090216398239136,
1372
+ "learning_rate": 1.3143530448868198e-05,
1373
+ "loss": 0.5663,
1374
+ "step": 3900
1375
+ },
1376
+ {
1377
+ "epoch": 1.112214498510427,
1378
+ "grad_norm": 0.512146532535553,
1379
+ "learning_rate": 1.3006801012880425e-05,
1380
+ "loss": 0.5656,
1381
+ "step": 3920
1382
+ },
1383
+ {
1384
+ "epoch": 1.1178890622783373,
1385
+ "grad_norm": 0.5273200869560242,
1386
+ "learning_rate": 1.2870239961825853e-05,
1387
+ "loss": 0.5621,
1388
+ "step": 3940
1389
+ },
1390
+ {
1391
+ "epoch": 1.1235636260462476,
1392
+ "grad_norm": 0.5408139824867249,
1393
+ "learning_rate": 1.2733858832346648e-05,
1394
+ "loss": 0.5744,
1395
+ "step": 3960
1396
+ },
1397
+ {
1398
+ "epoch": 1.1292381898141581,
1399
+ "grad_norm": 0.4986436069011688,
1400
+ "learning_rate": 1.2597669145885231e-05,
1401
+ "loss": 0.5704,
1402
+ "step": 3980
1403
+ },
1404
+ {
1405
+ "epoch": 1.1349127535820684,
1406
+ "grad_norm": 0.5186699628829956,
1407
+ "learning_rate": 1.2461682407710973e-05,
1408
+ "loss": 0.5588,
1409
+ "step": 4000
1410
+ },
1411
+ {
1412
+ "epoch": 1.1405873173499788,
1413
+ "grad_norm": 0.5081115365028381,
1414
+ "learning_rate": 1.2325910105948229e-05,
1415
+ "loss": 0.5667,
1416
+ "step": 4020
1417
+ },
1418
+ {
1419
+ "epoch": 1.146261881117889,
1420
+ "grad_norm": 0.501616358757019,
1421
+ "learning_rate": 1.219036371060582e-05,
1422
+ "loss": 0.5628,
1423
+ "step": 4040
1424
+ },
1425
+ {
1426
+ "epoch": 1.1519364448857994,
1427
+ "grad_norm": 0.5288362503051758,
1428
+ "learning_rate": 1.2055054672608043e-05,
1429
+ "loss": 0.5642,
1430
+ "step": 4060
1431
+ },
1432
+ {
1433
+ "epoch": 1.1576110086537097,
1434
+ "grad_norm": 0.5392152070999146,
1435
+ "learning_rate": 1.1919994422827326e-05,
1436
+ "loss": 0.5606,
1437
+ "step": 4080
1438
+ },
1439
+ {
1440
+ "epoch": 1.16328557242162,
1441
+ "grad_norm": 0.514348030090332,
1442
+ "learning_rate": 1.1785194371118521e-05,
1443
+ "loss": 0.5653,
1444
+ "step": 4100
1445
+ },
1446
+ {
1447
+ "epoch": 1.1689601361895305,
1448
+ "grad_norm": 0.4942004978656769,
1449
+ "learning_rate": 1.1650665905355014e-05,
1450
+ "loss": 0.5622,
1451
+ "step": 4120
1452
+ },
1453
+ {
1454
+ "epoch": 1.1746346999574409,
1455
+ "grad_norm": 0.48802751302719116,
1456
+ "learning_rate": 1.1516420390466685e-05,
1457
+ "loss": 0.5613,
1458
+ "step": 4140
1459
+ },
1460
+ {
1461
+ "epoch": 1.1803092637253512,
1462
+ "grad_norm": 0.5025625228881836,
1463
+ "learning_rate": 1.1382469167479795e-05,
1464
+ "loss": 0.5656,
1465
+ "step": 4160
1466
+ },
1467
+ {
1468
+ "epoch": 1.1859838274932615,
1469
+ "grad_norm": 0.5276467204093933,
1470
+ "learning_rate": 1.1248823552558895e-05,
1471
+ "loss": 0.5639,
1472
+ "step": 4180
1473
+ },
1474
+ {
1475
+ "epoch": 1.1916583912611718,
1476
+ "grad_norm": 0.5035718083381653,
1477
+ "learning_rate": 1.1115494836050861e-05,
1478
+ "loss": 0.5612,
1479
+ "step": 4200
1480
+ },
1481
+ {
1482
+ "epoch": 1.197332955029082,
1483
+ "grad_norm": 0.5080997347831726,
1484
+ "learning_rate": 1.0982494281531069e-05,
1485
+ "loss": 0.5647,
1486
+ "step": 4220
1487
+ },
1488
+ {
1489
+ "epoch": 1.2030075187969924,
1490
+ "grad_norm": 0.505695104598999,
1491
+ "learning_rate": 1.0849833124851846e-05,
1492
+ "loss": 0.5681,
1493
+ "step": 4240
1494
+ },
1495
+ {
1496
+ "epoch": 1.2086820825649027,
1497
+ "grad_norm": 0.48905614018440247,
1498
+ "learning_rate": 1.0717522573193281e-05,
1499
+ "loss": 0.561,
1500
+ "step": 4260
1501
+ },
1502
+ {
1503
+ "epoch": 1.2143566463328133,
1504
+ "grad_norm": 0.49127668142318726,
1505
+ "learning_rate": 1.0585573804116448e-05,
1506
+ "loss": 0.5639,
1507
+ "step": 4280
1508
+ },
1509
+ {
1510
+ "epoch": 1.2200312101007236,
1511
+ "grad_norm": 0.5206524729728699,
1512
+ "learning_rate": 1.0453997964619112e-05,
1513
+ "loss": 0.5594,
1514
+ "step": 4300
1515
+ },
1516
+ {
1517
+ "epoch": 1.2257057738686339,
1518
+ "grad_norm": 0.48683062195777893,
1519
+ "learning_rate": 1.0322806170194061e-05,
1520
+ "loss": 0.5622,
1521
+ "step": 4320
1522
+ },
1523
+ {
1524
+ "epoch": 1.2313803376365442,
1525
+ "grad_norm": 0.532207190990448,
1526
+ "learning_rate": 1.0192009503890037e-05,
1527
+ "loss": 0.5581,
1528
+ "step": 4340
1529
+ },
1530
+ {
1531
+ "epoch": 1.2370549014044545,
1532
+ "grad_norm": 0.49200239777565,
1533
+ "learning_rate": 1.0061619015375473e-05,
1534
+ "loss": 0.5594,
1535
+ "step": 4360
1536
+ },
1537
+ {
1538
+ "epoch": 1.2427294651723648,
1539
+ "grad_norm": 0.504898190498352,
1540
+ "learning_rate": 9.931645720004995e-06,
1541
+ "loss": 0.5622,
1542
+ "step": 4380
1543
+ },
1544
+ {
1545
+ "epoch": 1.2484040289402751,
1546
+ "grad_norm": 0.5061923861503601,
1547
+ "learning_rate": 9.802100597888877e-06,
1548
+ "loss": 0.5572,
1549
+ "step": 4400
1550
+ },
1551
+ {
1552
+ "epoch": 1.2540785927081854,
1553
+ "grad_norm": 0.4961055815219879,
1554
+ "learning_rate": 9.672994592965409e-06,
1555
+ "loss": 0.5609,
1556
+ "step": 4420
1557
+ },
1558
+ {
1559
+ "epoch": 1.259753156476096,
1560
+ "grad_norm": 0.4930592477321625,
1561
+ "learning_rate": 9.544338612076396e-06,
1562
+ "loss": 0.5637,
1563
+ "step": 4440
1564
+ },
1565
+ {
1566
+ "epoch": 1.2654277202440063,
1567
+ "grad_norm": 0.4978179335594177,
1568
+ "learning_rate": 9.41614352404571e-06,
1569
+ "loss": 0.5615,
1570
+ "step": 4460
1571
+ },
1572
+ {
1573
+ "epoch": 1.2711022840119166,
1574
+ "grad_norm": 0.5112114548683167,
1575
+ "learning_rate": 9.288420158761127e-06,
1576
+ "loss": 0.558,
1577
+ "step": 4480
1578
+ },
1579
+ {
1580
+ "epoch": 1.276776847779827,
1581
+ "grad_norm": 0.5114573240280151,
1582
+ "learning_rate": 9.161179306259401e-06,
1583
+ "loss": 0.5561,
1584
+ "step": 4500
1585
+ },
1586
+ {
1587
+ "epoch": 1.2824514115477372,
1588
+ "grad_norm": 0.5023430585861206,
1589
+ "learning_rate": 9.034431715814726e-06,
1590
+ "loss": 0.5558,
1591
+ "step": 4520
1592
+ },
1593
+ {
1594
+ "epoch": 1.2881259753156475,
1595
+ "grad_norm": 0.503487765789032,
1596
+ "learning_rate": 8.908188095030655e-06,
1597
+ "loss": 0.5607,
1598
+ "step": 4540
1599
+ },
1600
+ {
1601
+ "epoch": 1.2938005390835579,
1602
+ "grad_norm": 0.5188455581665039,
1603
+ "learning_rate": 8.78245910893552e-06,
1604
+ "loss": 0.5639,
1605
+ "step": 4560
1606
+ },
1607
+ {
1608
+ "epoch": 1.2994751028514684,
1609
+ "grad_norm": 0.5216081738471985,
1610
+ "learning_rate": 8.657255379081438e-06,
1611
+ "loss": 0.5584,
1612
+ "step": 4580
1613
+ },
1614
+ {
1615
+ "epoch": 1.3051496666193787,
1616
+ "grad_norm": 0.5024508833885193,
1617
+ "learning_rate": 8.532587482647013e-06,
1618
+ "loss": 0.5604,
1619
+ "step": 4600
1620
+ },
1621
+ {
1622
+ "epoch": 1.310824230387289,
1623
+ "grad_norm": 0.5100445747375488,
1624
+ "learning_rate": 8.408465951543779e-06,
1625
+ "loss": 0.5596,
1626
+ "step": 4620
1627
+ },
1628
+ {
1629
+ "epoch": 1.3164987941551993,
1630
+ "grad_norm": 0.5005710124969482,
1631
+ "learning_rate": 8.284901271526481e-06,
1632
+ "loss": 0.5591,
1633
+ "step": 4640
1634
+ },
1635
+ {
1636
+ "epoch": 1.3221733579231096,
1637
+ "grad_norm": 0.5151055455207825,
1638
+ "learning_rate": 8.161903881307231e-06,
1639
+ "loss": 0.5462,
1640
+ "step": 4660
1641
+ },
1642
+ {
1643
+ "epoch": 1.32784792169102,
1644
+ "grad_norm": 0.4919968545436859,
1645
+ "learning_rate": 8.039484171673628e-06,
1646
+ "loss": 0.5523,
1647
+ "step": 4680
1648
+ },
1649
+ {
1650
+ "epoch": 1.3335224854589303,
1651
+ "grad_norm": 0.5007758140563965,
1652
+ "learning_rate": 7.917652484610975e-06,
1653
+ "loss": 0.5545,
1654
+ "step": 4700
1655
+ },
1656
+ {
1657
+ "epoch": 1.3391970492268408,
1658
+ "grad_norm": 0.4885912537574768,
1659
+ "learning_rate": 7.796419112428583e-06,
1660
+ "loss": 0.5582,
1661
+ "step": 4720
1662
+ },
1663
+ {
1664
+ "epoch": 1.344871612994751,
1665
+ "grad_norm": 0.4874049127101898,
1666
+ "learning_rate": 7.675794296890265e-06,
1667
+ "loss": 0.5505,
1668
+ "step": 4740
1669
+ },
1670
+ {
1671
+ "epoch": 1.3505461767626614,
1672
+ "grad_norm": 0.46998655796051025,
1673
+ "learning_rate": 7.555788228349143e-06,
1674
+ "loss": 0.554,
1675
+ "step": 4760
1676
+ },
1677
+ {
1678
+ "epoch": 1.3562207405305717,
1679
+ "grad_norm": 0.4996753931045532,
1680
+ "learning_rate": 7.436411044886753e-06,
1681
+ "loss": 0.5513,
1682
+ "step": 4780
1683
+ },
1684
+ {
1685
+ "epoch": 1.361895304298482,
1686
+ "grad_norm": 0.502571165561676,
1687
+ "learning_rate": 7.31767283145657e-06,
1688
+ "loss": 0.5547,
1689
+ "step": 4800
1690
+ },
1691
+ {
1692
+ "epoch": 1.3675698680663924,
1693
+ "grad_norm": 0.48792627453804016,
1694
+ "learning_rate": 7.199583619032052e-06,
1695
+ "loss": 0.5551,
1696
+ "step": 4820
1697
+ },
1698
+ {
1699
+ "epoch": 1.3732444318343027,
1700
+ "grad_norm": 0.48799988627433777,
1701
+ "learning_rate": 7.082153383759222e-06,
1702
+ "loss": 0.5524,
1703
+ "step": 4840
1704
+ },
1705
+ {
1706
+ "epoch": 1.3789189956022132,
1707
+ "grad_norm": 0.4976406991481781,
1708
+ "learning_rate": 6.9653920461138755e-06,
1709
+ "loss": 0.5548,
1710
+ "step": 4860
1711
+ },
1712
+ {
1713
+ "epoch": 1.3845935593701233,
1714
+ "grad_norm": 0.5006715655326843,
1715
+ "learning_rate": 6.849309470063529e-06,
1716
+ "loss": 0.5544,
1717
+ "step": 4880
1718
+ },
1719
+ {
1720
+ "epoch": 1.3902681231380338,
1721
+ "grad_norm": 0.4864628314971924,
1722
+ "learning_rate": 6.7339154622340754e-06,
1723
+ "loss": 0.5483,
1724
+ "step": 4900
1725
+ },
1726
+ {
1727
+ "epoch": 1.3959426869059441,
1728
+ "grad_norm": 0.48580724000930786,
1729
+ "learning_rate": 6.619219771081361e-06,
1730
+ "loss": 0.5544,
1731
+ "step": 4920
1732
+ },
1733
+ {
1734
+ "epoch": 1.4016172506738545,
1735
+ "grad_norm": 0.5042415857315063,
1736
+ "learning_rate": 6.505232086067607e-06,
1737
+ "loss": 0.5504,
1738
+ "step": 4940
1739
+ },
1740
+ {
1741
+ "epoch": 1.4072918144417648,
1742
+ "grad_norm": 0.4970082640647888,
1743
+ "learning_rate": 6.391962036842863e-06,
1744
+ "loss": 0.547,
1745
+ "step": 4960
1746
+ },
1747
+ {
1748
+ "epoch": 1.412966378209675,
1749
+ "grad_norm": 0.47866857051849365,
1750
+ "learning_rate": 6.279419192431494e-06,
1751
+ "loss": 0.5548,
1752
+ "step": 4980
1753
+ },
1754
+ {
1755
+ "epoch": 1.4186409419775854,
1756
+ "grad_norm": 0.4664076566696167,
1757
+ "learning_rate": 6.167613060423789e-06,
1758
+ "loss": 0.5454,
1759
+ "step": 5000
1760
+ },
1761
+ {
1762
+ "epoch": 1.4243155057454957,
1763
+ "grad_norm": 0.49711087346076965,
1764
+ "learning_rate": 6.0565530861727685e-06,
1765
+ "loss": 0.5519,
1766
+ "step": 5020
1767
+ },
1768
+ {
1769
+ "epoch": 1.4299900695134062,
1770
+ "grad_norm": 0.46965324878692627,
1771
+ "learning_rate": 5.946248651996244e-06,
1772
+ "loss": 0.5519,
1773
+ "step": 5040
1774
+ },
1775
+ {
1776
+ "epoch": 1.4356646332813165,
1777
+ "grad_norm": 0.505743145942688,
1778
+ "learning_rate": 5.836709076384188e-06,
1779
+ "loss": 0.5482,
1780
+ "step": 5060
1781
+ },
1782
+ {
1783
+ "epoch": 1.4413391970492269,
1784
+ "grad_norm": 0.5078002214431763,
1785
+ "learning_rate": 5.727943613211521e-06,
1786
+ "loss": 0.5575,
1787
+ "step": 5080
1788
+ },
1789
+ {
1790
+ "epoch": 1.4470137608171372,
1791
+ "grad_norm": 0.48647207021713257,
1792
+ "learning_rate": 5.619961450956347e-06,
1793
+ "loss": 0.5461,
1794
+ "step": 5100
1795
+ },
1796
+ {
1797
+ "epoch": 1.4526883245850475,
1798
+ "grad_norm": 0.4711668789386749,
1799
+ "learning_rate": 5.5127717119237084e-06,
1800
+ "loss": 0.5472,
1801
+ "step": 5120
1802
+ },
1803
+ {
1804
+ "epoch": 1.4583628883529578,
1805
+ "grad_norm": 0.518395721912384,
1806
+ "learning_rate": 5.406383451474948e-06,
1807
+ "loss": 0.5483,
1808
+ "step": 5140
1809
+ },
1810
+ {
1811
+ "epoch": 1.464037452120868,
1812
+ "grad_norm": 0.4849320948123932,
1813
+ "learning_rate": 5.300805657262706e-06,
1814
+ "loss": 0.5459,
1815
+ "step": 5160
1816
+ },
1817
+ {
1818
+ "epoch": 1.4697120158887786,
1819
+ "grad_norm": 0.501943826675415,
1820
+ "learning_rate": 5.1960472484716374e-06,
1821
+ "loss": 0.5482,
1822
+ "step": 5180
1823
+ },
1824
+ {
1825
+ "epoch": 1.475386579656689,
1826
+ "grad_norm": 0.48699691891670227,
1827
+ "learning_rate": 5.092117075064931e-06,
1828
+ "loss": 0.5522,
1829
+ "step": 5200
1830
+ },
1831
+ {
1832
+ "epoch": 1.4810611434245993,
1833
+ "grad_norm": 0.48894861340522766,
1834
+ "learning_rate": 4.989023917036667e-06,
1835
+ "loss": 0.5502,
1836
+ "step": 5220
1837
+ },
1838
+ {
1839
+ "epoch": 1.4867357071925096,
1840
+ "grad_norm": 0.49131521582603455,
1841
+ "learning_rate": 4.886776483670077e-06,
1842
+ "loss": 0.5466,
1843
+ "step": 5240
1844
+ },
1845
+ {
1846
+ "epoch": 1.49241027096042,
1847
+ "grad_norm": 0.47139400243759155,
1848
+ "learning_rate": 4.78538341280181e-06,
1849
+ "loss": 0.5473,
1850
+ "step": 5260
1851
+ },
1852
+ {
1853
+ "epoch": 1.4980848347283302,
1854
+ "grad_norm": 0.49604731798171997,
1855
+ "learning_rate": 4.684853270092173e-06,
1856
+ "loss": 0.5498,
1857
+ "step": 5280
1858
+ },
1859
+ {
1860
+ "epoch": 1.5037593984962405,
1861
+ "grad_norm": 0.4864351749420166,
1862
+ "learning_rate": 4.585194548301545e-06,
1863
+ "loss": 0.5448,
1864
+ "step": 5300
1865
+ },
1866
+ {
1867
+ "epoch": 1.509433962264151,
1868
+ "grad_norm": 0.48130905628204346,
1869
+ "learning_rate": 4.486415666572874e-06,
1870
+ "loss": 0.5469,
1871
+ "step": 5320
1872
+ },
1873
+ {
1874
+ "epoch": 1.5151085260320611,
1875
+ "grad_norm": 0.4783124625682831,
1876
+ "learning_rate": 4.388524969720458e-06,
1877
+ "loss": 0.546,
1878
+ "step": 5340
1879
+ },
1880
+ {
1881
+ "epoch": 1.5207830897999717,
1882
+ "grad_norm": 0.4969868063926697,
1883
+ "learning_rate": 4.2915307275249585e-06,
1884
+ "loss": 0.5453,
1885
+ "step": 5360
1886
+ },
1887
+ {
1888
+ "epoch": 1.526457653567882,
1889
+ "grad_norm": 0.4832542836666107,
1890
+ "learning_rate": 4.195441134034799e-06,
1891
+ "loss": 0.5463,
1892
+ "step": 5380
1893
+ },
1894
+ {
1895
+ "epoch": 1.5321322173357923,
1896
+ "grad_norm": 0.4712090790271759,
1897
+ "learning_rate": 4.10026430687389e-06,
1898
+ "loss": 0.5449,
1899
+ "step": 5400
1900
+ },
1901
+ {
1902
+ "epoch": 1.5378067811037026,
1903
+ "grad_norm": 0.4822421967983246,
1904
+ "learning_rate": 4.0060082865559035e-06,
1905
+ "loss": 0.5465,
1906
+ "step": 5420
1907
+ },
1908
+ {
1909
+ "epoch": 1.543481344871613,
1910
+ "grad_norm": 0.4809670150279999,
1911
+ "learning_rate": 3.912681035804971e-06,
1912
+ "loss": 0.5406,
1913
+ "step": 5440
1914
+ },
1915
+ {
1916
+ "epoch": 1.5491559086395235,
1917
+ "grad_norm": 0.4631410539150238,
1918
+ "learning_rate": 3.820290438883018e-06,
1919
+ "loss": 0.5461,
1920
+ "step": 5460
1921
+ },
1922
+ {
1923
+ "epoch": 1.5548304724074336,
1924
+ "grad_norm": 0.46498140692710876,
1925
+ "learning_rate": 3.728844300923694e-06,
1926
+ "loss": 0.5419,
1927
+ "step": 5480
1928
+ },
1929
+ {
1930
+ "epoch": 1.560505036175344,
1931
+ "grad_norm": 0.4786704480648041,
1932
+ "learning_rate": 3.6383503472730116e-06,
1933
+ "loss": 0.5476,
1934
+ "step": 5500
1935
+ },
1936
+ {
1937
+ "epoch": 1.5661795999432544,
1938
+ "grad_norm": 0.4655323624610901,
1939
+ "learning_rate": 3.548816222836688e-06,
1940
+ "loss": 0.5406,
1941
+ "step": 5520
1942
+ },
1943
+ {
1944
+ "epoch": 1.5718541637111647,
1945
+ "grad_norm": 0.46424925327301025,
1946
+ "learning_rate": 3.460249491434319e-06,
1947
+ "loss": 0.5415,
1948
+ "step": 5540
1949
+ },
1950
+ {
1951
+ "epoch": 1.577528727479075,
1952
+ "grad_norm": 0.45783787965774536,
1953
+ "learning_rate": 3.3726576351603985e-06,
1954
+ "loss": 0.5503,
1955
+ "step": 5560
1956
+ },
1957
+ {
1958
+ "epoch": 1.5832032912469853,
1959
+ "grad_norm": 0.49086692929267883,
1960
+ "learning_rate": 3.2860480537522103e-06,
1961
+ "loss": 0.543,
1962
+ "step": 5580
1963
+ },
1964
+ {
1965
+ "epoch": 1.5888778550148959,
1966
+ "grad_norm": 0.48474520444869995,
1967
+ "learning_rate": 3.2004280639647122e-06,
1968
+ "loss": 0.539,
1969
+ "step": 5600
1970
+ },
1971
+ {
1972
+ "epoch": 1.594552418782806,
1973
+ "grad_norm": 0.5037649869918823,
1974
+ "learning_rate": 3.115804898952434e-06,
1975
+ "loss": 0.5415,
1976
+ "step": 5620
1977
+ },
1978
+ {
1979
+ "epoch": 1.6002269825507165,
1980
+ "grad_norm": 0.4954313337802887,
1981
+ "learning_rate": 3.032185707658389e-06,
1982
+ "loss": 0.5487,
1983
+ "step": 5640
1984
+ },
1985
+ {
1986
+ "epoch": 1.6059015463186268,
1987
+ "grad_norm": 0.4597771465778351,
1988
+ "learning_rate": 2.949577554210157e-06,
1989
+ "loss": 0.5445,
1990
+ "step": 5660
1991
+ },
1992
+ {
1993
+ "epoch": 1.6115761100865371,
1994
+ "grad_norm": 0.4839852750301361,
1995
+ "learning_rate": 2.8679874173231137e-06,
1996
+ "loss": 0.5499,
1997
+ "step": 5680
1998
+ },
1999
+ {
2000
+ "epoch": 1.6172506738544474,
2001
+ "grad_norm": 0.4653310179710388,
2002
+ "learning_rate": 2.787422189710844e-06,
2003
+ "loss": 0.5453,
2004
+ "step": 5700
2005
+ },
2006
+ {
2007
+ "epoch": 1.6229252376223577,
2008
+ "grad_norm": 0.485579252243042,
2009
+ "learning_rate": 2.7078886775028693e-06,
2010
+ "loss": 0.5383,
2011
+ "step": 5720
2012
+ },
2013
+ {
2014
+ "epoch": 1.6285998013902683,
2015
+ "grad_norm": 0.4727838337421417,
2016
+ "learning_rate": 2.629393599669667e-06,
2017
+ "loss": 0.5421,
2018
+ "step": 5740
2019
+ },
2020
+ {
2021
+ "epoch": 1.6342743651581784,
2022
+ "grad_norm": 0.45239365100860596,
2023
+ "learning_rate": 2.5519435874550434e-06,
2024
+ "loss": 0.5357,
2025
+ "step": 5760
2026
+ },
2027
+ {
2028
+ "epoch": 1.639948928926089,
2029
+ "grad_norm": 0.4669874310493469,
2030
+ "learning_rate": 2.475545183815926e-06,
2031
+ "loss": 0.5385,
2032
+ "step": 5780
2033
+ },
2034
+ {
2035
+ "epoch": 1.645623492693999,
2036
+ "grad_norm": 0.4859563410282135,
2037
+ "learning_rate": 2.400204842869637e-06,
2038
+ "loss": 0.5446,
2039
+ "step": 5800
2040
+ },
2041
+ {
2042
+ "epoch": 1.6512980564619095,
2043
+ "grad_norm": 0.4492729902267456,
2044
+ "learning_rate": 2.3259289293486246e-06,
2045
+ "loss": 0.5418,
2046
+ "step": 5820
2047
+ },
2048
+ {
2049
+ "epoch": 1.6569726202298198,
2050
+ "grad_norm": 0.46383896470069885,
2051
+ "learning_rate": 2.252723718062787e-06,
2052
+ "loss": 0.5401,
2053
+ "step": 5840
2054
+ },
2055
+ {
2056
+ "epoch": 1.6626471839977301,
2057
+ "grad_norm": 0.48168492317199707,
2058
+ "learning_rate": 2.1805953933693835e-06,
2059
+ "loss": 0.5423,
2060
+ "step": 5860
2061
+ },
2062
+ {
2063
+ "epoch": 1.6683217477656405,
2064
+ "grad_norm": 0.46742239594459534,
2065
+ "learning_rate": 2.109550048650563e-06,
2066
+ "loss": 0.542,
2067
+ "step": 5880
2068
+ },
2069
+ {
2070
+ "epoch": 1.6739963115335508,
2071
+ "grad_norm": 0.46751725673675537,
2072
+ "learning_rate": 2.0395936857986125e-06,
2073
+ "loss": 0.5402,
2074
+ "step": 5900
2075
+ },
2076
+ {
2077
+ "epoch": 1.6796708753014613,
2078
+ "grad_norm": 0.49627310037612915,
2079
+ "learning_rate": 1.970732214708908e-06,
2080
+ "loss": 0.5461,
2081
+ "step": 5920
2082
+ },
2083
+ {
2084
+ "epoch": 1.6853454390693714,
2085
+ "grad_norm": 0.46826520562171936,
2086
+ "learning_rate": 1.9029714527806652e-06,
2087
+ "loss": 0.5385,
2088
+ "step": 5940
2089
+ },
2090
+ {
2091
+ "epoch": 1.691020002837282,
2092
+ "grad_norm": 0.4701858162879944,
2093
+ "learning_rate": 1.8363171244254606e-06,
2094
+ "loss": 0.5376,
2095
+ "step": 5960
2096
+ },
2097
+ {
2098
+ "epoch": 1.6966945666051922,
2099
+ "grad_norm": 0.4635229706764221,
2100
+ "learning_rate": 1.7707748605836632e-06,
2101
+ "loss": 0.5378,
2102
+ "step": 5980
2103
+ },
2104
+ {
2105
+ "epoch": 1.7023691303731026,
2106
+ "grad_norm": 0.4729613661766052,
2107
+ "learning_rate": 1.7063501982487135e-06,
2108
+ "loss": 0.5437,
2109
+ "step": 6000
2110
+ },
2111
+ {
2112
+ "epoch": 1.7080436941410129,
2113
+ "grad_norm": 0.4672451913356781,
2114
+ "learning_rate": 1.6430485799993673e-06,
2115
+ "loss": 0.5428,
2116
+ "step": 6020
2117
+ },
2118
+ {
2119
+ "epoch": 1.7137182579089232,
2120
+ "grad_norm": 0.46772390604019165,
2121
+ "learning_rate": 1.5808753535399022e-06,
2122
+ "loss": 0.5392,
2123
+ "step": 6040
2124
+ },
2125
+ {
2126
+ "epoch": 1.7193928216768337,
2127
+ "grad_norm": 0.46337825059890747,
2128
+ "learning_rate": 1.5198357712483629e-06,
2129
+ "loss": 0.5413,
2130
+ "step": 6060
2131
+ },
2132
+ {
2133
+ "epoch": 1.7250673854447438,
2134
+ "grad_norm": 0.48103076219558716,
2135
+ "learning_rate": 1.459934989732818e-06,
2136
+ "loss": 0.5416,
2137
+ "step": 6080
2138
+ },
2139
+ {
2140
+ "epoch": 1.7307419492126543,
2141
+ "grad_norm": 0.45769959688186646,
2142
+ "learning_rate": 1.4011780693957492e-06,
2143
+ "loss": 0.5436,
2144
+ "step": 6100
2145
+ },
2146
+ {
2147
+ "epoch": 1.7364165129805647,
2148
+ "grad_norm": 0.4552821218967438,
2149
+ "learning_rate": 1.3435699740065377e-06,
2150
+ "loss": 0.5425,
2151
+ "step": 6120
2152
+ },
2153
+ {
2154
+ "epoch": 1.742091076748475,
2155
+ "grad_norm": 0.48623600602149963,
2156
+ "learning_rate": 1.2871155702821324e-06,
2157
+ "loss": 0.5427,
2158
+ "step": 6140
2159
+ },
2160
+ {
2161
+ "epoch": 1.7477656405163853,
2162
+ "grad_norm": 0.5024483799934387,
2163
+ "learning_rate": 1.231819627475911e-06,
2164
+ "loss": 0.5384,
2165
+ "step": 6160
2166
+ },
2167
+ {
2168
+ "epoch": 1.7534402042842956,
2169
+ "grad_norm": 0.4556623101234436,
2170
+ "learning_rate": 1.1776868169747702e-06,
2171
+ "loss": 0.5393,
2172
+ "step": 6180
2173
+ },
2174
+ {
2175
+ "epoch": 1.7591147680522061,
2176
+ "grad_norm": 0.4748471677303314,
2177
+ "learning_rate": 1.1247217119044951e-06,
2178
+ "loss": 0.5385,
2179
+ "step": 6200
2180
+ },
2181
+ {
2182
+ "epoch": 1.7647893318201162,
2183
+ "grad_norm": 0.4622340500354767,
2184
+ "learning_rate": 1.07292878674342e-06,
2185
+ "loss": 0.5377,
2186
+ "step": 6220
2187
+ },
2188
+ {
2189
+ "epoch": 1.7704638955880267,
2190
+ "grad_norm": 0.4581329822540283,
2191
+ "learning_rate": 1.0223124169444236e-06,
2192
+ "loss": 0.5366,
2193
+ "step": 6240
2194
+ },
2195
+ {
2196
+ "epoch": 1.776138459355937,
2197
+ "grad_norm": 0.4667391777038574,
2198
+ "learning_rate": 9.72876878565287e-07,
2199
+ "loss": 0.539,
2200
+ "step": 6260
2201
+ },
2202
+ {
2203
+ "epoch": 1.7818130231238474,
2204
+ "grad_norm": 0.4563803970813751,
2205
+ "learning_rate": 9.246263479074663e-07,
2206
+ "loss": 0.5403,
2207
+ "step": 6280
2208
+ },
2209
+ {
2210
+ "epoch": 1.7874875868917577,
2211
+ "grad_norm": 0.44948819279670715,
2212
+ "learning_rate": 8.775649011632703e-07,
2213
+ "loss": 0.5392,
2214
+ "step": 6300
2215
+ },
2216
+ {
2217
+ "epoch": 1.793162150659668,
2218
+ "grad_norm": 0.4829549193382263,
2219
+ "learning_rate": 8.316965140715071e-07,
2220
+ "loss": 0.5373,
2221
+ "step": 6320
2222
+ },
2223
+ {
2224
+ "epoch": 1.7988367144275785,
2225
+ "grad_norm": 0.4718981683254242,
2226
+ "learning_rate": 7.870250615816182e-07,
2227
+ "loss": 0.5383,
2228
+ "step": 6340
2229
+ },
2230
+ {
2231
+ "epoch": 1.8045112781954886,
2232
+ "grad_norm": 0.4641667306423187,
2233
+ "learning_rate": 7.435543175263166e-07,
2234
+ "loss": 0.543,
2235
+ "step": 6360
2236
+ },
2237
+ {
2238
+ "epoch": 1.8101858419633992,
2239
+ "grad_norm": 0.45884087681770325,
2240
+ "learning_rate": 7.012879543027801e-07,
2241
+ "loss": 0.538,
2242
+ "step": 6380
2243
+ },
2244
+ {
2245
+ "epoch": 1.8158604057313092,
2246
+ "grad_norm": 0.4888609051704407,
2247
+ "learning_rate": 6.602295425624033e-07,
2248
+ "loss": 0.5366,
2249
+ "step": 6400
2250
+ },
2251
+ {
2252
+ "epoch": 1.8215349694992198,
2253
+ "grad_norm": 0.46243107318878174,
2254
+ "learning_rate": 6.20382550909157e-07,
2255
+ "loss": 0.5365,
2256
+ "step": 6420
2257
+ },
2258
+ {
2259
+ "epoch": 1.82720953326713,
2260
+ "grad_norm": 0.46520647406578064,
2261
+ "learning_rate": 5.817503456065559e-07,
2262
+ "loss": 0.5339,
2263
+ "step": 6440
2264
+ },
2265
+ {
2266
+ "epoch": 1.8328840970350404,
2267
+ "grad_norm": 0.47549664974212646,
2268
+ "learning_rate": 5.443361902932792e-07,
2269
+ "loss": 0.5361,
2270
+ "step": 6460
2271
+ },
2272
+ {
2273
+ "epoch": 1.838558660802951,
2274
+ "grad_norm": 0.4677965044975281,
2275
+ "learning_rate": 5.081432457074614e-07,
2276
+ "loss": 0.5394,
2277
+ "step": 6480
2278
+ },
2279
+ {
2280
+ "epoch": 1.844233224570861,
2281
+ "grad_norm": 0.46250638365745544,
2282
+ "learning_rate": 4.7317456941966597e-07,
2283
+ "loss": 0.5388,
2284
+ "step": 6500
2285
+ },
2286
+ {
2287
+ "epoch": 1.8499077883387716,
2288
+ "grad_norm": 0.4758864641189575,
2289
+ "learning_rate": 4.3943311557459177e-07,
2290
+ "loss": 0.534,
2291
+ "step": 6520
2292
+ },
2293
+ {
2294
+ "epoch": 1.8555823521066817,
2295
+ "grad_norm": 0.4370381832122803,
2296
+ "learning_rate": 4.069217346415027e-07,
2297
+ "loss": 0.5339,
2298
+ "step": 6540
2299
+ },
2300
+ {
2301
+ "epoch": 1.8612569158745922,
2302
+ "grad_norm": 0.4617324769496918,
2303
+ "learning_rate": 3.756431731734272e-07,
2304
+ "loss": 0.5396,
2305
+ "step": 6560
2306
+ },
2307
+ {
2308
+ "epoch": 1.8669314796425025,
2309
+ "grad_norm": 0.4532717168331146,
2310
+ "learning_rate": 3.4560007357511856e-07,
2311
+ "loss": 0.5393,
2312
+ "step": 6580
2313
+ },
2314
+ {
2315
+ "epoch": 1.8726060434104128,
2316
+ "grad_norm": 0.46486184000968933,
2317
+ "learning_rate": 3.16794973879837e-07,
2318
+ "loss": 0.5367,
2319
+ "step": 6600
2320
+ },
2321
+ {
2322
+ "epoch": 1.8782806071783231,
2323
+ "grad_norm": 0.44514200091362,
2324
+ "learning_rate": 2.8923030753492783e-07,
2325
+ "loss": 0.5384,
2326
+ "step": 6620
2327
+ },
2328
+ {
2329
+ "epoch": 1.8839551709462334,
2330
+ "grad_norm": 0.4737865924835205,
2331
+ "learning_rate": 2.6290840319625255e-07,
2332
+ "loss": 0.5355,
2333
+ "step": 6640
2334
+ },
2335
+ {
2336
+ "epoch": 1.889629734714144,
2337
+ "grad_norm": 0.45271801948547363,
2338
+ "learning_rate": 2.378314845314561e-07,
2339
+ "loss": 0.5451,
2340
+ "step": 6660
2341
+ },
2342
+ {
2343
+ "epoch": 1.895304298482054,
2344
+ "grad_norm": 0.46050384640693665,
2345
+ "learning_rate": 2.14001670032124e-07,
2346
+ "loss": 0.5347,
2347
+ "step": 6680
2348
+ },
2349
+ {
2350
+ "epoch": 1.9009788622499646,
2351
+ "grad_norm": 0.4726841151714325,
2352
+ "learning_rate": 1.9142097283479876e-07,
2353
+ "loss": 0.5428,
2354
+ "step": 6700
2355
+ },
2356
+ {
2357
+ "epoch": 1.906653426017875,
2358
+ "grad_norm": 0.4662003815174103,
2359
+ "learning_rate": 1.700913005509208e-07,
2360
+ "loss": 0.5407,
2361
+ "step": 6720
2362
+ },
2363
+ {
2364
+ "epoch": 1.9123279897857852,
2365
+ "grad_norm": 0.44422999024391174,
2366
+ "learning_rate": 1.500144551056709e-07,
2367
+ "loss": 0.535,
2368
+ "step": 6740
2369
+ },
2370
+ {
2371
+ "epoch": 1.9180025535536955,
2372
+ "grad_norm": 0.4599597752094269,
2373
+ "learning_rate": 1.3119213258574015e-07,
2374
+ "loss": 0.5376,
2375
+ "step": 6760
2376
+ },
2377
+ {
2378
+ "epoch": 1.9236771173216058,
2379
+ "grad_norm": 0.4735456705093384,
2380
+ "learning_rate": 1.1362592309605291e-07,
2381
+ "loss": 0.5392,
2382
+ "step": 6780
2383
+ },
2384
+ {
2385
+ "epoch": 1.9293516810895164,
2386
+ "grad_norm": 0.4692912995815277,
2387
+ "learning_rate": 9.731731062542604e-08,
2388
+ "loss": 0.5398,
2389
+ "step": 6800
2390
  }
2391
  ],
2392
+ "logging_steps": 20,
2393
+ "max_steps": 7048,
2394
  "num_input_tokens_seen": 0,
2395
+ "num_train_epochs": 2,
2396
  "save_steps": 200,
2397
  "stateful_callbacks": {
2398
  "TrainerControl": {
 
2406
  "attributes": {}
2407
  }
2408
  },
2409
+ "total_flos": 1.5124467391135325e+20,
2410
+ "train_batch_size": 1,
2411
  "trial_name": null,
2412
  "trial_params": null
2413
  }
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a6bf16ea130bda159d1af2ee62d236c7ae097ea41c8408d8221e7b326b65872b
3
- size 6456
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffd93f25c50f75fbd7f7b6ad5a315acf357ca57e88203e0285f40efaac4f4e34
3
+ size 6520