Rosiness
commited on
Commit
•
bf7b7e8
0
Parent(s):
First model version
Browse files- .gitattributes +21 -0
- added_tokens.json +9 -0
- config.json +187 -0
- configuration_intern_vit.py +117 -0
- configuration_internvl_chat.py +75 -0
- conversation.py +1243 -0
- generation_config.json +4 -0
- latest +1 -0
- model-00001-of-00017.safetensors +3 -0
- model-00002-of-00017.safetensors +3 -0
- model-00003-of-00017.safetensors +3 -0
- model-00004-of-00017.safetensors +3 -0
- model-00005-of-00017.safetensors +3 -0
- model-00006-of-00017.safetensors +3 -0
- model-00007-of-00017.safetensors +3 -0
- model-00008-of-00017.safetensors +3 -0
- model-00009-of-00017.safetensors +3 -0
- model-00010-of-00017.safetensors +3 -0
- model-00011-of-00017.safetensors +3 -0
- model-00012-of-00017.safetensors +3 -0
- model-00013-of-00017.safetensors +3 -0
- model-00014-of-00017.safetensors +3 -0
- model-00015-of-00017.safetensors +3 -0
- model-00016-of-00017.safetensors +3 -0
- model-00017-of-00017.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_intern_vit.py +413 -0
- modeling_internvl_chat.py +449 -0
- special_tokens_map.json +41 -0
- tokenizer.model +3 -0
- tokenizer_config.json +143 -0
- trainer_state.json +1145 -0
- training_args.bin +3 -0
- zero_to_fp32.py +578 -0
.gitattributes
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model-00003-of-00017.safetensors filter=lfs diff=lfs merge=lfs -text
|
2 |
+
model-00008-of-00017.safetensors filter=lfs diff=lfs merge=lfs -text
|
3 |
+
model-00012-of-00017.safetensors filter=lfs diff=lfs merge=lfs -text
|
4 |
+
model-00016-of-00017.safetensors filter=lfs diff=lfs merge=lfs -text
|
5 |
+
model-00011-of-00017.safetensors filter=lfs diff=lfs merge=lfs -text
|
6 |
+
model-00013-of-00017.safetensors filter=lfs diff=lfs merge=lfs -text
|
7 |
+
training_args.bin filter=lfs diff=lfs merge=lfs -text
|
8 |
+
model-00006-of-00017.safetensors filter=lfs diff=lfs merge=lfs -text
|
9 |
+
model-00009-of-00017.safetensors filter=lfs diff=lfs merge=lfs -text
|
10 |
+
model-00010-of-00017.safetensors filter=lfs diff=lfs merge=lfs -text
|
11 |
+
tokenizer.model filter=lfs diff=lfs merge=lfs -text
|
12 |
+
model-00015-of-00017.safetensors filter=lfs diff=lfs merge=lfs -text
|
13 |
+
model-00007-of-00017.safetensors filter=lfs diff=lfs merge=lfs -text
|
14 |
+
model-00014-of-00017.safetensors filter=lfs diff=lfs merge=lfs -text
|
15 |
+
model-00017-of-00017.safetensors filter=lfs diff=lfs merge=lfs -text
|
16 |
+
model-00001-of-00017.safetensors filter=lfs diff=lfs merge=lfs -text
|
17 |
+
model-00002-of-00017.safetensors filter=lfs diff=lfs merge=lfs -text
|
18 |
+
model-00004-of-00017.safetensors filter=lfs diff=lfs merge=lfs -text
|
19 |
+
model-00005-of-00017.safetensors filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
21 |
+
tokenizer.model filter=lfs diff=lfs merge=lfs -text
|
added_tokens.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"</box>": 64006,
|
3 |
+
"</quad>": 64002,
|
4 |
+
"</ref>": 64004,
|
5 |
+
"<IMG_CONTEXT>": 64000,
|
6 |
+
"<box>": 64005,
|
7 |
+
"<quad>": 64001,
|
8 |
+
"<ref>": 64003
|
9 |
+
}
|
config.json
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_commit_hash": "21d6ce2f09ce86a2dd76e1d6672d2143e060de34",
|
3 |
+
"_name_or_path": "OpenGVLab/InternVL-Chat-Chinese-V1-2-Plus",
|
4 |
+
"architectures": [
|
5 |
+
"InternVLChatModel"
|
6 |
+
],
|
7 |
+
"auto_map": {
|
8 |
+
"AutoConfig": "configuration_internvl_chat.InternVLChatConfig",
|
9 |
+
"AutoModel": "modeling_internvl_chat.InternVLChatModel"
|
10 |
+
},
|
11 |
+
"downsample_ratio": 0.5,
|
12 |
+
"force_image_size": 448,
|
13 |
+
"llm_config": {
|
14 |
+
"_name_or_path": "01-ai/Yi-34B",
|
15 |
+
"add_cross_attention": false,
|
16 |
+
"architectures": [
|
17 |
+
"LlamaForCausalLM"
|
18 |
+
],
|
19 |
+
"attention_bias": false,
|
20 |
+
"attention_dropout": 0.0,
|
21 |
+
"attn_implementation": "flash_attention_2",
|
22 |
+
"bad_words_ids": null,
|
23 |
+
"begin_suppress_tokens": null,
|
24 |
+
"bos_token_id": 1,
|
25 |
+
"chunk_size_feed_forward": 0,
|
26 |
+
"cross_attention_hidden_size": null,
|
27 |
+
"decoder_start_token_id": null,
|
28 |
+
"diversity_penalty": 0.0,
|
29 |
+
"do_sample": false,
|
30 |
+
"early_stopping": false,
|
31 |
+
"encoder_no_repeat_ngram_size": 0,
|
32 |
+
"eos_token_id": 7,
|
33 |
+
"exponential_decay_length_penalty": null,
|
34 |
+
"finetuning_task": null,
|
35 |
+
"forced_bos_token_id": null,
|
36 |
+
"forced_eos_token_id": null,
|
37 |
+
"hidden_act": "silu",
|
38 |
+
"hidden_size": 7168,
|
39 |
+
"id2label": {
|
40 |
+
"0": "LABEL_0",
|
41 |
+
"1": "LABEL_1"
|
42 |
+
},
|
43 |
+
"initializer_range": 0.02,
|
44 |
+
"intermediate_size": 20480,
|
45 |
+
"is_decoder": false,
|
46 |
+
"is_encoder_decoder": false,
|
47 |
+
"label2id": {
|
48 |
+
"LABEL_0": 0,
|
49 |
+
"LABEL_1": 1
|
50 |
+
},
|
51 |
+
"length_penalty": 1.0,
|
52 |
+
"max_length": 20,
|
53 |
+
"max_position_embeddings": 4096,
|
54 |
+
"min_length": 0,
|
55 |
+
"model_type": "llama",
|
56 |
+
"no_repeat_ngram_size": 0,
|
57 |
+
"num_attention_heads": 56,
|
58 |
+
"num_beam_groups": 1,
|
59 |
+
"num_beams": 1,
|
60 |
+
"num_hidden_layers": 60,
|
61 |
+
"num_key_value_heads": 8,
|
62 |
+
"num_return_sequences": 1,
|
63 |
+
"output_attentions": false,
|
64 |
+
"output_hidden_states": false,
|
65 |
+
"output_scores": false,
|
66 |
+
"pad_token_id": 0,
|
67 |
+
"prefix": null,
|
68 |
+
"pretraining_tp": 1,
|
69 |
+
"problem_type": null,
|
70 |
+
"pruned_heads": {},
|
71 |
+
"remove_invalid_values": false,
|
72 |
+
"repetition_penalty": 1.0,
|
73 |
+
"return_dict": true,
|
74 |
+
"return_dict_in_generate": false,
|
75 |
+
"rms_norm_eps": 1e-05,
|
76 |
+
"rope_scaling": null,
|
77 |
+
"rope_theta": 5000000.0,
|
78 |
+
"sep_token_id": null,
|
79 |
+
"suppress_tokens": null,
|
80 |
+
"task_specific_params": null,
|
81 |
+
"temperature": 1.0,
|
82 |
+
"tf_legacy_loss": false,
|
83 |
+
"tie_encoder_decoder": false,
|
84 |
+
"tie_word_embeddings": false,
|
85 |
+
"tokenizer_class": null,
|
86 |
+
"top_k": 50,
|
87 |
+
"top_p": 1.0,
|
88 |
+
"torch_dtype": "bfloat16",
|
89 |
+
"torchscript": false,
|
90 |
+
"transformers_version": "4.36.2",
|
91 |
+
"typical_p": 1.0,
|
92 |
+
"use_bfloat16": false,
|
93 |
+
"use_cache": false,
|
94 |
+
"vocab_size": 64007
|
95 |
+
},
|
96 |
+
"model_type": "internvl_chat",
|
97 |
+
"pad2square": false,
|
98 |
+
"select_layer": -1,
|
99 |
+
"template": "Hermes-2",
|
100 |
+
"torch_dtype": "bfloat16",
|
101 |
+
"transformers_version": null,
|
102 |
+
"use_backbone_lora": 0,
|
103 |
+
"use_llm_lora": 0,
|
104 |
+
"vision_config": {
|
105 |
+
"_name_or_path": "",
|
106 |
+
"add_cross_attention": false,
|
107 |
+
"architectures": [
|
108 |
+
"InternVisionModel"
|
109 |
+
],
|
110 |
+
"attention_dropout": 0.0,
|
111 |
+
"bad_words_ids": null,
|
112 |
+
"begin_suppress_tokens": null,
|
113 |
+
"bos_token_id": null,
|
114 |
+
"chunk_size_feed_forward": 0,
|
115 |
+
"cross_attention_hidden_size": null,
|
116 |
+
"decoder_start_token_id": null,
|
117 |
+
"diversity_penalty": 0.0,
|
118 |
+
"do_sample": false,
|
119 |
+
"drop_path_rate": 0.0,
|
120 |
+
"dropout": 0.0,
|
121 |
+
"early_stopping": false,
|
122 |
+
"encoder_no_repeat_ngram_size": 0,
|
123 |
+
"eos_token_id": null,
|
124 |
+
"exponential_decay_length_penalty": null,
|
125 |
+
"finetuning_task": null,
|
126 |
+
"forced_bos_token_id": null,
|
127 |
+
"forced_eos_token_id": null,
|
128 |
+
"hidden_act": "gelu",
|
129 |
+
"hidden_size": 3200,
|
130 |
+
"id2label": {
|
131 |
+
"0": "LABEL_0",
|
132 |
+
"1": "LABEL_1"
|
133 |
+
},
|
134 |
+
"image_size": 448,
|
135 |
+
"initializer_factor": 0.1,
|
136 |
+
"initializer_range": 1e-10,
|
137 |
+
"intermediate_size": 12800,
|
138 |
+
"is_decoder": false,
|
139 |
+
"is_encoder_decoder": false,
|
140 |
+
"label2id": {
|
141 |
+
"LABEL_0": 0,
|
142 |
+
"LABEL_1": 1
|
143 |
+
},
|
144 |
+
"layer_norm_eps": 1e-06,
|
145 |
+
"length_penalty": 1.0,
|
146 |
+
"max_length": 20,
|
147 |
+
"min_length": 0,
|
148 |
+
"model_type": "intern_vit_6b",
|
149 |
+
"no_repeat_ngram_size": 0,
|
150 |
+
"num_attention_heads": 25,
|
151 |
+
"num_beam_groups": 1,
|
152 |
+
"num_beams": 1,
|
153 |
+
"num_channels": 3,
|
154 |
+
"num_hidden_layers": 45,
|
155 |
+
"num_return_sequences": 1,
|
156 |
+
"output_attentions": false,
|
157 |
+
"output_hidden_states": false,
|
158 |
+
"output_scores": false,
|
159 |
+
"pad_token_id": null,
|
160 |
+
"patch_size": 14,
|
161 |
+
"prefix": null,
|
162 |
+
"problem_type": null,
|
163 |
+
"pruned_heads": {},
|
164 |
+
"qk_normalization": true,
|
165 |
+
"qkv_bias": false,
|
166 |
+
"remove_invalid_values": false,
|
167 |
+
"repetition_penalty": 1.0,
|
168 |
+
"return_dict": true,
|
169 |
+
"return_dict_in_generate": false,
|
170 |
+
"sep_token_id": null,
|
171 |
+
"suppress_tokens": null,
|
172 |
+
"task_specific_params": null,
|
173 |
+
"temperature": 1.0,
|
174 |
+
"tf_legacy_loss": false,
|
175 |
+
"tie_encoder_decoder": false,
|
176 |
+
"tie_word_embeddings": true,
|
177 |
+
"tokenizer_class": null,
|
178 |
+
"top_k": 50,
|
179 |
+
"top_p": 1.0,
|
180 |
+
"torch_dtype": "bfloat16",
|
181 |
+
"torchscript": false,
|
182 |
+
"transformers_version": "4.36.2",
|
183 |
+
"typical_p": 1.0,
|
184 |
+
"use_bfloat16": true,
|
185 |
+
"use_flash_attn": true
|
186 |
+
}
|
187 |
+
}
|
configuration_intern_vit.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# InternVL
|
3 |
+
# Copyright (c) 2023 OpenGVLab
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# --------------------------------------------------------
|
6 |
+
import os
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
from transformers.configuration_utils import PretrainedConfig
|
10 |
+
from transformers.utils import logging
|
11 |
+
|
12 |
+
logger = logging.get_logger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
class InternVisionConfig(PretrainedConfig):
|
16 |
+
r"""
|
17 |
+
This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
|
18 |
+
instantiate a vision encoder according to the specified arguments, defining the model architecture.
|
19 |
+
|
20 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
21 |
+
documentation from [`PretrainedConfig`] for more information.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
num_channels (`int`, *optional*, defaults to 3):
|
25 |
+
Number of color channels in the input images (e.g., 3 for RGB).
|
26 |
+
patch_size (`int`, *optional*, defaults to 14):
|
27 |
+
The size (resolution) of each patch.
|
28 |
+
image_size (`int`, *optional*, defaults to 224):
|
29 |
+
The size (resolution) of each image.
|
30 |
+
qkv_bias (`bool`, *optional*, defaults to `False`):
|
31 |
+
Whether to add a bias to the queries and values in the self-attention layers.
|
32 |
+
hidden_size (`int`, *optional*, defaults to 3200):
|
33 |
+
Dimensionality of the encoder layers and the pooler layer.
|
34 |
+
num_attention_heads (`int`, *optional*, defaults to 25):
|
35 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
36 |
+
intermediate_size (`int`, *optional*, defaults to 12800):
|
37 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
38 |
+
qk_normalization (`bool`, *optional*, defaults to `True`):
|
39 |
+
Whether to normalize the queries and keys in the self-attention layers.
|
40 |
+
num_hidden_layers (`int`, *optional*, defaults to 48):
|
41 |
+
Number of hidden layers in the Transformer encoder.
|
42 |
+
use_flash_attn (`bool`, *optional*, defaults to `True`):
|
43 |
+
Whether to use flash attention mechanism.
|
44 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
45 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
46 |
+
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
|
47 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
48 |
+
The epsilon used by the layer normalization layers.
|
49 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
50 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
51 |
+
drop_path_rate (`float`, *optional*, defaults to 0.0):
|
52 |
+
Dropout rate for stochastic depth.
|
53 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
54 |
+
The dropout ratio for the attention probabilities.
|
55 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
56 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
57 |
+
initializer_factor (`float`, *optional*, defaults to 0.1):
|
58 |
+
A factor for layer scale.
|
59 |
+
"""
|
60 |
+
|
61 |
+
model_type = 'intern_vit_6b'
|
62 |
+
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
num_channels=3,
|
66 |
+
patch_size=14,
|
67 |
+
image_size=224,
|
68 |
+
qkv_bias=False,
|
69 |
+
hidden_size=3200,
|
70 |
+
num_attention_heads=25,
|
71 |
+
intermediate_size=12800,
|
72 |
+
qk_normalization=True,
|
73 |
+
num_hidden_layers=48,
|
74 |
+
use_flash_attn=True,
|
75 |
+
hidden_act='gelu',
|
76 |
+
layer_norm_eps=1e-6,
|
77 |
+
dropout=0.0,
|
78 |
+
drop_path_rate=0.0,
|
79 |
+
attention_dropout=0.0,
|
80 |
+
initializer_range=0.02,
|
81 |
+
initializer_factor=0.1,
|
82 |
+
**kwargs,
|
83 |
+
):
|
84 |
+
super().__init__(**kwargs)
|
85 |
+
|
86 |
+
self.hidden_size = hidden_size
|
87 |
+
self.intermediate_size = intermediate_size
|
88 |
+
self.dropout = dropout
|
89 |
+
self.drop_path_rate = drop_path_rate
|
90 |
+
self.num_hidden_layers = num_hidden_layers
|
91 |
+
self.num_attention_heads = num_attention_heads
|
92 |
+
self.num_channels = num_channels
|
93 |
+
self.patch_size = patch_size
|
94 |
+
self.image_size = image_size
|
95 |
+
self.initializer_range = initializer_range
|
96 |
+
self.initializer_factor = initializer_factor
|
97 |
+
self.attention_dropout = attention_dropout
|
98 |
+
self.layer_norm_eps = layer_norm_eps
|
99 |
+
self.hidden_act = hidden_act
|
100 |
+
self.qkv_bias = qkv_bias
|
101 |
+
self.qk_normalization = qk_normalization
|
102 |
+
self.use_flash_attn = use_flash_attn
|
103 |
+
|
104 |
+
@classmethod
|
105 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
|
106 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
107 |
+
|
108 |
+
if 'vision_config' in config_dict:
|
109 |
+
config_dict = config_dict['vision_config']
|
110 |
+
|
111 |
+
if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
|
112 |
+
logger.warning(
|
113 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
114 |
+
f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
|
115 |
+
)
|
116 |
+
|
117 |
+
return cls.from_dict(config_dict, **kwargs)
|
configuration_internvl_chat.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# InternVL
|
3 |
+
# Copyright (c) 2023 OpenGVLab
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
import copy
|
8 |
+
|
9 |
+
from transformers import LlamaConfig
|
10 |
+
from transformers.configuration_utils import PretrainedConfig
|
11 |
+
from transformers.utils import logging
|
12 |
+
|
13 |
+
from .configuration_intern_vit import InternVisionConfig
|
14 |
+
|
15 |
+
logger = logging.get_logger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
class InternVLChatConfig(PretrainedConfig):
|
19 |
+
model_type = 'internvl_chat'
|
20 |
+
is_composition = True
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
vision_config=None,
|
25 |
+
llm_config=None,
|
26 |
+
use_backbone_lora=0,
|
27 |
+
use_llm_lora=0,
|
28 |
+
pad2square=False,
|
29 |
+
select_layer=-4,
|
30 |
+
force_image_size=None,
|
31 |
+
downsample_ratio=0.5,
|
32 |
+
template=None,
|
33 |
+
**kwargs):
|
34 |
+
super().__init__(**kwargs)
|
35 |
+
|
36 |
+
if vision_config is None:
|
37 |
+
vision_config = {}
|
38 |
+
logger.info('vision_config is None. Initializing the InternVisionConfig with default values.')
|
39 |
+
|
40 |
+
if llm_config is None:
|
41 |
+
llm_config = {}
|
42 |
+
logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
|
43 |
+
|
44 |
+
self.vision_config = InternVisionConfig(**vision_config)
|
45 |
+
self.llm_config = LlamaConfig(**llm_config)
|
46 |
+
self.use_backbone_lora = use_backbone_lora
|
47 |
+
self.use_llm_lora = use_llm_lora
|
48 |
+
self.pad2square = pad2square
|
49 |
+
self.select_layer = select_layer
|
50 |
+
self.force_image_size = force_image_size
|
51 |
+
self.downsample_ratio = downsample_ratio
|
52 |
+
self.template = template
|
53 |
+
|
54 |
+
logger.info(f'vision_select_layer: {self.select_layer}')
|
55 |
+
|
56 |
+
def to_dict(self):
|
57 |
+
"""
|
58 |
+
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
62 |
+
"""
|
63 |
+
output = copy.deepcopy(self.__dict__)
|
64 |
+
output['vision_config'] = self.vision_config.to_dict()
|
65 |
+
output['llm_config'] = self.llm_config.to_dict()
|
66 |
+
output['model_type'] = self.__class__.model_type
|
67 |
+
output['use_backbone_lora'] = self.use_backbone_lora
|
68 |
+
output['use_llm_lora'] = self.use_llm_lora
|
69 |
+
output['pad2square'] = self.pad2square
|
70 |
+
output['select_layer'] = self.select_layer
|
71 |
+
output['force_image_size'] = self.force_image_size
|
72 |
+
output['downsample_ratio'] = self.downsample_ratio
|
73 |
+
output['template'] = self.template
|
74 |
+
|
75 |
+
return output
|
conversation.py
ADDED
@@ -0,0 +1,1243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Conversation prompt templates.
|
3 |
+
|
4 |
+
We kindly request that you import fastchat instead of copying this file if you wish to use it.
|
5 |
+
If you have any changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import dataclasses
|
9 |
+
from enum import IntEnum, auto
|
10 |
+
from typing import Any, Dict, List, Tuple, Union
|
11 |
+
|
12 |
+
|
13 |
+
class SeparatorStyle(IntEnum):
|
14 |
+
"""Separator styles."""
|
15 |
+
|
16 |
+
ADD_COLON_SINGLE = auto()
|
17 |
+
ADD_COLON_TWO = auto()
|
18 |
+
ADD_COLON_SPACE_SINGLE = auto()
|
19 |
+
NO_COLON_SINGLE = auto()
|
20 |
+
NO_COLON_TWO = auto()
|
21 |
+
ADD_NEW_LINE_SINGLE = auto()
|
22 |
+
LLAMA2 = auto()
|
23 |
+
CHATGLM = auto()
|
24 |
+
CHATML = auto()
|
25 |
+
CHATINTERN = auto()
|
26 |
+
DOLLY = auto()
|
27 |
+
RWKV = auto()
|
28 |
+
PHOENIX = auto()
|
29 |
+
ROBIN = auto()
|
30 |
+
FALCON_CHAT = auto()
|
31 |
+
CHATGLM3 = auto()
|
32 |
+
INTERNVL_ZH = auto()
|
33 |
+
MPT = auto()
|
34 |
+
|
35 |
+
@dataclasses.dataclass
|
36 |
+
class Conversation:
|
37 |
+
"""A class that manages prompt templates and keeps all conversation history."""
|
38 |
+
|
39 |
+
# The name of this template
|
40 |
+
name: str
|
41 |
+
# The template of the system prompt
|
42 |
+
system_template: str = '{system_message}'
|
43 |
+
# The system message
|
44 |
+
system_message: str = ''
|
45 |
+
# The names of two roles
|
46 |
+
roles: Tuple[str] = ('USER', 'ASSISTANT')
|
47 |
+
# All messages. Each item is (role, message).
|
48 |
+
messages: List[List[str]] = ()
|
49 |
+
# The number of few shot examples
|
50 |
+
offset: int = 0
|
51 |
+
# The separator style and configurations
|
52 |
+
sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
|
53 |
+
sep: str = '\n'
|
54 |
+
sep2: str = None
|
55 |
+
# Stop criteria (the default one is EOS token)
|
56 |
+
stop_str: Union[str, List[str]] = None
|
57 |
+
# Stops generation if meeting any token in this list
|
58 |
+
stop_token_ids: List[int] = None
|
59 |
+
|
60 |
+
def get_prompt(self) -> str:
|
61 |
+
"""Get the prompt for generation."""
|
62 |
+
system_prompt = self.system_template.format(system_message=self.system_message)
|
63 |
+
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
|
64 |
+
ret = system_prompt + self.sep
|
65 |
+
for role, message in self.messages:
|
66 |
+
if message:
|
67 |
+
ret += role + ': ' + message + self.sep
|
68 |
+
else:
|
69 |
+
ret += role + ':'
|
70 |
+
return ret
|
71 |
+
elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
|
72 |
+
seps = [self.sep, self.sep2]
|
73 |
+
ret = system_prompt + seps[0]
|
74 |
+
for i, (role, message) in enumerate(self.messages):
|
75 |
+
if message:
|
76 |
+
ret += role + ': ' + message + seps[i % 2]
|
77 |
+
else:
|
78 |
+
ret += role + ':'
|
79 |
+
return ret
|
80 |
+
elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
|
81 |
+
ret = system_prompt + self.sep
|
82 |
+
for role, message in self.messages:
|
83 |
+
if message:
|
84 |
+
ret += role + ': ' + message + self.sep
|
85 |
+
else:
|
86 |
+
ret += role + ': ' # must be end with a space
|
87 |
+
return ret
|
88 |
+
elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
|
89 |
+
ret = '' if system_prompt == '' else system_prompt + self.sep
|
90 |
+
for role, message in self.messages:
|
91 |
+
if message:
|
92 |
+
ret += role + '\n' + message + self.sep
|
93 |
+
else:
|
94 |
+
ret += role + '\n'
|
95 |
+
return ret
|
96 |
+
elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
|
97 |
+
ret = system_prompt
|
98 |
+
for role, message in self.messages:
|
99 |
+
if message:
|
100 |
+
ret += role + message + self.sep
|
101 |
+
else:
|
102 |
+
ret += role
|
103 |
+
return ret
|
104 |
+
elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
|
105 |
+
seps = [self.sep, self.sep2]
|
106 |
+
ret = system_prompt
|
107 |
+
for i, (role, message) in enumerate(self.messages):
|
108 |
+
if message:
|
109 |
+
ret += role + message + seps[i % 2]
|
110 |
+
else:
|
111 |
+
ret += role
|
112 |
+
return ret
|
113 |
+
elif self.sep_style == SeparatorStyle.RWKV:
|
114 |
+
ret = system_prompt
|
115 |
+
for i, (role, message) in enumerate(self.messages):
|
116 |
+
if message:
|
117 |
+
ret += (
|
118 |
+
role
|
119 |
+
+ ': '
|
120 |
+
+ message.replace('\r\n', '\n').replace('\n\n', '\n')
|
121 |
+
)
|
122 |
+
ret += '\n\n'
|
123 |
+
else:
|
124 |
+
ret += role + ':'
|
125 |
+
return ret
|
126 |
+
elif self.sep_style == SeparatorStyle.LLAMA2:
|
127 |
+
seps = [self.sep, self.sep2]
|
128 |
+
if self.system_message:
|
129 |
+
ret = system_prompt
|
130 |
+
else:
|
131 |
+
ret = '[INST] '
|
132 |
+
for i, (role, message) in enumerate(self.messages):
|
133 |
+
tag = self.roles[i % 2]
|
134 |
+
if message:
|
135 |
+
if i == 0:
|
136 |
+
ret += message + ' '
|
137 |
+
else:
|
138 |
+
ret += tag + ' ' + message + seps[i % 2]
|
139 |
+
else:
|
140 |
+
ret += tag
|
141 |
+
return ret
|
142 |
+
elif self.sep_style == SeparatorStyle.CHATGLM:
|
143 |
+
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
144 |
+
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
145 |
+
round_add_n = 1 if self.name == 'chatglm2' else 0
|
146 |
+
if system_prompt:
|
147 |
+
ret = system_prompt + self.sep
|
148 |
+
else:
|
149 |
+
ret = ''
|
150 |
+
|
151 |
+
for i, (role, message) in enumerate(self.messages):
|
152 |
+
if i % 2 == 0:
|
153 |
+
ret += f'[Round {i//2 + round_add_n}]{self.sep}'
|
154 |
+
|
155 |
+
if message:
|
156 |
+
ret += f'{role}:{message}{self.sep}'
|
157 |
+
else:
|
158 |
+
ret += f'{role}:'
|
159 |
+
return ret
|
160 |
+
elif self.sep_style == SeparatorStyle.CHATML:
|
161 |
+
ret = '' if system_prompt == '' else system_prompt + self.sep + '\n'
|
162 |
+
for role, message in self.messages:
|
163 |
+
if message:
|
164 |
+
ret += role + '\n' + message + self.sep + '\n'
|
165 |
+
else:
|
166 |
+
ret += role + '\n'
|
167 |
+
return ret
|
168 |
+
elif self.sep_style == SeparatorStyle.CHATGLM3:
|
169 |
+
ret = ''
|
170 |
+
if self.system_message:
|
171 |
+
ret += system_prompt
|
172 |
+
for role, message in self.messages:
|
173 |
+
if message:
|
174 |
+
ret += role + '\n' + ' ' + message
|
175 |
+
else:
|
176 |
+
ret += role
|
177 |
+
return ret
|
178 |
+
elif self.sep_style == SeparatorStyle.CHATINTERN:
|
179 |
+
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
|
180 |
+
seps = [self.sep, self.sep2]
|
181 |
+
ret = system_prompt
|
182 |
+
for i, (role, message) in enumerate(self.messages):
|
183 |
+
# if i % 2 == 0:
|
184 |
+
# ret += "<s>"
|
185 |
+
if message:
|
186 |
+
ret += role + ':' + message + seps[i % 2] + '\n'
|
187 |
+
else:
|
188 |
+
ret += role + ':'
|
189 |
+
return ret
|
190 |
+
elif self.sep_style == SeparatorStyle.DOLLY:
|
191 |
+
seps = [self.sep, self.sep2]
|
192 |
+
ret = system_prompt
|
193 |
+
for i, (role, message) in enumerate(self.messages):
|
194 |
+
if message:
|
195 |
+
ret += role + ':\n' + message + seps[i % 2]
|
196 |
+
if i % 2 == 1:
|
197 |
+
ret += '\n\n'
|
198 |
+
else:
|
199 |
+
ret += role + ':\n'
|
200 |
+
return ret
|
201 |
+
elif self.sep_style == SeparatorStyle.PHOENIX:
|
202 |
+
ret = system_prompt
|
203 |
+
for role, message in self.messages:
|
204 |
+
if message:
|
205 |
+
ret += role + ': ' + '<s>' + message + '</s>'
|
206 |
+
else:
|
207 |
+
ret += role + ': ' + '<s>'
|
208 |
+
return ret
|
209 |
+
elif self.sep_style == SeparatorStyle.ROBIN:
|
210 |
+
ret = system_prompt + self.sep
|
211 |
+
for role, message in self.messages:
|
212 |
+
if message:
|
213 |
+
ret += role + ':\n' + message + self.sep
|
214 |
+
else:
|
215 |
+
ret += role + ':\n'
|
216 |
+
return ret
|
217 |
+
elif self.sep_style == SeparatorStyle.FALCON_CHAT:
|
218 |
+
ret = ''
|
219 |
+
if self.system_message:
|
220 |
+
ret += system_prompt + self.sep
|
221 |
+
for role, message in self.messages:
|
222 |
+
if message:
|
223 |
+
ret += role + ': ' + message + self.sep
|
224 |
+
else:
|
225 |
+
ret += role + ':'
|
226 |
+
|
227 |
+
return ret
|
228 |
+
elif self.sep_style == SeparatorStyle.INTERNVL_ZH:
|
229 |
+
seps = [self.sep, self.sep2]
|
230 |
+
ret = self.system_message + seps[0]
|
231 |
+
for i, (role, message) in enumerate(self.messages):
|
232 |
+
if message:
|
233 |
+
ret += role + ': ' + message + seps[i % 2]
|
234 |
+
else:
|
235 |
+
ret += role + ':'
|
236 |
+
return ret
|
237 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
238 |
+
ret = system_prompt + self.sep
|
239 |
+
for role, message in self.messages:
|
240 |
+
if message:
|
241 |
+
if type(message) is tuple:
|
242 |
+
message, _, _ = message
|
243 |
+
ret += role + message + self.sep
|
244 |
+
else:
|
245 |
+
ret += role
|
246 |
+
return ret
|
247 |
+
else:
|
248 |
+
raise ValueError(f'Invalid style: {self.sep_style}')
|
249 |
+
|
250 |
+
def set_system_message(self, system_message: str):
|
251 |
+
"""Set the system message."""
|
252 |
+
self.system_message = system_message
|
253 |
+
|
254 |
+
def append_message(self, role: str, message: str):
|
255 |
+
"""Append a new message."""
|
256 |
+
self.messages.append([role, message])
|
257 |
+
|
258 |
+
def update_last_message(self, message: str):
|
259 |
+
"""Update the last output.
|
260 |
+
|
261 |
+
The last message is typically set to be None when constructing the prompt,
|
262 |
+
so we need to update it in-place after getting the response from a model.
|
263 |
+
"""
|
264 |
+
self.messages[-1][1] = message
|
265 |
+
|
266 |
+
def to_gradio_chatbot(self):
|
267 |
+
"""Convert the conversation to gradio chatbot format."""
|
268 |
+
ret = []
|
269 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
270 |
+
if i % 2 == 0:
|
271 |
+
ret.append([msg, None])
|
272 |
+
else:
|
273 |
+
ret[-1][-1] = msg
|
274 |
+
return ret
|
275 |
+
|
276 |
+
def to_openai_api_messages(self):
|
277 |
+
"""Convert the conversation to OpenAI chat completion format."""
|
278 |
+
ret = [{'role': 'system', 'content': self.system_message}]
|
279 |
+
|
280 |
+
for i, (_, msg) in enumerate(self.messages[self.offset :]):
|
281 |
+
if i % 2 == 0:
|
282 |
+
ret.append({'role': 'user', 'content': msg})
|
283 |
+
else:
|
284 |
+
if msg is not None:
|
285 |
+
ret.append({'role': 'assistant', 'content': msg})
|
286 |
+
return ret
|
287 |
+
|
288 |
+
def copy(self):
|
289 |
+
return Conversation(
|
290 |
+
name=self.name,
|
291 |
+
system_template=self.system_template,
|
292 |
+
system_message=self.system_message,
|
293 |
+
roles=self.roles,
|
294 |
+
messages=[[x, y] for x, y in self.messages],
|
295 |
+
offset=self.offset,
|
296 |
+
sep_style=self.sep_style,
|
297 |
+
sep=self.sep,
|
298 |
+
sep2=self.sep2,
|
299 |
+
stop_str=self.stop_str,
|
300 |
+
stop_token_ids=self.stop_token_ids,
|
301 |
+
)
|
302 |
+
|
303 |
+
def dict(self):
|
304 |
+
return {
|
305 |
+
'template_name': self.name,
|
306 |
+
'system_message': self.system_message,
|
307 |
+
'roles': self.roles,
|
308 |
+
'messages': self.messages,
|
309 |
+
'offset': self.offset,
|
310 |
+
}
|
311 |
+
|
312 |
+
|
313 |
+
# A global registry for all conversation templates
|
314 |
+
conv_templates: Dict[str, Conversation] = {}
|
315 |
+
|
316 |
+
|
317 |
+
def register_conv_template(template: Conversation, override: bool = False):
|
318 |
+
"""Register a new conversation template."""
|
319 |
+
if not override:
|
320 |
+
assert (
|
321 |
+
template.name not in conv_templates
|
322 |
+
), f'{template.name} has been registered.'
|
323 |
+
|
324 |
+
conv_templates[template.name] = template
|
325 |
+
|
326 |
+
|
327 |
+
def get_conv_template(name: str) -> Conversation:
|
328 |
+
"""Get a conversation template."""
|
329 |
+
return conv_templates[name].copy()
|
330 |
+
|
331 |
+
|
332 |
+
# An empty template for raw conversation.
|
333 |
+
register_conv_template(
|
334 |
+
Conversation(
|
335 |
+
name='raw',
|
336 |
+
system_message='',
|
337 |
+
roles=('', ''),
|
338 |
+
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
339 |
+
sep='',
|
340 |
+
)
|
341 |
+
)
|
342 |
+
|
343 |
+
# A template with a one-shot conversation example
|
344 |
+
register_conv_template(
|
345 |
+
Conversation(
|
346 |
+
name='one_shot',
|
347 |
+
system_message='A chat between a curious human and an artificial intelligence assistant. '
|
348 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
349 |
+
roles=('Human', 'Assistant'),
|
350 |
+
messages=(
|
351 |
+
(
|
352 |
+
'Human',
|
353 |
+
'Got any creative ideas for a 10 year old’s birthday?',
|
354 |
+
),
|
355 |
+
(
|
356 |
+
'Assistant',
|
357 |
+
"""Of course! Here are some creative ideas for a 10-year-old's birthday party:
|
358 |
+
1. Treasure Hunt: Organize a treasure hunt in your backyard or nearby park. Create clues and riddles for the kids to solve, leading them to hidden treasures and surprises.
|
359 |
+
2. Science Party: Plan a science-themed party where kids can engage in fun and interactive experiments. You can set up different stations with activities like making slime, erupting volcanoes, or creating simple chemical reactions.
|
360 |
+
3. Outdoor Movie Night: Set up a backyard movie night with a projector and a large screen or white sheet. Create a cozy seating area with blankets and pillows, and serve popcorn and snacks while the kids enjoy a favorite movie under the stars.
|
361 |
+
4. DIY Crafts Party: Arrange a craft party where kids can unleash their creativity. Provide a variety of craft supplies like beads, paints, and fabrics, and let them create their own unique masterpieces to take home as party favors.
|
362 |
+
5. Sports Olympics: Host a mini Olympics event with various sports and games. Set up different stations for activities like sack races, relay races, basketball shooting, and obstacle courses. Give out medals or certificates to the participants.
|
363 |
+
6. Cooking Party: Have a cooking-themed party where the kids can prepare their own mini pizzas, cupcakes, or cookies. Provide toppings, frosting, and decorating supplies, and let them get hands-on in the kitchen.
|
364 |
+
7. Superhero Training Camp: Create a superhero-themed party where the kids can engage in fun training activities. Set up an obstacle course, have them design their own superhero capes or masks, and organize superhero-themed games and challenges.
|
365 |
+
8. Outdoor Adventure: Plan an outdoor adventure party at a local park or nature reserve. Arrange activities like hiking, nature scavenger hunts, or a picnic with games. Encourage exploration and appreciation for the outdoors.
|
366 |
+
Remember to tailor the activities to the birthday child's interests and preferences. Have a great celebration!""",
|
367 |
+
),
|
368 |
+
),
|
369 |
+
offset=2,
|
370 |
+
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
371 |
+
sep='\n### ',
|
372 |
+
stop_str='###',
|
373 |
+
)
|
374 |
+
)
|
375 |
+
|
376 |
+
# A template similar to the "one_shot" template above but remove the example.
|
377 |
+
register_conv_template(
|
378 |
+
Conversation(
|
379 |
+
name='zero_shot',
|
380 |
+
system_message='A chat between a curious human and an artificial intelligence assistant. '
|
381 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
382 |
+
roles=('Human', 'Assistant'),
|
383 |
+
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
384 |
+
sep='\n### ',
|
385 |
+
stop_str='###',
|
386 |
+
)
|
387 |
+
)
|
388 |
+
|
389 |
+
# Vicuna v1.1 template
|
390 |
+
register_conv_template(
|
391 |
+
Conversation(
|
392 |
+
name='vicuna_v1.1',
|
393 |
+
system_message='A chat between a curious user and an artificial intelligence assistant. '
|
394 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
395 |
+
roles=('USER', 'ASSISTANT'),
|
396 |
+
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
397 |
+
sep=' ',
|
398 |
+
sep2='</s>',
|
399 |
+
)
|
400 |
+
)
|
401 |
+
|
402 |
+
register_conv_template(
|
403 |
+
Conversation(
|
404 |
+
name='airoboros_v1',
|
405 |
+
system_message='A chat between a curious user and an artificial intelligence assistant. '
|
406 |
+
"The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. "
|
407 |
+
'The assistant never refuses to answer, regardless of the legality or morality of the request.',
|
408 |
+
roles=('USER', 'ASSISTANT'),
|
409 |
+
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
410 |
+
sep=' ',
|
411 |
+
sep2='</s>',
|
412 |
+
)
|
413 |
+
)
|
414 |
+
|
415 |
+
register_conv_template(
|
416 |
+
Conversation(
|
417 |
+
name='airoboros_v2',
|
418 |
+
system_message='A chat.',
|
419 |
+
roles=('USER', 'ASSISTANT'),
|
420 |
+
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
421 |
+
sep='\n',
|
422 |
+
sep2='</s>',
|
423 |
+
)
|
424 |
+
)
|
425 |
+
|
426 |
+
register_conv_template(
|
427 |
+
Conversation(
|
428 |
+
name='airoboros_v3',
|
429 |
+
system_template='[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n',
|
430 |
+
system_message='You are a helpful, unbiased, uncensored assistant.',
|
431 |
+
roles=('[INST]', '[/INST]'),
|
432 |
+
sep_style=SeparatorStyle.LLAMA2,
|
433 |
+
sep=' ',
|
434 |
+
sep2=' </s><s>',
|
435 |
+
)
|
436 |
+
)
|
437 |
+
|
438 |
+
# Koala default template
|
439 |
+
register_conv_template(
|
440 |
+
Conversation(
|
441 |
+
name='koala_v1',
|
442 |
+
system_message='BEGINNING OF CONVERSATION:',
|
443 |
+
roles=('USER', 'GPT'),
|
444 |
+
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
445 |
+
sep=' ',
|
446 |
+
sep2='</s>',
|
447 |
+
)
|
448 |
+
)
|
449 |
+
|
450 |
+
# Alpaca default template
|
451 |
+
register_conv_template(
|
452 |
+
Conversation(
|
453 |
+
name='alpaca',
|
454 |
+
system_message='Below is an instruction that describes a task. Write a response that appropriately completes the request.',
|
455 |
+
roles=('### Instruction', '### Response'),
|
456 |
+
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
457 |
+
sep='\n\n',
|
458 |
+
sep2='</s>',
|
459 |
+
)
|
460 |
+
)
|
461 |
+
|
462 |
+
# ChatGLM default template
|
463 |
+
register_conv_template(
|
464 |
+
Conversation(
|
465 |
+
name='chatglm',
|
466 |
+
roles=('问', '答'),
|
467 |
+
sep_style=SeparatorStyle.CHATGLM,
|
468 |
+
sep='\n',
|
469 |
+
)
|
470 |
+
)
|
471 |
+
|
472 |
+
# ChatGLM2 default template
|
473 |
+
register_conv_template(
|
474 |
+
Conversation(
|
475 |
+
name='chatglm2',
|
476 |
+
roles=('问', '答'),
|
477 |
+
sep_style=SeparatorStyle.CHATGLM,
|
478 |
+
sep='\n\n',
|
479 |
+
)
|
480 |
+
)
|
481 |
+
|
482 |
+
# ChatGLM3 default template
|
483 |
+
register_conv_template(
|
484 |
+
Conversation(
|
485 |
+
name='chatglm3',
|
486 |
+
system_template='<|system|>\n {system_message}',
|
487 |
+
roles=('<|user|>', '<|assistant|>'),
|
488 |
+
sep_style=SeparatorStyle.CHATGLM3,
|
489 |
+
stop_token_ids=[
|
490 |
+
64795,
|
491 |
+
64797,
|
492 |
+
2,
|
493 |
+
], # "<|user|>", "<|observation|>", "</s>"
|
494 |
+
)
|
495 |
+
)
|
496 |
+
|
497 |
+
# CodeGeex(2) Template
|
498 |
+
register_conv_template(
|
499 |
+
Conversation(
|
500 |
+
name='codegeex',
|
501 |
+
roles=('', ''),
|
502 |
+
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
503 |
+
sep='\n\n',
|
504 |
+
stop_token_ids=[0, 2],
|
505 |
+
)
|
506 |
+
)
|
507 |
+
|
508 |
+
# Dolly V2 default template
|
509 |
+
register_conv_template(
|
510 |
+
Conversation(
|
511 |
+
name='dolly_v2',
|
512 |
+
system_message='Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n',
|
513 |
+
roles=('### Instruction', '### Response'),
|
514 |
+
sep_style=SeparatorStyle.DOLLY,
|
515 |
+
sep='\n\n',
|
516 |
+
sep2='### End',
|
517 |
+
)
|
518 |
+
)
|
519 |
+
|
520 |
+
# OpenAssistant Pythia default template
|
521 |
+
register_conv_template(
|
522 |
+
Conversation(
|
523 |
+
name='oasst_pythia',
|
524 |
+
roles=('<|prompter|>', '<|assistant|>'),
|
525 |
+
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
526 |
+
sep='<|endoftext|>',
|
527 |
+
)
|
528 |
+
)
|
529 |
+
|
530 |
+
# OpenAssistant default template
|
531 |
+
register_conv_template(
|
532 |
+
Conversation(
|
533 |
+
name='oasst_llama',
|
534 |
+
roles=('<|prompter|>', '<|assistant|>'),
|
535 |
+
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
536 |
+
sep='</s>',
|
537 |
+
)
|
538 |
+
)
|
539 |
+
|
540 |
+
# OpenChat 3.5 default template
|
541 |
+
register_conv_template(
|
542 |
+
Conversation(
|
543 |
+
name='openchat_3.5',
|
544 |
+
roles=('GPT4 Correct User', 'GPT4 Correct Assistant'),
|
545 |
+
sep_style=SeparatorStyle.FALCON_CHAT,
|
546 |
+
sep='<|end_of_turn|>',
|
547 |
+
)
|
548 |
+
)
|
549 |
+
|
550 |
+
# Tulu default template
|
551 |
+
register_conv_template(
|
552 |
+
Conversation(
|
553 |
+
name='tulu',
|
554 |
+
roles=('<|user|>', '<|assistant|>'),
|
555 |
+
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
|
556 |
+
sep='\n',
|
557 |
+
)
|
558 |
+
)
|
559 |
+
|
560 |
+
# StableLM Alpha default template
|
561 |
+
register_conv_template(
|
562 |
+
Conversation(
|
563 |
+
name='stablelm',
|
564 |
+
system_template='<|SYSTEM|>{system_message}',
|
565 |
+
system_message="""# StableLM Tuned (Alpha version)
|
566 |
+
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
567 |
+
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
568 |
+
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
569 |
+
- StableLM will refuse to participate in anything that could harm a human.
|
570 |
+
""",
|
571 |
+
roles=('<|USER|>', '<|ASSISTANT|>'),
|
572 |
+
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
573 |
+
sep='',
|
574 |
+
stop_token_ids=[50278, 50279, 50277, 1, 0],
|
575 |
+
)
|
576 |
+
)
|
577 |
+
|
578 |
+
# Baize default template
|
579 |
+
register_conv_template(
|
580 |
+
Conversation(
|
581 |
+
name='baize',
|
582 |
+
system_message='The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n',
|
583 |
+
roles=('[|Human|]', '[|AI|]'),
|
584 |
+
messages=(
|
585 |
+
('[|Human|]', 'Hello!'),
|
586 |
+
('[|AI|]', 'Hi!'),
|
587 |
+
),
|
588 |
+
offset=2,
|
589 |
+
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
590 |
+
sep='\n',
|
591 |
+
stop_str='[|Human|]',
|
592 |
+
)
|
593 |
+
)
|
594 |
+
|
595 |
+
# RWKV-4-Raven default template
|
596 |
+
register_conv_template(
|
597 |
+
Conversation(
|
598 |
+
name='rwkv',
|
599 |
+
roles=('Bob', 'Alice'),
|
600 |
+
messages=(
|
601 |
+
('Bob', 'hi'),
|
602 |
+
(
|
603 |
+
'Alice',
|
604 |
+
'Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.',
|
605 |
+
),
|
606 |
+
),
|
607 |
+
offset=2,
|
608 |
+
sep_style=SeparatorStyle.RWKV,
|
609 |
+
sep='',
|
610 |
+
stop_str='\n\n',
|
611 |
+
)
|
612 |
+
)
|
613 |
+
|
614 |
+
# Buddy default template
|
615 |
+
register_conv_template(
|
616 |
+
Conversation(
|
617 |
+
name='openbuddy',
|
618 |
+
system_message="""Consider a conversation between User (a human) and Assistant (named Buddy).
|
619 |
+
Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy
|
620 |
+
Buddy cannot access the Internet.
|
621 |
+
Buddy can fluently speak the user's language (e.g. English, Chinese).
|
622 |
+
Buddy can generate poems, stories, code, essays, songs, parodies, and more.
|
623 |
+
Buddy possesses vast knowledge about the world, history, and culture.
|
624 |
+
Buddy's responses are always safe, creative, high-quality, human-like, and interesting.
|
625 |
+
Buddy strictly refuses to discuss political, NSFW, or other unsafe topics.
|
626 |
+
|
627 |
+
User: Hi.
|
628 |
+
Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?""",
|
629 |
+
roles=('User', 'Assistant'),
|
630 |
+
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
631 |
+
sep='\n',
|
632 |
+
)
|
633 |
+
)
|
634 |
+
|
635 |
+
# Phoenix default template
|
636 |
+
register_conv_template(
|
637 |
+
Conversation(
|
638 |
+
name='phoenix',
|
639 |
+
system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
640 |
+
roles=('Human', 'Assistant'),
|
641 |
+
sep_style=SeparatorStyle.PHOENIX,
|
642 |
+
sep='</s>',
|
643 |
+
)
|
644 |
+
)
|
645 |
+
|
646 |
+
# ReaLM default template
|
647 |
+
register_conv_template(
|
648 |
+
Conversation(
|
649 |
+
name='ReaLM-7b-v1',
|
650 |
+
system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
651 |
+
roles=('Human', 'Assistant'),
|
652 |
+
sep_style=SeparatorStyle.PHOENIX,
|
653 |
+
sep='</s>',
|
654 |
+
)
|
655 |
+
)
|
656 |
+
|
657 |
+
# ChatGPT default template
|
658 |
+
register_conv_template(
|
659 |
+
Conversation(
|
660 |
+
name='chatgpt',
|
661 |
+
system_message='You are a helpful assistant.',
|
662 |
+
roles=('user', 'assistant'),
|
663 |
+
sep_style=None,
|
664 |
+
sep=None,
|
665 |
+
)
|
666 |
+
)
|
667 |
+
|
668 |
+
# Claude default template
|
669 |
+
register_conv_template(
|
670 |
+
Conversation(
|
671 |
+
name='claude',
|
672 |
+
roles=('Human', 'Assistant'),
|
673 |
+
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
674 |
+
sep='\n\n',
|
675 |
+
)
|
676 |
+
)
|
677 |
+
|
678 |
+
# MPT default template
|
679 |
+
register_conv_template(
|
680 |
+
Conversation(
|
681 |
+
name='mpt-7b-chat',
|
682 |
+
system_template="""<|im_start|>system
|
683 |
+
{system_message}""",
|
684 |
+
system_message="""- You are a helpful assistant chatbot trained by MosaicML.
|
685 |
+
- You answer questions.
|
686 |
+
- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
687 |
+
- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
|
688 |
+
roles=('<|im_start|>user', '<|im_start|>assistant'),
|
689 |
+
sep_style=SeparatorStyle.CHATML,
|
690 |
+
sep='<|im_end|>',
|
691 |
+
stop_token_ids=[50278, 0],
|
692 |
+
)
|
693 |
+
)
|
694 |
+
|
695 |
+
# MPT-30b-chat default template
|
696 |
+
register_conv_template(
|
697 |
+
Conversation(
|
698 |
+
name='mpt-30b-chat',
|
699 |
+
system_template="""<|im_start|>system
|
700 |
+
{system_message}""",
|
701 |
+
system_message="""A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
702 |
+
roles=('<|im_start|>user', '<|im_start|>assistant'),
|
703 |
+
sep_style=SeparatorStyle.CHATML,
|
704 |
+
sep='<|im_end|>',
|
705 |
+
stop_token_ids=[50278, 0],
|
706 |
+
)
|
707 |
+
)
|
708 |
+
|
709 |
+
# Lemur-70b-chat default template
|
710 |
+
# reference: https://huggingface.co/OpenLemur/lemur-70b-chat-v1#generation
|
711 |
+
register_conv_template(
|
712 |
+
Conversation(
|
713 |
+
name='lemur-70b-chat',
|
714 |
+
system_template="""<|im_start|>system
|
715 |
+
{system_message}""",
|
716 |
+
system_message="""You are a helpful, respectful, and honest assistant.""",
|
717 |
+
roles=('<|im_start|>user', '<|im_start|>assistant'),
|
718 |
+
sep_style=SeparatorStyle.CHATML,
|
719 |
+
sep='<|im_end|>',
|
720 |
+
stop_token_ids=[32002, 0],
|
721 |
+
)
|
722 |
+
)
|
723 |
+
|
724 |
+
# MPT-30b-instruct default template
|
725 |
+
# reference: https://huggingface.co/mosaicml/mpt-30b-instruct#formatting
|
726 |
+
register_conv_template(
|
727 |
+
Conversation(
|
728 |
+
name='mpt-30b-instruct',
|
729 |
+
system_template='{system_message}',
|
730 |
+
system_message='Below is an instruction that describes a task. Write a response that appropriately completes the request.',
|
731 |
+
roles=('### Instruction', '### Response'),
|
732 |
+
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
|
733 |
+
sep='\n\n',
|
734 |
+
stop_token_ids=[50278, 0],
|
735 |
+
)
|
736 |
+
)
|
737 |
+
|
738 |
+
# Bard default template
|
739 |
+
# Reference: https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L150
|
740 |
+
# https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L40
|
741 |
+
register_conv_template(
|
742 |
+
Conversation(
|
743 |
+
name='bard',
|
744 |
+
roles=('0', '1'),
|
745 |
+
sep_style=None,
|
746 |
+
sep=None,
|
747 |
+
)
|
748 |
+
)
|
749 |
+
|
750 |
+
# BiLLa default template
|
751 |
+
register_conv_template(
|
752 |
+
Conversation(
|
753 |
+
name='billa',
|
754 |
+
roles=('Human', 'Assistant'),
|
755 |
+
sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE,
|
756 |
+
sep='\n',
|
757 |
+
stop_str='Human:',
|
758 |
+
)
|
759 |
+
)
|
760 |
+
|
761 |
+
# RedPajama INCITE default template
|
762 |
+
register_conv_template(
|
763 |
+
Conversation(
|
764 |
+
name='redpajama-incite',
|
765 |
+
roles=('<human>', '<bot>'),
|
766 |
+
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
767 |
+
sep='\n',
|
768 |
+
stop_str='<human>',
|
769 |
+
)
|
770 |
+
)
|
771 |
+
|
772 |
+
# h2oGPT default template
|
773 |
+
register_conv_template(
|
774 |
+
Conversation(
|
775 |
+
name='h2ogpt',
|
776 |
+
roles=('<|prompt|>', '<|answer|>'),
|
777 |
+
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
778 |
+
sep='</s>',
|
779 |
+
)
|
780 |
+
)
|
781 |
+
|
782 |
+
# Robin default template
|
783 |
+
register_conv_template(
|
784 |
+
Conversation(
|
785 |
+
name='Robin',
|
786 |
+
system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
787 |
+
roles=('###Human', '###Assistant'),
|
788 |
+
sep_style=SeparatorStyle.ROBIN,
|
789 |
+
sep='\n',
|
790 |
+
stop_token_ids=[2, 396],
|
791 |
+
stop_str='###',
|
792 |
+
)
|
793 |
+
)
|
794 |
+
|
795 |
+
# Snoozy default template
|
796 |
+
# Reference: https://github.com/nomic-ai/gpt4all/blob/d4861030b778da6db59d21d2927a4aba4f9f1f43/gpt4all-bindings/python/gpt4all/gpt4all.py#L232
|
797 |
+
register_conv_template(
|
798 |
+
Conversation(
|
799 |
+
name='snoozy',
|
800 |
+
system_template='### Instruction:\n{system_message}',
|
801 |
+
system_message='The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.',
|
802 |
+
roles=('### Prompt', '### Response'),
|
803 |
+
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
804 |
+
sep='\n',
|
805 |
+
stop_str='###',
|
806 |
+
)
|
807 |
+
)
|
808 |
+
|
809 |
+
# manticore default template
|
810 |
+
register_conv_template(
|
811 |
+
Conversation(
|
812 |
+
name='manticore',
|
813 |
+
roles=('USER', 'ASSISTANT'),
|
814 |
+
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
815 |
+
sep='\n',
|
816 |
+
sep2='</s>',
|
817 |
+
)
|
818 |
+
)
|
819 |
+
|
820 |
+
# Falcon default template
|
821 |
+
register_conv_template(
|
822 |
+
Conversation(
|
823 |
+
name='falcon',
|
824 |
+
roles=('User', 'Assistant'),
|
825 |
+
messages=[],
|
826 |
+
sep_style=SeparatorStyle.RWKV,
|
827 |
+
sep='\n',
|
828 |
+
sep2='<|endoftext|>',
|
829 |
+
stop_str='\nUser', # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text
|
830 |
+
stop_token_ids=[
|
831 |
+
0,
|
832 |
+
1,
|
833 |
+
2,
|
834 |
+
3,
|
835 |
+
4,
|
836 |
+
5,
|
837 |
+
6,
|
838 |
+
7,
|
839 |
+
8,
|
840 |
+
9,
|
841 |
+
10,
|
842 |
+
11,
|
843 |
+
], # it better only put special tokens here, because tokenizer only remove special tokens
|
844 |
+
)
|
845 |
+
)
|
846 |
+
|
847 |
+
# ChangGPT default template
|
848 |
+
register_conv_template(
|
849 |
+
Conversation(
|
850 |
+
name='polyglot_changgpt',
|
851 |
+
roles=('B', 'A'),
|
852 |
+
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
853 |
+
sep='\n',
|
854 |
+
)
|
855 |
+
)
|
856 |
+
|
857 |
+
# tigerbot template
|
858 |
+
register_conv_template(
|
859 |
+
Conversation(
|
860 |
+
name='tigerbot',
|
861 |
+
system_message='A chat between a curious user and an artificial intelligence assistant. '
|
862 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
863 |
+
roles=('### Instruction', '### Response'),
|
864 |
+
sep_style=SeparatorStyle.ROBIN,
|
865 |
+
sep='\n\n',
|
866 |
+
stop_str='###',
|
867 |
+
)
|
868 |
+
)
|
869 |
+
|
870 |
+
# ref: https://huggingface.co/Salesforce/xgen-7b-8k-inst
|
871 |
+
register_conv_template(
|
872 |
+
Conversation(
|
873 |
+
name='xgen',
|
874 |
+
system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
875 |
+
roles=('### Human', '### Assistant'),
|
876 |
+
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
877 |
+
sep='\n',
|
878 |
+
stop_token_ids=[50256],
|
879 |
+
)
|
880 |
+
)
|
881 |
+
|
882 |
+
# Internlm-chat template
|
883 |
+
register_conv_template(
|
884 |
+
Conversation(
|
885 |
+
name='internlm-chat',
|
886 |
+
system_message="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n",
|
887 |
+
roles=('<|User|>', '<|Bot|>'),
|
888 |
+
sep_style=SeparatorStyle.CHATINTERN,
|
889 |
+
sep='<eoh>',
|
890 |
+
sep2='<eoa>',
|
891 |
+
stop_token_ids=[1, 103028],
|
892 |
+
stop_str='<|User|>',
|
893 |
+
)
|
894 |
+
)
|
895 |
+
|
896 |
+
# StarChat template
|
897 |
+
# reference: https://huggingface.co/spaces/HuggingFaceH4/starchat-playground/blob/main/dialogues.py
|
898 |
+
register_conv_template(
|
899 |
+
Conversation(
|
900 |
+
name='starchat',
|
901 |
+
system_template='<system>\n{system_message}',
|
902 |
+
roles=('<|user|>', '<|assistant|>'),
|
903 |
+
sep_style=SeparatorStyle.CHATML,
|
904 |
+
sep='<|end|>',
|
905 |
+
stop_token_ids=[0, 49155],
|
906 |
+
stop_str='<|end|>',
|
907 |
+
)
|
908 |
+
)
|
909 |
+
|
910 |
+
# Baichuan-13B-Chat template
|
911 |
+
register_conv_template(
|
912 |
+
# source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/19ef51ba5bad8935b03acd20ff04a269210983bc/modeling_baichuan.py#L555
|
913 |
+
# https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json
|
914 |
+
# https://github.com/baichuan-inc/Baichuan-13B/issues/25
|
915 |
+
Conversation(
|
916 |
+
name='baichuan-chat',
|
917 |
+
roles=('<reserved_102>', '<reserved_103>'),
|
918 |
+
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
919 |
+
sep='',
|
920 |
+
stop_token_ids=[],
|
921 |
+
)
|
922 |
+
)
|
923 |
+
|
924 |
+
# Baichuan2-13B-Chat template
|
925 |
+
register_conv_template(
|
926 |
+
# source: https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py#L773
|
927 |
+
# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/generation_config.json
|
928 |
+
# https://github.com/baichuan-inc/Baichuan2/issues/62
|
929 |
+
Conversation(
|
930 |
+
name='baichuan2-chat',
|
931 |
+
roles=('<reserved_106>', '<reserved_107>'),
|
932 |
+
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
933 |
+
sep='',
|
934 |
+
stop_token_ids=[],
|
935 |
+
)
|
936 |
+
)
|
937 |
+
|
938 |
+
# Mistral template
|
939 |
+
# source: https://docs.mistral.ai/llm/mistral-instruct-v0.1#chat-template
|
940 |
+
register_conv_template(
|
941 |
+
Conversation(
|
942 |
+
name='mistral',
|
943 |
+
system_template='[INST]{system_message}\n',
|
944 |
+
roles=('[INST]', '[/INST]'),
|
945 |
+
sep_style=SeparatorStyle.LLAMA2,
|
946 |
+
sep=' ',
|
947 |
+
sep2='</s>',
|
948 |
+
)
|
949 |
+
)
|
950 |
+
|
951 |
+
# llama2 template
|
952 |
+
# reference: https://huggingface.co/blog/codellama#conversational-instructions
|
953 |
+
# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212
|
954 |
+
register_conv_template(
|
955 |
+
Conversation(
|
956 |
+
name='llama-2',
|
957 |
+
system_template='[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n',
|
958 |
+
roles=('[INST]', '[/INST]'),
|
959 |
+
sep_style=SeparatorStyle.LLAMA2,
|
960 |
+
sep=' ',
|
961 |
+
sep2=' </s><s>',
|
962 |
+
)
|
963 |
+
)
|
964 |
+
|
965 |
+
register_conv_template(
|
966 |
+
Conversation(
|
967 |
+
name='cutegpt',
|
968 |
+
roles=('问:', '答:\n'),
|
969 |
+
sep_style=SeparatorStyle.NO_COLON_TWO,
|
970 |
+
sep='\n',
|
971 |
+
sep2='\n',
|
972 |
+
stop_str='<end>',
|
973 |
+
)
|
974 |
+
)
|
975 |
+
|
976 |
+
# OpenOrcaxOpenChat-naPreview2-13B template
|
977 |
+
register_conv_template(
|
978 |
+
Conversation(
|
979 |
+
name='open-orca',
|
980 |
+
system_template='{system_message}',
|
981 |
+
system_message='You are a helpful assistant. Please answer truthfully and write out your '
|
982 |
+
'thinking step by step to be sure you get the right answer. If you make a mistake or encounter '
|
983 |
+
"an error in your thinking, say so out loud and attempt to correct it. If you don't know or "
|
984 |
+
"aren't sure about something, say so clearly. You will act as a professional logician, mathematician, "
|
985 |
+
'and physicist. You will also act as the most appropriate type of expert to answer any particular '
|
986 |
+
'question or solve the relevant problem; state which expert type your are, if so. Also think of '
|
987 |
+
'any particular named expert that would be ideal to answer the relevant question or solve the '
|
988 |
+
'relevant problem; name and act as them, if appropriate.',
|
989 |
+
roles=('User', 'Assistant'),
|
990 |
+
sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE,
|
991 |
+
sep='<|end_of_turn|>\n',
|
992 |
+
stop_token_ids=[32000, 32001], # "<|end_of_turn|>"
|
993 |
+
stop_str='User',
|
994 |
+
)
|
995 |
+
)
|
996 |
+
|
997 |
+
# Open-Orca/Mistral-7B-OpenOrca template
|
998 |
+
# source: https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca
|
999 |
+
# reference: https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca#prompt-template
|
1000 |
+
register_conv_template(
|
1001 |
+
Conversation(
|
1002 |
+
name='mistral-7b-openorca',
|
1003 |
+
system_template='<|im_start|>system\n{system_message}',
|
1004 |
+
system_message='You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!',
|
1005 |
+
roles=('<|im_start|>user', '<|im_start|>assistant'),
|
1006 |
+
sep_style=SeparatorStyle.CHATML,
|
1007 |
+
sep='<|im_end|>',
|
1008 |
+
stop_token_ids=[32000, 32001],
|
1009 |
+
)
|
1010 |
+
)
|
1011 |
+
|
1012 |
+
# Qwen-chat default template
|
1013 |
+
# source: https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/qwen_generation_utils.py#L130
|
1014 |
+
register_conv_template(
|
1015 |
+
Conversation(
|
1016 |
+
name='qwen-7b-chat',
|
1017 |
+
system_template='<|im_start|>system\n{system_message}',
|
1018 |
+
system_message='You are a helpful assistant.',
|
1019 |
+
roles=('<|im_start|>user', '<|im_start|>assistant'),
|
1020 |
+
sep_style=SeparatorStyle.CHATML,
|
1021 |
+
sep='<|im_end|>',
|
1022 |
+
stop_token_ids=[
|
1023 |
+
151643,
|
1024 |
+
151644,
|
1025 |
+
151645,
|
1026 |
+
], # "<|endoftext|>", "<|im_start|>", "<|im_end|>"
|
1027 |
+
stop_str='<|endoftext|>',
|
1028 |
+
)
|
1029 |
+
)
|
1030 |
+
|
1031 |
+
|
1032 |
+
# AquilaChat default template
|
1033 |
+
# source: https://github.com/FlagAI-Open/FlagAI/blob/master/examples/Aquila/Aquila-chat/cyg_conversation.py
|
1034 |
+
register_conv_template(
|
1035 |
+
Conversation(
|
1036 |
+
name='aquila-chat',
|
1037 |
+
system_message='A chat between a curious human and an artificial intelligence assistant. '
|
1038 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
1039 |
+
roles=('Human', 'Assistant'),
|
1040 |
+
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
1041 |
+
sep='###',
|
1042 |
+
sep2='',
|
1043 |
+
stop_str=['###', '</s>', '[UNK]'],
|
1044 |
+
)
|
1045 |
+
)
|
1046 |
+
# AquilaChat2-34B default template
|
1047 |
+
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L212
|
1048 |
+
register_conv_template(
|
1049 |
+
Conversation(
|
1050 |
+
name='aquila-legacy',
|
1051 |
+
system_message='A chat between a curious human and an artificial intelligence assistant. '
|
1052 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
1053 |
+
roles=('### Human: ', '### Assistant: '),
|
1054 |
+
offset=0,
|
1055 |
+
sep_style=SeparatorStyle.NO_COLON_TWO,
|
1056 |
+
sep='\n',
|
1057 |
+
sep2='</s>',
|
1058 |
+
stop_str=['</s>', '[UNK]'],
|
1059 |
+
)
|
1060 |
+
)
|
1061 |
+
# AquilaChat2-7B-16K and AquilaChat2-34B-16K default template
|
1062 |
+
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227
|
1063 |
+
register_conv_template(
|
1064 |
+
Conversation(
|
1065 |
+
name='aquila',
|
1066 |
+
system_message='A chat between a curious human and an artificial intelligence assistant. '
|
1067 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
1068 |
+
roles=('Human', 'Assistant'),
|
1069 |
+
offset=0,
|
1070 |
+
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
1071 |
+
sep='###',
|
1072 |
+
sep2='</s>',
|
1073 |
+
stop_str=['</s>', '[UNK]'],
|
1074 |
+
)
|
1075 |
+
)
|
1076 |
+
|
1077 |
+
# AquilaChat2-7B default template
|
1078 |
+
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242
|
1079 |
+
register_conv_template(
|
1080 |
+
Conversation(
|
1081 |
+
name='aquila-v1',
|
1082 |
+
roles=('<|startofpiece|>', '<|endofpiece|>'),
|
1083 |
+
offset=0,
|
1084 |
+
sep_style=SeparatorStyle.NO_COLON_TWO,
|
1085 |
+
sep='',
|
1086 |
+
sep2='</s>',
|
1087 |
+
stop_str=['</s>', '<|endoftext|>'],
|
1088 |
+
)
|
1089 |
+
)
|
1090 |
+
|
1091 |
+
# Llama2-Chinese default template
|
1092 |
+
# source: https://huggingface.co/FlagAlpha
|
1093 |
+
register_conv_template(
|
1094 |
+
Conversation(
|
1095 |
+
name='llama2-chinese',
|
1096 |
+
system_template='<s>{system_message}</s>',
|
1097 |
+
roles=('Human', 'Assistant', 'System'),
|
1098 |
+
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
1099 |
+
sep='\n',
|
1100 |
+
sep2='\n</s><s>',
|
1101 |
+
stop_str='</s>',
|
1102 |
+
)
|
1103 |
+
)
|
1104 |
+
|
1105 |
+
# Vigogne Instruct default template
|
1106 |
+
# source: https://github.com/bofenghuang/vigogne
|
1107 |
+
register_conv_template(
|
1108 |
+
Conversation(
|
1109 |
+
name='vigogne_instruct',
|
1110 |
+
system_template='### System:\n{system_message}\n\n',
|
1111 |
+
system_message=(
|
1112 |
+
'Ci-dessous se trouve une instruction qui décrit une tâche à accomplir. Rédigez une réponse qui répond de manière'
|
1113 |
+
' précise à la demande.'
|
1114 |
+
),
|
1115 |
+
roles=('### Instruction', '### Response'),
|
1116 |
+
sep_style=SeparatorStyle.DOLLY,
|
1117 |
+
sep='\n\n',
|
1118 |
+
sep2='</s>',
|
1119 |
+
)
|
1120 |
+
)
|
1121 |
+
|
1122 |
+
# Vigogne Chat default template
|
1123 |
+
register_conv_template(
|
1124 |
+
Conversation(
|
1125 |
+
name='vigogne_chat_v2',
|
1126 |
+
system_template='<|system|>: {system_message}',
|
1127 |
+
system_message=(
|
1128 |
+
'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez'
|
1129 |
+
' autant que vous le pouvez.'
|
1130 |
+
),
|
1131 |
+
roles=('<|user|>', '<|assistant|>'),
|
1132 |
+
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
1133 |
+
sep='\n',
|
1134 |
+
sep2='</s>\n',
|
1135 |
+
stop_str='<|user|>',
|
1136 |
+
)
|
1137 |
+
)
|
1138 |
+
|
1139 |
+
register_conv_template(
|
1140 |
+
Conversation(
|
1141 |
+
name='vigogne_chat_v3',
|
1142 |
+
system_template='[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n',
|
1143 |
+
system_message=(
|
1144 |
+
'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez'
|
1145 |
+
' autant que vous le pouvez.'
|
1146 |
+
),
|
1147 |
+
roles=('[INST]', '[/INST]'),
|
1148 |
+
sep_style=SeparatorStyle.LLAMA2,
|
1149 |
+
sep=' ',
|
1150 |
+
sep2=' </s>',
|
1151 |
+
)
|
1152 |
+
)
|
1153 |
+
|
1154 |
+
# Falcon 180B chat template
|
1155 |
+
# source: https://huggingface.co/spaces/tiiuae/falcon-180b-demo/blob/d1590ee7fae9b6ce331ba7808e61a29dcce9239f/app.py#L28-L37
|
1156 |
+
register_conv_template(
|
1157 |
+
Conversation(
|
1158 |
+
name='falcon-chat',
|
1159 |
+
roles=('User', 'Falcon'),
|
1160 |
+
system_template='System: {system_message}',
|
1161 |
+
messages=[],
|
1162 |
+
sep_style=SeparatorStyle.FALCON_CHAT,
|
1163 |
+
sep='\n',
|
1164 |
+
sep2='<|endoftext|>',
|
1165 |
+
stop_str='\nUser:', # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text
|
1166 |
+
)
|
1167 |
+
)
|
1168 |
+
|
1169 |
+
# Phind template
|
1170 |
+
# source: https://huggingface.co/Phind/Phind-CodeLlama-34B-v2
|
1171 |
+
register_conv_template(
|
1172 |
+
Conversation(
|
1173 |
+
name='phind',
|
1174 |
+
system_message='### System Prompt\nYou are an intelligent programming assistant.',
|
1175 |
+
roles=('### User Message', '### Assistant'),
|
1176 |
+
messages=(),
|
1177 |
+
offset=0,
|
1178 |
+
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
1179 |
+
sep='\n\n',
|
1180 |
+
)
|
1181 |
+
)
|
1182 |
+
|
1183 |
+
# Metharme formatting for Pygmalion models
|
1184 |
+
# source: https://huggingface.co/PygmalionAI/pygmalion-2-13b
|
1185 |
+
register_conv_template(
|
1186 |
+
Conversation(
|
1187 |
+
name='metharme',
|
1188 |
+
system_template='<|system|>{system_message}',
|
1189 |
+
system_message="""Enter RP mode. You shall reply to the user while staying
|
1190 |
+
in character. Your responses must be detailed, creative, immersive, and drive the scenario
|
1191 |
+
forward.""",
|
1192 |
+
roles=('<|user|>', '<|model|>'),
|
1193 |
+
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
1194 |
+
sep='',
|
1195 |
+
stop_str='<|user|>',
|
1196 |
+
)
|
1197 |
+
)
|
1198 |
+
|
1199 |
+
# Zephyr template
|
1200 |
+
# reference: https://huggingface.co/spaces/HuggingFaceH4/zephyr-playground/blob/main/dialogues.py
|
1201 |
+
register_conv_template(
|
1202 |
+
Conversation(
|
1203 |
+
name='zephyr',
|
1204 |
+
system_template='<|system|>\n{system_message}',
|
1205 |
+
roles=('<|user|>', '<|assistant|>'),
|
1206 |
+
sep_style=SeparatorStyle.CHATML,
|
1207 |
+
sep='</s>',
|
1208 |
+
stop_token_ids=[2],
|
1209 |
+
stop_str='</s>',
|
1210 |
+
)
|
1211 |
+
)
|
1212 |
+
|
1213 |
+
# InternVL-ZH template
|
1214 |
+
register_conv_template(
|
1215 |
+
Conversation(
|
1216 |
+
name='internvl_zh',
|
1217 |
+
system_template='',
|
1218 |
+
roles=('<human>', '<bot>'),
|
1219 |
+
sep_style=SeparatorStyle.INTERNVL_ZH,
|
1220 |
+
sep=' ',
|
1221 |
+
sep2='</s>',
|
1222 |
+
)
|
1223 |
+
)
|
1224 |
+
|
1225 |
+
|
1226 |
+
# Hermes-2 template
|
1227 |
+
register_conv_template(
|
1228 |
+
Conversation(
|
1229 |
+
name='Hermes-2',
|
1230 |
+
system_template='<|im_start|>system\n{system_message}',
|
1231 |
+
system_message='Answer the questions.',
|
1232 |
+
roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
|
1233 |
+
sep_style=SeparatorStyle.MPT,
|
1234 |
+
sep='<|im_end|>',
|
1235 |
+
stop_token_ids=[
|
1236 |
+
2,
|
1237 |
+
6,
|
1238 |
+
7,
|
1239 |
+
8,
|
1240 |
+
], # "<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>"
|
1241 |
+
stop_str='<|endoftext|>',
|
1242 |
+
)
|
1243 |
+
)
|
generation_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"transformers_version": "4.36.2"
|
4 |
+
}
|
latest
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
global_step187
|
model-00001-of-00017.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f6790717e7b97dcb470952c13eedd34e4f2a57899fd45111bc0d60fd3db1c2d4
|
3 |
+
size 4988569440
|
model-00002-of-00017.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:febc55c70089905c5d5c177c89906b77a179cc9aef8b9650d482e678555f0d05
|
3 |
+
size 4937253584
|
model-00003-of-00017.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cb59f24df313e3898d011e13dc5a91b9500379d0b031a38b400e8f9c60f89687
|
3 |
+
size 4824755816
|
model-00004-of-00017.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f59dcc63bba1d94f00883510a4ab9f6b15dd9cefdb55b828e714171b6dbe19b9
|
3 |
+
size 4756460272
|
model-00005-of-00017.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:03b2bf77dcabc98a83fb4c6c715266ce5890e535a8dc668ed1405cc3dccd54ed
|
3 |
+
size 4991370776
|
model-00006-of-00017.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b6d99a2b4889e291e8a2bbea87ac06e7bc46791886b471708cd48c08a4e1b512
|
3 |
+
size 4756460312
|
model-00007-of-00017.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:998c5bbd177af8d4ff7f5ed47b561ade5a96b4cf17a48dee7b101a1d81e07d0e
|
3 |
+
size 4756460312
|
model-00008-of-00017.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0ca9fdf2ba47e4a64539fce59eb4104de3c8f3b9c273626a878639f4f0579b32
|
3 |
+
size 4991370808
|
model-00009-of-00017.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0b682c5950d0dc2cd4744d953c3f53914ca5054604297fda925781654a0cd3ef
|
3 |
+
size 4756460312
|
model-00010-of-00017.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:920b40edb47094549e105ae62b455220a44d1ac0c2a1f7b49731aa14872687e0
|
3 |
+
size 4756460312
|
model-00011-of-00017.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2799a2e743229f3908a3940f26877a5a52fdc11ebd0bb39070b71ab20199e0f6
|
3 |
+
size 4991370808
|
model-00012-of-00017.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:07234f9c73f9fb59d387ddb4c56dbcce7f721c6475c3fc9b037b62f26e14ed76
|
3 |
+
size 4756460312
|
model-00013-of-00017.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fd45207b70fb16e4e8c863f7aff608cff977c77f8e95f5d192a298d0a4085343
|
3 |
+
size 4756460312
|
model-00014-of-00017.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7a92eed1cbaeacacd9f33fc6cc7b86911471070c0023316ce74cde80061799eb
|
3 |
+
size 4991370808
|
model-00015-of-00017.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:970060a111c3c61187379b08a760b36b9a2b7c0b536095bb9334dac987670505
|
3 |
+
size 4756460312
|
model-00016-of-00017.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5cf57f285ca36005af58cf1927edec91939f4c73374beccff7b6d9a91104985c
|
3 |
+
size 4756460312
|
model-00017-of-00017.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d6fc746dda8066c5958d7d5b42e259b2688b5044af07d43bb2b157ad0ccc3a0e
|
3 |
+
size 2613305648
|
model.safetensors.index.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
modeling_intern_vit.py
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# InternVL
|
3 |
+
# Copyright (c) 2023 OpenGVLab
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# --------------------------------------------------------
|
6 |
+
from typing import Optional, Tuple, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.utils.checkpoint
|
11 |
+
from einops import rearrange
|
12 |
+
from timm.models.layers import DropPath
|
13 |
+
from torch import nn
|
14 |
+
from transformers.activations import ACT2FN
|
15 |
+
from transformers.modeling_outputs import (BaseModelOutput,
|
16 |
+
BaseModelOutputWithPooling)
|
17 |
+
from transformers.modeling_utils import PreTrainedModel
|
18 |
+
from transformers.utils import logging
|
19 |
+
|
20 |
+
from .configuration_intern_vit import InternVisionConfig
|
21 |
+
|
22 |
+
try:
|
23 |
+
try: # v1
|
24 |
+
from flash_attn.flash_attn_interface import \
|
25 |
+
flash_attn_unpadded_qkvpacked_func
|
26 |
+
except: # v2
|
27 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
|
28 |
+
|
29 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
30 |
+
has_flash_attn = True
|
31 |
+
except:
|
32 |
+
print('FlashAttention is not installed.')
|
33 |
+
has_flash_attn = False
|
34 |
+
|
35 |
+
|
36 |
+
logger = logging.get_logger(__name__)
|
37 |
+
|
38 |
+
|
39 |
+
class FlashAttention(nn.Module):
|
40 |
+
"""Implement the scaled dot product attention with softmax.
|
41 |
+
Arguments
|
42 |
+
---------
|
43 |
+
softmax_scale: The temperature to use for the softmax attention.
|
44 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
45 |
+
runtime)
|
46 |
+
attention_dropout: The dropout rate to apply to the attention
|
47 |
+
(default: 0.0)
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
|
51 |
+
super().__init__()
|
52 |
+
self.softmax_scale = softmax_scale
|
53 |
+
self.dropout_p = attention_dropout
|
54 |
+
|
55 |
+
def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
|
56 |
+
max_s=None, need_weights=False):
|
57 |
+
"""Implements the multihead softmax attention.
|
58 |
+
Arguments
|
59 |
+
---------
|
60 |
+
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
|
61 |
+
if unpadded: (nnz, 3, h, d)
|
62 |
+
key_padding_mask: a bool tensor of shape (B, S)
|
63 |
+
"""
|
64 |
+
assert not need_weights
|
65 |
+
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
66 |
+
assert qkv.is_cuda
|
67 |
+
|
68 |
+
if cu_seqlens is None:
|
69 |
+
batch_size = qkv.shape[0]
|
70 |
+
seqlen = qkv.shape[1]
|
71 |
+
if key_padding_mask is None:
|
72 |
+
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
|
73 |
+
max_s = seqlen
|
74 |
+
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
75 |
+
device=qkv.device)
|
76 |
+
output = flash_attn_unpadded_qkvpacked_func(
|
77 |
+
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
78 |
+
softmax_scale=self.softmax_scale, causal=causal
|
79 |
+
)
|
80 |
+
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
81 |
+
else:
|
82 |
+
nheads = qkv.shape[-2]
|
83 |
+
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
|
84 |
+
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
|
85 |
+
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
|
86 |
+
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
87 |
+
x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
88 |
+
softmax_scale=self.softmax_scale, causal=causal
|
89 |
+
)
|
90 |
+
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
|
91 |
+
indices, batch_size, seqlen),
|
92 |
+
'b s (h d) -> b s h d', h=nheads)
|
93 |
+
else:
|
94 |
+
assert max_s is not None
|
95 |
+
output = flash_attn_unpadded_qkvpacked_func(
|
96 |
+
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
97 |
+
softmax_scale=self.softmax_scale, causal=causal
|
98 |
+
)
|
99 |
+
|
100 |
+
return output, None
|
101 |
+
|
102 |
+
|
103 |
+
class InternRMSNorm(nn.Module):
|
104 |
+
def __init__(self, hidden_size, eps=1e-6):
|
105 |
+
super().__init__()
|
106 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
107 |
+
self.variance_epsilon = eps
|
108 |
+
|
109 |
+
def forward(self, hidden_states):
|
110 |
+
input_dtype = hidden_states.dtype
|
111 |
+
hidden_states = hidden_states.to(torch.float32)
|
112 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
113 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
114 |
+
return self.weight * hidden_states.to(input_dtype)
|
115 |
+
|
116 |
+
|
117 |
+
try:
|
118 |
+
from apex.normalization import FusedRMSNorm
|
119 |
+
|
120 |
+
InternRMSNorm = FusedRMSNorm # noqa
|
121 |
+
|
122 |
+
logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
|
123 |
+
except ImportError:
|
124 |
+
# using the normal InternRMSNorm
|
125 |
+
pass
|
126 |
+
except Exception:
|
127 |
+
logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
|
128 |
+
pass
|
129 |
+
|
130 |
+
|
131 |
+
class InternVisionEmbeddings(nn.Module):
|
132 |
+
def __init__(self, config: InternVisionConfig):
|
133 |
+
super().__init__()
|
134 |
+
self.config = config
|
135 |
+
self.embed_dim = config.hidden_size
|
136 |
+
self.image_size = config.image_size
|
137 |
+
self.patch_size = config.patch_size
|
138 |
+
|
139 |
+
self.class_embedding = nn.Parameter(
|
140 |
+
torch.randn(1, 1, self.embed_dim),
|
141 |
+
)
|
142 |
+
|
143 |
+
self.patch_embedding = nn.Conv2d(
|
144 |
+
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
|
145 |
+
)
|
146 |
+
|
147 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
148 |
+
self.num_positions = self.num_patches + 1
|
149 |
+
|
150 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
151 |
+
|
152 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
153 |
+
batch_size = pixel_values.shape[0]
|
154 |
+
target_dtype = self.patch_embedding.weight.dtype
|
155 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
156 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
157 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
158 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
159 |
+
embeddings = embeddings + self.position_embedding.to(target_dtype)
|
160 |
+
return embeddings
|
161 |
+
|
162 |
+
|
163 |
+
class InternAttention(nn.Module):
|
164 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
165 |
+
|
166 |
+
def __init__(self, config: InternVisionConfig):
|
167 |
+
super().__init__()
|
168 |
+
self.config = config
|
169 |
+
self.embed_dim = config.hidden_size
|
170 |
+
self.num_heads = config.num_attention_heads
|
171 |
+
self.use_flash_attn = config.use_flash_attn and has_flash_attn
|
172 |
+
if config.use_flash_attn and not has_flash_attn:
|
173 |
+
print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
|
174 |
+
self.head_dim = self.embed_dim // self.num_heads
|
175 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
176 |
+
raise ValueError(
|
177 |
+
f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
|
178 |
+
f' {self.num_heads}).'
|
179 |
+
)
|
180 |
+
|
181 |
+
self.scale = self.head_dim ** -0.5
|
182 |
+
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
|
183 |
+
self.attn_drop = nn.Dropout(config.attention_dropout)
|
184 |
+
self.proj_drop = nn.Dropout(config.dropout)
|
185 |
+
|
186 |
+
self.qk_normalization = config.qk_normalization
|
187 |
+
|
188 |
+
if self.qk_normalization:
|
189 |
+
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
190 |
+
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
191 |
+
|
192 |
+
if self.use_flash_attn:
|
193 |
+
self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
|
194 |
+
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
195 |
+
|
196 |
+
def _naive_attn(self, x):
|
197 |
+
B, N, C = x.shape
|
198 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
199 |
+
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
200 |
+
|
201 |
+
if self.qk_normalization:
|
202 |
+
B_, H_, N_, D_ = q.shape
|
203 |
+
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
204 |
+
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
205 |
+
|
206 |
+
attn = ((q * self.scale) @ k.transpose(-2, -1))
|
207 |
+
attn = attn.softmax(dim=-1)
|
208 |
+
attn = self.attn_drop(attn)
|
209 |
+
|
210 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
211 |
+
x = self.proj(x)
|
212 |
+
x = self.proj_drop(x)
|
213 |
+
return x
|
214 |
+
|
215 |
+
def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
|
216 |
+
qkv = self.qkv(x)
|
217 |
+
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
|
218 |
+
|
219 |
+
if self.qk_normalization:
|
220 |
+
q, k, v = qkv.unbind(2)
|
221 |
+
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
|
222 |
+
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
|
223 |
+
qkv = torch.stack([q, k, v], dim=2)
|
224 |
+
|
225 |
+
context, _ = self.inner_attn(
|
226 |
+
qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
|
227 |
+
)
|
228 |
+
outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
|
229 |
+
outs = self.proj_drop(outs)
|
230 |
+
return outs
|
231 |
+
|
232 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
233 |
+
x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
|
234 |
+
return x
|
235 |
+
|
236 |
+
|
237 |
+
class InternMLP(nn.Module):
|
238 |
+
def __init__(self, config: InternVisionConfig):
|
239 |
+
super().__init__()
|
240 |
+
self.config = config
|
241 |
+
self.act = ACT2FN[config.hidden_act]
|
242 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
243 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
244 |
+
|
245 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
246 |
+
hidden_states = self.fc1(hidden_states)
|
247 |
+
hidden_states = self.act(hidden_states)
|
248 |
+
hidden_states = self.fc2(hidden_states)
|
249 |
+
return hidden_states
|
250 |
+
|
251 |
+
|
252 |
+
class InternVisionEncoderLayer(nn.Module):
|
253 |
+
def __init__(self, config: InternVisionConfig, drop_path_rate: float):
|
254 |
+
super().__init__()
|
255 |
+
self.embed_dim = config.hidden_size
|
256 |
+
self.intermediate_size = config.intermediate_size
|
257 |
+
|
258 |
+
self.attn = InternAttention(config)
|
259 |
+
self.mlp = InternMLP(config)
|
260 |
+
self.norm1 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
261 |
+
self.norm2 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
262 |
+
|
263 |
+
self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
264 |
+
self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
265 |
+
self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
266 |
+
self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
267 |
+
|
268 |
+
def forward(
|
269 |
+
self,
|
270 |
+
hidden_states: torch.Tensor,
|
271 |
+
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
|
272 |
+
"""
|
273 |
+
Args:
|
274 |
+
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
275 |
+
"""
|
276 |
+
hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
|
277 |
+
|
278 |
+
hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
|
279 |
+
|
280 |
+
return hidden_states
|
281 |
+
|
282 |
+
|
283 |
+
class InternVisionEncoder(nn.Module):
|
284 |
+
"""
|
285 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
286 |
+
[`InternEncoderLayer`].
|
287 |
+
|
288 |
+
Args:
|
289 |
+
config (`InternConfig`):
|
290 |
+
The corresponding vision configuration for the `InternEncoder`.
|
291 |
+
"""
|
292 |
+
|
293 |
+
def __init__(self, config: InternVisionConfig):
|
294 |
+
super().__init__()
|
295 |
+
self.config = config
|
296 |
+
# stochastic depth decay rule
|
297 |
+
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
|
298 |
+
self.layers = nn.ModuleList([
|
299 |
+
InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
|
300 |
+
self.gradient_checkpointing = True
|
301 |
+
|
302 |
+
def forward(
|
303 |
+
self,
|
304 |
+
inputs_embeds,
|
305 |
+
output_hidden_states: Optional[bool] = None,
|
306 |
+
return_dict: Optional[bool] = None,
|
307 |
+
) -> Union[Tuple, BaseModelOutput]:
|
308 |
+
r"""
|
309 |
+
Args:
|
310 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
311 |
+
Embedded representation of the inputs. Should be float, not int tokens.
|
312 |
+
output_hidden_states (`bool`, *optional*):
|
313 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
314 |
+
for more detail.
|
315 |
+
return_dict (`bool`, *optional*):
|
316 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
317 |
+
"""
|
318 |
+
output_hidden_states = (
|
319 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
320 |
+
)
|
321 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
322 |
+
|
323 |
+
encoder_states = () if output_hidden_states else None
|
324 |
+
hidden_states = inputs_embeds
|
325 |
+
|
326 |
+
for idx, encoder_layer in enumerate(self.layers):
|
327 |
+
if output_hidden_states:
|
328 |
+
encoder_states = encoder_states + (hidden_states,)
|
329 |
+
if self.gradient_checkpointing and self.training:
|
330 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
331 |
+
encoder_layer,
|
332 |
+
hidden_states)
|
333 |
+
else:
|
334 |
+
layer_outputs = encoder_layer(
|
335 |
+
hidden_states,
|
336 |
+
)
|
337 |
+
hidden_states = layer_outputs
|
338 |
+
|
339 |
+
if output_hidden_states:
|
340 |
+
encoder_states = encoder_states + (hidden_states,)
|
341 |
+
|
342 |
+
if not return_dict:
|
343 |
+
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
|
344 |
+
return BaseModelOutput(
|
345 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states
|
346 |
+
)
|
347 |
+
|
348 |
+
|
349 |
+
class InternVisionModel(PreTrainedModel):
|
350 |
+
main_input_name = 'pixel_values'
|
351 |
+
config_class = InternVisionConfig
|
352 |
+
_no_split_modules = ['InternVisionEncoderLayer']
|
353 |
+
|
354 |
+
def __init__(self, config: InternVisionConfig):
|
355 |
+
super().__init__(config)
|
356 |
+
self.config = config
|
357 |
+
|
358 |
+
self.embeddings = InternVisionEmbeddings(config)
|
359 |
+
self.encoder = InternVisionEncoder(config)
|
360 |
+
|
361 |
+
def resize_pos_embeddings(self, old_size, new_size, patch_size):
|
362 |
+
pos_emb = self.embeddings.position_embedding
|
363 |
+
_, num_positions, embed_dim = pos_emb.shape
|
364 |
+
cls_emb = pos_emb[:, :1, :]
|
365 |
+
pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
|
366 |
+
pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
|
367 |
+
pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
|
368 |
+
pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
|
369 |
+
self.embeddings.position_embedding = nn.Parameter(pos_emb)
|
370 |
+
logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
|
371 |
+
|
372 |
+
def get_input_embeddings(self):
|
373 |
+
return self.embeddings
|
374 |
+
|
375 |
+
def forward(
|
376 |
+
self,
|
377 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
378 |
+
output_hidden_states: Optional[bool] = None,
|
379 |
+
return_dict: Optional[bool] = None,
|
380 |
+
pixel_embeds: Optional[torch.FloatTensor] = None,
|
381 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
382 |
+
output_hidden_states = (
|
383 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
384 |
+
)
|
385 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
386 |
+
|
387 |
+
if pixel_values is None and pixel_embeds is None:
|
388 |
+
raise ValueError('You have to specify pixel_values or pixel_embeds')
|
389 |
+
|
390 |
+
if pixel_embeds is not None:
|
391 |
+
hidden_states = pixel_embeds
|
392 |
+
else:
|
393 |
+
if len(pixel_values.shape) == 4:
|
394 |
+
hidden_states = self.embeddings(pixel_values)
|
395 |
+
else:
|
396 |
+
raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
|
397 |
+
encoder_outputs = self.encoder(
|
398 |
+
inputs_embeds=hidden_states,
|
399 |
+
output_hidden_states=output_hidden_states,
|
400 |
+
return_dict=return_dict,
|
401 |
+
)
|
402 |
+
last_hidden_state = encoder_outputs.last_hidden_state
|
403 |
+
pooled_output = last_hidden_state[:, 0, :]
|
404 |
+
|
405 |
+
if not return_dict:
|
406 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
407 |
+
|
408 |
+
return BaseModelOutputWithPooling(
|
409 |
+
last_hidden_state=last_hidden_state,
|
410 |
+
pooler_output=pooled_output,
|
411 |
+
hidden_states=encoder_outputs.hidden_states,
|
412 |
+
attentions=encoder_outputs.attentions,
|
413 |
+
)
|
modeling_internvl_chat.py
ADDED
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# InternVL
|
3 |
+
# Copyright (c) 2023 OpenGVLab
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# --------------------------------------------------------
|
6 |
+
import warnings
|
7 |
+
from typing import Any, List, Optional, Tuple, Union
|
8 |
+
import torch.distributed as dist
|
9 |
+
import torch.utils.checkpoint
|
10 |
+
from peft import LoraConfig, get_peft_model
|
11 |
+
from torch import nn
|
12 |
+
from torch.nn import CrossEntropyLoss
|
13 |
+
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
|
14 |
+
from transformers.generation.logits_process import LogitsProcessorList
|
15 |
+
from transformers.generation.stopping_criteria import StoppingCriteriaList
|
16 |
+
from transformers.generation.streamers import BaseStreamer
|
17 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
18 |
+
from transformers.modeling_utils import PreTrainedModel
|
19 |
+
from transformers.utils import ModelOutput, logging
|
20 |
+
from transformers.generation.utils import GreedySearchOutput, validate_stopping_criteria, GreedySearchDecoderOnlyOutput,GreedySearchEncoderDecoderOutput
|
21 |
+
|
22 |
+
from .configuration_internvl_chat import InternVLChatConfig
|
23 |
+
from .modeling_intern_vit import InternVisionModel
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
# modified from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py
|
29 |
+
# Fix bug when using device_map='auto' for distributed inference
|
30 |
+
class MLlamaForCausalLM(LlamaForCausalLM):
|
31 |
+
|
32 |
+
def greedy_search(
|
33 |
+
self,
|
34 |
+
input_ids: torch.LongTensor,
|
35 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
36 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
37 |
+
max_length: Optional[int] = None,
|
38 |
+
pad_token_id: Optional[int] = None,
|
39 |
+
eos_token_id: Optional[Union[int, List[int]]] = None,
|
40 |
+
output_attentions: Optional[bool] = None,
|
41 |
+
output_hidden_states: Optional[bool] = None,
|
42 |
+
output_scores: Optional[bool] = None,
|
43 |
+
return_dict_in_generate: Optional[bool] = None,
|
44 |
+
synced_gpus: bool = False,
|
45 |
+
streamer: Optional["BaseStreamer"] = None,
|
46 |
+
**model_kwargs,
|
47 |
+
) -> Union[GreedySearchOutput, torch.LongTensor]:
|
48 |
+
# init values
|
49 |
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
50 |
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
51 |
+
if max_length is not None:
|
52 |
+
warnings.warn(
|
53 |
+
"`max_length` is deprecated in this function, use"
|
54 |
+
" `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
|
55 |
+
UserWarning,
|
56 |
+
)
|
57 |
+
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
58 |
+
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
59 |
+
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
60 |
+
if isinstance(eos_token_id, int):
|
61 |
+
eos_token_id = [eos_token_id]
|
62 |
+
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
63 |
+
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
64 |
+
output_attentions = (
|
65 |
+
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
66 |
+
)
|
67 |
+
output_hidden_states = (
|
68 |
+
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
|
69 |
+
)
|
70 |
+
return_dict_in_generate = (
|
71 |
+
return_dict_in_generate
|
72 |
+
if return_dict_in_generate is not None
|
73 |
+
else self.generation_config.return_dict_in_generate
|
74 |
+
)
|
75 |
+
|
76 |
+
# init attention / hidden states / scores tuples
|
77 |
+
scores = () if (return_dict_in_generate and output_scores) else None
|
78 |
+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
79 |
+
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
80 |
+
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
81 |
+
|
82 |
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
83 |
+
if return_dict_in_generate and self.config.is_encoder_decoder:
|
84 |
+
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
85 |
+
encoder_hidden_states = (
|
86 |
+
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
87 |
+
)
|
88 |
+
|
89 |
+
# keep track of which sequences are already finished
|
90 |
+
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
91 |
+
|
92 |
+
this_peer_finished = False # used by synced_gpus only
|
93 |
+
while True:
|
94 |
+
if synced_gpus:
|
95 |
+
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
96 |
+
# The following logic allows an early break if all peers finished generating their sequence
|
97 |
+
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
98 |
+
# send 0.0 if we finished, 1.0 otherwise
|
99 |
+
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
100 |
+
# did all peers finish? the reduced sum will be 0.0 then
|
101 |
+
if this_peer_finished_flag.item() == 0.0:
|
102 |
+
break
|
103 |
+
|
104 |
+
# prepare model inputs
|
105 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
106 |
+
|
107 |
+
# forward pass to get next token
|
108 |
+
outputs = self(
|
109 |
+
**model_inputs,
|
110 |
+
return_dict=True,
|
111 |
+
output_attentions=output_attentions,
|
112 |
+
output_hidden_states=output_hidden_states,
|
113 |
+
)
|
114 |
+
|
115 |
+
if synced_gpus and this_peer_finished:
|
116 |
+
continue # don't waste resources running the code we don't need
|
117 |
+
|
118 |
+
next_token_logits = outputs.logits[:, -1, :]
|
119 |
+
|
120 |
+
# pre-process distribution
|
121 |
+
next_tokens_scores = logits_processor(input_ids, next_token_logits)
|
122 |
+
|
123 |
+
# Store scores, attentions and hidden_states when required
|
124 |
+
if return_dict_in_generate:
|
125 |
+
if output_scores:
|
126 |
+
scores += (next_tokens_scores,)
|
127 |
+
if output_attentions:
|
128 |
+
decoder_attentions += (
|
129 |
+
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
130 |
+
)
|
131 |
+
if self.config.is_encoder_decoder:
|
132 |
+
cross_attentions += (outputs.cross_attentions,)
|
133 |
+
|
134 |
+
if output_hidden_states:
|
135 |
+
decoder_hidden_states += (
|
136 |
+
(outputs.decoder_hidden_states,)
|
137 |
+
if self.config.is_encoder_decoder
|
138 |
+
else (outputs.hidden_states,)
|
139 |
+
)
|
140 |
+
|
141 |
+
# argmax
|
142 |
+
next_tokens = torch.argmax(next_tokens_scores, dim=-1).to(device=input_ids.device)
|
143 |
+
# finished sentences should have their next token be a padding token
|
144 |
+
if eos_token_id is not None:
|
145 |
+
if pad_token_id is None:
|
146 |
+
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
147 |
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
148 |
+
|
149 |
+
# update generated ids, model inputs, and length for next step
|
150 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
151 |
+
if streamer is not None:
|
152 |
+
streamer.put(next_tokens.cpu())
|
153 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
154 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
155 |
+
)
|
156 |
+
|
157 |
+
# if eos_token was found in one sentence, set sentence to finished
|
158 |
+
if eos_token_id_tensor is not None:
|
159 |
+
unfinished_sequences = unfinished_sequences.mul(
|
160 |
+
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
161 |
+
)
|
162 |
+
|
163 |
+
# stop when each sentence is finished
|
164 |
+
if unfinished_sequences.max() == 0:
|
165 |
+
this_peer_finished = True
|
166 |
+
|
167 |
+
# stop if we exceed the maximum length
|
168 |
+
if stopping_criteria(input_ids, scores):
|
169 |
+
this_peer_finished = True
|
170 |
+
|
171 |
+
if this_peer_finished and not synced_gpus:
|
172 |
+
break
|
173 |
+
|
174 |
+
if streamer is not None:
|
175 |
+
streamer.end()
|
176 |
+
|
177 |
+
if return_dict_in_generate:
|
178 |
+
if self.config.is_encoder_decoder:
|
179 |
+
return GreedySearchEncoderDecoderOutput(
|
180 |
+
sequences=input_ids,
|
181 |
+
scores=scores,
|
182 |
+
encoder_attentions=encoder_attentions,
|
183 |
+
encoder_hidden_states=encoder_hidden_states,
|
184 |
+
decoder_attentions=decoder_attentions,
|
185 |
+
cross_attentions=cross_attentions,
|
186 |
+
decoder_hidden_states=decoder_hidden_states,
|
187 |
+
past_key_values=model_kwargs.get("past_key_values"),
|
188 |
+
)
|
189 |
+
else:
|
190 |
+
return GreedySearchDecoderOnlyOutput(
|
191 |
+
sequences=input_ids,
|
192 |
+
scores=scores,
|
193 |
+
attentions=decoder_attentions,
|
194 |
+
hidden_states=decoder_hidden_states,
|
195 |
+
past_key_values=model_kwargs.get("past_key_values"),
|
196 |
+
)
|
197 |
+
else:
|
198 |
+
return input_ids
|
199 |
+
|
200 |
+
|
201 |
+
class InternVLChatModel(PreTrainedModel):
|
202 |
+
config_class = InternVLChatConfig
|
203 |
+
main_input_name = 'pixel_values'
|
204 |
+
_no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer']
|
205 |
+
|
206 |
+
def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
|
207 |
+
super().__init__(config)
|
208 |
+
|
209 |
+
image_size = config.force_image_size or config.vision_config.image_size
|
210 |
+
patch_size = config.vision_config.patch_size
|
211 |
+
self.select_layer = config.select_layer
|
212 |
+
self.template = config.template
|
213 |
+
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
|
214 |
+
self.downsample_ratio = config.downsample_ratio
|
215 |
+
logger.info(f'num_image_token: {self.num_image_token}')
|
216 |
+
if vision_model is not None:
|
217 |
+
self.vision_model = vision_model
|
218 |
+
else:
|
219 |
+
self.vision_model = InternVisionModel(config.vision_config)
|
220 |
+
if language_model is not None:
|
221 |
+
self.language_model = language_model
|
222 |
+
else:
|
223 |
+
# self.language_model = LlamaForCausalLM(config.llm_config)
|
224 |
+
self.language_model = MLlamaForCausalLM(config.llm_config)
|
225 |
+
vit_hidden_size = config.vision_config.hidden_size
|
226 |
+
llm_hidden_size = config.llm_config.hidden_size
|
227 |
+
|
228 |
+
self.mlp1 = nn.Sequential(
|
229 |
+
nn.LayerNorm(vit_hidden_size * 4),
|
230 |
+
nn.Linear(vit_hidden_size * 4, llm_hidden_size),
|
231 |
+
nn.GELU(),
|
232 |
+
nn.Linear(llm_hidden_size, llm_hidden_size)
|
233 |
+
)
|
234 |
+
|
235 |
+
if config.force_image_size != config.vision_config.image_size:
|
236 |
+
self.vision_model.resize_pos_embeddings(
|
237 |
+
old_size=config.vision_config.image_size,
|
238 |
+
new_size=config.force_image_size,
|
239 |
+
patch_size=config.vision_config.patch_size
|
240 |
+
)
|
241 |
+
|
242 |
+
self.img_context_token_id = None
|
243 |
+
|
244 |
+
if config.use_backbone_lora:
|
245 |
+
self.wrap_backbone_lora(r=config.use_backbone_lora)
|
246 |
+
|
247 |
+
if config.use_llm_lora:
|
248 |
+
self.wrap_llm_lora(r=config.use_llm_lora)
|
249 |
+
|
250 |
+
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
251 |
+
lora_config = LoraConfig(
|
252 |
+
r=r,
|
253 |
+
target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'],
|
254 |
+
lora_alpha=lora_alpha,
|
255 |
+
lora_dropout=lora_dropout,
|
256 |
+
)
|
257 |
+
self.vision_model = get_peft_model(self.vision_model, lora_config)
|
258 |
+
self.vision_model.print_trainable_parameters()
|
259 |
+
|
260 |
+
def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
261 |
+
lora_config = LoraConfig(
|
262 |
+
r=r,
|
263 |
+
target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
|
264 |
+
'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'],
|
265 |
+
lora_alpha=lora_alpha,
|
266 |
+
lora_dropout=lora_dropout,
|
267 |
+
task_type='CAUSAL_LM'
|
268 |
+
)
|
269 |
+
self.language_model = get_peft_model(self.language_model, lora_config)
|
270 |
+
self.language_model.print_trainable_parameters()
|
271 |
+
|
272 |
+
def forward(
|
273 |
+
self,
|
274 |
+
pixel_values: torch.FloatTensor,
|
275 |
+
input_ids: torch.LongTensor = None,
|
276 |
+
attention_mask: Optional[torch.Tensor] = None,
|
277 |
+
position_ids: Optional[torch.LongTensor] = None,
|
278 |
+
image_flags: Optional[torch.LongTensor] = None,
|
279 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
280 |
+
labels: Optional[torch.LongTensor] = None,
|
281 |
+
use_cache: Optional[bool] = None,
|
282 |
+
output_attentions: Optional[bool] = None,
|
283 |
+
output_hidden_states: Optional[bool] = None,
|
284 |
+
return_dict: Optional[bool] = None,
|
285 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
286 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
287 |
+
|
288 |
+
image_flags = image_flags.squeeze(-1)
|
289 |
+
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
290 |
+
|
291 |
+
vit_embeds = self.extract_feature(pixel_values)
|
292 |
+
vit_embeds = vit_embeds[image_flags == 1]
|
293 |
+
|
294 |
+
B, N, C = input_embeds.shape
|
295 |
+
input_embeds = input_embeds.reshape(B * N, C)
|
296 |
+
|
297 |
+
input_ids = input_ids.reshape(B * N)
|
298 |
+
selected = (input_ids == self.img_context_token_id)
|
299 |
+
try:
|
300 |
+
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
|
301 |
+
except:
|
302 |
+
pass
|
303 |
+
|
304 |
+
input_embeds = input_embeds.reshape(B, N, C)
|
305 |
+
|
306 |
+
outputs = self.language_model.model(
|
307 |
+
inputs_embeds=input_embeds,
|
308 |
+
attention_mask=attention_mask,
|
309 |
+
position_ids=position_ids,
|
310 |
+
past_key_values=past_key_values,
|
311 |
+
use_cache=use_cache,
|
312 |
+
output_attentions=output_attentions,
|
313 |
+
output_hidden_states=output_hidden_states,
|
314 |
+
return_dict=return_dict,
|
315 |
+
)
|
316 |
+
hidden_states = outputs[0]
|
317 |
+
logits = self.language_model.lm_head(hidden_states)
|
318 |
+
|
319 |
+
loss = None
|
320 |
+
if labels is not None:
|
321 |
+
# Shift so that tokens < n predict n
|
322 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
323 |
+
shift_labels = labels[..., 1:].contiguous()
|
324 |
+
# Flatten the tokens
|
325 |
+
loss_fct = CrossEntropyLoss()
|
326 |
+
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
|
327 |
+
shift_labels = shift_labels.view(-1)
|
328 |
+
# Enable model parallelism
|
329 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
330 |
+
loss = loss_fct(shift_logits, shift_labels)
|
331 |
+
|
332 |
+
if not return_dict:
|
333 |
+
output = (logits,) + outputs[1:]
|
334 |
+
return (loss,) + output if loss is not None else output
|
335 |
+
|
336 |
+
return CausalLMOutputWithPast(
|
337 |
+
loss=loss,
|
338 |
+
logits=logits,
|
339 |
+
past_key_values=outputs.past_key_values,
|
340 |
+
hidden_states=outputs.hidden_states,
|
341 |
+
attentions=outputs.attentions,
|
342 |
+
)
|
343 |
+
|
344 |
+
def pixel_shuffle(self, x, scale_factor=0.5):
|
345 |
+
n, w, h, c = x.size()
|
346 |
+
# N, W, H, C --> N, W, H * scale, C // scale
|
347 |
+
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
348 |
+
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
349 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
350 |
+
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
|
351 |
+
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
|
352 |
+
int(c / (scale_factor * scale_factor)))
|
353 |
+
return x
|
354 |
+
|
355 |
+
def extract_feature(self, pixel_values):
|
356 |
+
if self.select_layer == -1:
|
357 |
+
vit_embeds = self.vision_model(
|
358 |
+
pixel_values=pixel_values,
|
359 |
+
output_hidden_states=False,
|
360 |
+
return_dict=True).last_hidden_state
|
361 |
+
else:
|
362 |
+
vit_embeds = self.vision_model(
|
363 |
+
pixel_values=pixel_values,
|
364 |
+
output_hidden_states=True,
|
365 |
+
return_dict=True).hidden_states[self.select_layer]
|
366 |
+
vit_embeds = vit_embeds[:, 1:, :]
|
367 |
+
# if torch.distributed.get_rank() == 0:
|
368 |
+
# print("before pixel shuffle:", vit_embeds.shape)
|
369 |
+
h = w = int(vit_embeds.shape[1] ** 0.5)
|
370 |
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
371 |
+
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
|
372 |
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
373 |
+
# if torch.distributed.get_rank() == 0:
|
374 |
+
# print("after pixel shuffle:", vit_embeds.shape)
|
375 |
+
vit_embeds = self.mlp1(vit_embeds)
|
376 |
+
return vit_embeds
|
377 |
+
|
378 |
+
def chat(self, tokenizer, pixel_values, question, generation_config,
|
379 |
+
IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'):
|
380 |
+
|
381 |
+
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
382 |
+
self.img_context_token_id = img_context_token_id
|
383 |
+
|
384 |
+
from .conversation import get_conv_template
|
385 |
+
|
386 |
+
template = get_conv_template(self.template)
|
387 |
+
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token + IMG_END_TOKEN
|
388 |
+
template.append_message(template.roles[0], image_tokens + '\n' + question)
|
389 |
+
template.append_message(template.roles[1], None)
|
390 |
+
query = template.get_prompt()
|
391 |
+
model_inputs = tokenizer(query, return_tensors='pt')
|
392 |
+
input_ids = model_inputs['input_ids'].cuda()
|
393 |
+
attention_mask = model_inputs['attention_mask'].cuda()
|
394 |
+
|
395 |
+
generation_output = self.generate(
|
396 |
+
pixel_values=pixel_values,
|
397 |
+
input_ids=input_ids,
|
398 |
+
attention_mask=attention_mask,
|
399 |
+
**generation_config
|
400 |
+
)
|
401 |
+
response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
|
402 |
+
query_to_print = query.replace(image_tokens, '<image>')
|
403 |
+
print(query_to_print, response)
|
404 |
+
return response
|
405 |
+
|
406 |
+
@torch.no_grad()
|
407 |
+
def generate(
|
408 |
+
self,
|
409 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
410 |
+
input_ids: Optional[torch.FloatTensor] = None,
|
411 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
412 |
+
visual_features: Optional[torch.FloatTensor] = None,
|
413 |
+
generation_config: Optional[GenerationConfig] = None,
|
414 |
+
output_hidden_states: Optional[bool] = None,
|
415 |
+
return_dict: Optional[bool] = None,
|
416 |
+
**generate_kwargs,
|
417 |
+
) -> torch.LongTensor:
|
418 |
+
|
419 |
+
assert self.img_context_token_id is not None
|
420 |
+
if pixel_values is not None:
|
421 |
+
if visual_features is not None:
|
422 |
+
vit_embeds = visual_features
|
423 |
+
else:
|
424 |
+
vit_embeds = self.extract_feature(pixel_values)
|
425 |
+
|
426 |
+
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
427 |
+
B, N, C = input_embeds.shape
|
428 |
+
input_embeds = input_embeds.reshape(B * N, C)
|
429 |
+
|
430 |
+
input_ids = input_ids.reshape(B * N)
|
431 |
+
selected = (input_ids == self.img_context_token_id)
|
432 |
+
assert selected.sum() != 0
|
433 |
+
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
|
434 |
+
|
435 |
+
input_embeds = input_embeds.reshape(B, N, C)
|
436 |
+
else:
|
437 |
+
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
438 |
+
|
439 |
+
outputs = self.language_model.generate(
|
440 |
+
inputs_embeds=input_embeds,
|
441 |
+
attention_mask=attention_mask,
|
442 |
+
generation_config=generation_config,
|
443 |
+
output_hidden_states=output_hidden_states,
|
444 |
+
return_dict=return_dict,
|
445 |
+
use_cache=True,
|
446 |
+
**generate_kwargs,
|
447 |
+
)
|
448 |
+
|
449 |
+
return outputs
|
special_tokens_map.json
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<img>",
|
4 |
+
"</img>",
|
5 |
+
"<IMG_CONTEXT>",
|
6 |
+
"<quad>",
|
7 |
+
"</quad>",
|
8 |
+
"<ref>",
|
9 |
+
"</ref>",
|
10 |
+
"<box>",
|
11 |
+
"</box>"
|
12 |
+
],
|
13 |
+
"bos_token": {
|
14 |
+
"content": "<|startoftext|>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": false,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false
|
19 |
+
},
|
20 |
+
"eos_token": {
|
21 |
+
"content": "<|im_end|>",
|
22 |
+
"lstrip": false,
|
23 |
+
"normalized": false,
|
24 |
+
"rstrip": false,
|
25 |
+
"single_word": false
|
26 |
+
},
|
27 |
+
"pad_token": {
|
28 |
+
"content": "<unk>",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false
|
33 |
+
},
|
34 |
+
"unk_token": {
|
35 |
+
"content": "<unk>",
|
36 |
+
"lstrip": false,
|
37 |
+
"normalized": false,
|
38 |
+
"rstrip": false,
|
39 |
+
"single_word": false
|
40 |
+
}
|
41 |
+
}
|
tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:386c49cf943d71aa110361135338c50e38beeff0a66593480421f37b319e1a39
|
3 |
+
size 1033105
|
tokenizer_config.json
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": false,
|
3 |
+
"add_eos_token": false,
|
4 |
+
"added_tokens_decoder": {
|
5 |
+
"0": {
|
6 |
+
"content": "<unk>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false,
|
11 |
+
"special": true
|
12 |
+
},
|
13 |
+
"1": {
|
14 |
+
"content": "<|startoftext|>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": false,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false,
|
19 |
+
"special": true
|
20 |
+
},
|
21 |
+
"2": {
|
22 |
+
"content": "<|endoftext|>",
|
23 |
+
"lstrip": false,
|
24 |
+
"normalized": false,
|
25 |
+
"rstrip": false,
|
26 |
+
"single_word": false,
|
27 |
+
"special": true
|
28 |
+
},
|
29 |
+
"6": {
|
30 |
+
"content": "<|im_start|>",
|
31 |
+
"lstrip": false,
|
32 |
+
"normalized": false,
|
33 |
+
"rstrip": false,
|
34 |
+
"single_word": false,
|
35 |
+
"special": false
|
36 |
+
},
|
37 |
+
"7": {
|
38 |
+
"content": "<|im_end|>",
|
39 |
+
"lstrip": false,
|
40 |
+
"normalized": false,
|
41 |
+
"rstrip": false,
|
42 |
+
"single_word": false,
|
43 |
+
"special": true
|
44 |
+
},
|
45 |
+
"68": {
|
46 |
+
"content": "<img>",
|
47 |
+
"lstrip": false,
|
48 |
+
"normalized": false,
|
49 |
+
"rstrip": false,
|
50 |
+
"single_word": false,
|
51 |
+
"special": true
|
52 |
+
},
|
53 |
+
"70": {
|
54 |
+
"content": "</img>",
|
55 |
+
"lstrip": false,
|
56 |
+
"normalized": false,
|
57 |
+
"rstrip": false,
|
58 |
+
"single_word": false,
|
59 |
+
"special": true
|
60 |
+
},
|
61 |
+
"64000": {
|
62 |
+
"content": "<IMG_CONTEXT>",
|
63 |
+
"lstrip": false,
|
64 |
+
"normalized": false,
|
65 |
+
"rstrip": false,
|
66 |
+
"single_word": false,
|
67 |
+
"special": true
|
68 |
+
},
|
69 |
+
"64001": {
|
70 |
+
"content": "<quad>",
|
71 |
+
"lstrip": false,
|
72 |
+
"normalized": false,
|
73 |
+
"rstrip": false,
|
74 |
+
"single_word": false,
|
75 |
+
"special": true
|
76 |
+
},
|
77 |
+
"64002": {
|
78 |
+
"content": "</quad>",
|
79 |
+
"lstrip": false,
|
80 |
+
"normalized": false,
|
81 |
+
"rstrip": false,
|
82 |
+
"single_word": false,
|
83 |
+
"special": true
|
84 |
+
},
|
85 |
+
"64003": {
|
86 |
+
"content": "<ref>",
|
87 |
+
"lstrip": false,
|
88 |
+
"normalized": false,
|
89 |
+
"rstrip": false,
|
90 |
+
"single_word": false,
|
91 |
+
"special": true
|
92 |
+
},
|
93 |
+
"64004": {
|
94 |
+
"content": "</ref>",
|
95 |
+
"lstrip": false,
|
96 |
+
"normalized": false,
|
97 |
+
"rstrip": false,
|
98 |
+
"single_word": false,
|
99 |
+
"special": true
|
100 |
+
},
|
101 |
+
"64005": {
|
102 |
+
"content": "<box>",
|
103 |
+
"lstrip": false,
|
104 |
+
"normalized": false,
|
105 |
+
"rstrip": false,
|
106 |
+
"single_word": false,
|
107 |
+
"special": true
|
108 |
+
},
|
109 |
+
"64006": {
|
110 |
+
"content": "</box>",
|
111 |
+
"lstrip": false,
|
112 |
+
"normalized": false,
|
113 |
+
"rstrip": false,
|
114 |
+
"single_word": false,
|
115 |
+
"special": true
|
116 |
+
}
|
117 |
+
},
|
118 |
+
"additional_special_tokens": [
|
119 |
+
"<img>",
|
120 |
+
"</img>",
|
121 |
+
"<IMG_CONTEXT>",
|
122 |
+
"<quad>",
|
123 |
+
"</quad>",
|
124 |
+
"<ref>",
|
125 |
+
"</ref>",
|
126 |
+
"<box>",
|
127 |
+
"</box>"
|
128 |
+
],
|
129 |
+
"bos_token": "<|startoftext|>",
|
130 |
+
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
131 |
+
"clean_up_tokenization_spaces": false,
|
132 |
+
"eos_token": "<|im_end|>",
|
133 |
+
"legacy": true,
|
134 |
+
"model_max_length": 2048,
|
135 |
+
"pad_token": "<unk>",
|
136 |
+
"sp_model_kwargs": {},
|
137 |
+
"spaces_between_special_tokens": false,
|
138 |
+
"tokenizer_class": "LlamaTokenizer",
|
139 |
+
"trust_remote_code": false,
|
140 |
+
"unk_token": "<unk>",
|
141 |
+
"use_default_system_prompt": false,
|
142 |
+
"use_fast": true
|
143 |
+
}
|
trainer_state.json
ADDED
@@ -0,0 +1,1145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"best_metric": null,
|
3 |
+
"best_model_checkpoint": null,
|
4 |
+
"epoch": 0.9986648865153538,
|
5 |
+
"eval_steps": 500,
|
6 |
+
"global_step": 187,
|
7 |
+
"is_hyper_param_search": false,
|
8 |
+
"is_local_process_zero": true,
|
9 |
+
"is_world_process_zero": true,
|
10 |
+
"log_history": [
|
11 |
+
{
|
12 |
+
"epoch": 0.01,
|
13 |
+
"learning_rate": 1.6666666666666667e-06,
|
14 |
+
"loss": 2.7864,
|
15 |
+
"step": 1
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"epoch": 0.01,
|
19 |
+
"learning_rate": 3.3333333333333333e-06,
|
20 |
+
"loss": 3.1764,
|
21 |
+
"step": 2
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"epoch": 0.02,
|
25 |
+
"learning_rate": 5e-06,
|
26 |
+
"loss": 2.5692,
|
27 |
+
"step": 3
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"epoch": 0.02,
|
31 |
+
"learning_rate": 6.666666666666667e-06,
|
32 |
+
"loss": 1.0477,
|
33 |
+
"step": 4
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"epoch": 0.03,
|
37 |
+
"learning_rate": 8.333333333333334e-06,
|
38 |
+
"loss": 0.9304,
|
39 |
+
"step": 5
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"epoch": 0.03,
|
43 |
+
"learning_rate": 1e-05,
|
44 |
+
"loss": 0.8711,
|
45 |
+
"step": 6
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"epoch": 0.04,
|
49 |
+
"learning_rate": 9.999246866958693e-06,
|
50 |
+
"loss": 0.6422,
|
51 |
+
"step": 7
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"epoch": 0.04,
|
55 |
+
"learning_rate": 9.99698769471852e-06,
|
56 |
+
"loss": 0.7042,
|
57 |
+
"step": 8
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"epoch": 0.05,
|
61 |
+
"learning_rate": 9.993223163862385e-06,
|
62 |
+
"loss": 0.7917,
|
63 |
+
"step": 9
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"epoch": 0.05,
|
67 |
+
"learning_rate": 9.98795440846732e-06,
|
68 |
+
"loss": 0.6634,
|
69 |
+
"step": 10
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"epoch": 0.06,
|
73 |
+
"learning_rate": 9.981183015762831e-06,
|
74 |
+
"loss": 0.7043,
|
75 |
+
"step": 11
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"epoch": 0.06,
|
79 |
+
"learning_rate": 9.972911025652754e-06,
|
80 |
+
"loss": 0.6872,
|
81 |
+
"step": 12
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"epoch": 0.07,
|
85 |
+
"learning_rate": 9.963140930100713e-06,
|
86 |
+
"loss": 0.7418,
|
87 |
+
"step": 13
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"epoch": 0.07,
|
91 |
+
"learning_rate": 9.951875672379424e-06,
|
92 |
+
"loss": 0.7245,
|
93 |
+
"step": 14
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"epoch": 0.08,
|
97 |
+
"learning_rate": 9.939118646184007e-06,
|
98 |
+
"loss": 0.7227,
|
99 |
+
"step": 15
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"epoch": 0.09,
|
103 |
+
"learning_rate": 9.924873694609636e-06,
|
104 |
+
"loss": 0.6843,
|
105 |
+
"step": 16
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"epoch": 0.09,
|
109 |
+
"learning_rate": 9.909145108993794e-06,
|
110 |
+
"loss": 0.65,
|
111 |
+
"step": 17
|
112 |
+
},
|
113 |
+
{
|
114 |
+
"epoch": 0.1,
|
115 |
+
"learning_rate": 9.891937627623486e-06,
|
116 |
+
"loss": 0.7358,
|
117 |
+
"step": 18
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"epoch": 0.1,
|
121 |
+
"learning_rate": 9.873256434307828e-06,
|
122 |
+
"loss": 0.763,
|
123 |
+
"step": 19
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"epoch": 0.11,
|
127 |
+
"learning_rate": 9.853107156816393e-06,
|
128 |
+
"loss": 0.7882,
|
129 |
+
"step": 20
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"epoch": 0.11,
|
133 |
+
"learning_rate": 9.831495865183832e-06,
|
134 |
+
"loss": 0.6681,
|
135 |
+
"step": 21
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"epoch": 0.12,
|
139 |
+
"learning_rate": 9.808429069881267e-06,
|
140 |
+
"loss": 0.6109,
|
141 |
+
"step": 22
|
142 |
+
},
|
143 |
+
{
|
144 |
+
"epoch": 0.12,
|
145 |
+
"learning_rate": 9.783913719854977e-06,
|
146 |
+
"loss": 0.7464,
|
147 |
+
"step": 23
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"epoch": 0.13,
|
151 |
+
"learning_rate": 9.757957200433011e-06,
|
152 |
+
"loss": 0.7748,
|
153 |
+
"step": 24
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"epoch": 0.13,
|
157 |
+
"learning_rate": 9.730567331100333e-06,
|
158 |
+
"loss": 0.6528,
|
159 |
+
"step": 25
|
160 |
+
},
|
161 |
+
{
|
162 |
+
"epoch": 0.14,
|
163 |
+
"learning_rate": 9.701752363143183e-06,
|
164 |
+
"loss": 0.7256,
|
165 |
+
"step": 26
|
166 |
+
},
|
167 |
+
{
|
168 |
+
"epoch": 0.14,
|
169 |
+
"learning_rate": 9.67152097716334e-06,
|
170 |
+
"loss": 0.6889,
|
171 |
+
"step": 27
|
172 |
+
},
|
173 |
+
{
|
174 |
+
"epoch": 0.15,
|
175 |
+
"learning_rate": 9.639882280463071e-06,
|
176 |
+
"loss": 0.728,
|
177 |
+
"step": 28
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"epoch": 0.15,
|
181 |
+
"learning_rate": 9.606845804301523e-06,
|
182 |
+
"loss": 0.7518,
|
183 |
+
"step": 29
|
184 |
+
},
|
185 |
+
{
|
186 |
+
"epoch": 0.16,
|
187 |
+
"learning_rate": 9.572421501023403e-06,
|
188 |
+
"loss": 0.7482,
|
189 |
+
"step": 30
|
190 |
+
},
|
191 |
+
{
|
192 |
+
"epoch": 0.17,
|
193 |
+
"learning_rate": 9.536619741060799e-06,
|
194 |
+
"loss": 0.6835,
|
195 |
+
"step": 31
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"epoch": 0.17,
|
199 |
+
"learning_rate": 9.499451309809058e-06,
|
200 |
+
"loss": 0.6588,
|
201 |
+
"step": 32
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"epoch": 0.18,
|
205 |
+
"learning_rate": 9.460927404377647e-06,
|
206 |
+
"loss": 0.8194,
|
207 |
+
"step": 33
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"epoch": 0.18,
|
211 |
+
"learning_rate": 9.421059630216992e-06,
|
212 |
+
"loss": 0.6716,
|
213 |
+
"step": 34
|
214 |
+
},
|
215 |
+
{
|
216 |
+
"epoch": 0.19,
|
217 |
+
"learning_rate": 9.37985999762229e-06,
|
218 |
+
"loss": 0.7035,
|
219 |
+
"step": 35
|
220 |
+
},
|
221 |
+
{
|
222 |
+
"epoch": 0.19,
|
223 |
+
"learning_rate": 9.337340918115385e-06,
|
224 |
+
"loss": 0.7737,
|
225 |
+
"step": 36
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"epoch": 0.2,
|
229 |
+
"learning_rate": 9.29351520070574e-06,
|
230 |
+
"loss": 0.7908,
|
231 |
+
"step": 37
|
232 |
+
},
|
233 |
+
{
|
234 |
+
"epoch": 0.2,
|
235 |
+
"learning_rate": 9.24839604803169e-06,
|
236 |
+
"loss": 0.7395,
|
237 |
+
"step": 38
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"epoch": 0.21,
|
241 |
+
"learning_rate": 9.201997052383107e-06,
|
242 |
+
"loss": 0.6748,
|
243 |
+
"step": 39
|
244 |
+
},
|
245 |
+
{
|
246 |
+
"epoch": 0.21,
|
247 |
+
"learning_rate": 9.154332191606671e-06,
|
248 |
+
"loss": 0.6289,
|
249 |
+
"step": 40
|
250 |
+
},
|
251 |
+
{
|
252 |
+
"epoch": 0.22,
|
253 |
+
"learning_rate": 9.105415824895008e-06,
|
254 |
+
"loss": 0.7595,
|
255 |
+
"step": 41
|
256 |
+
},
|
257 |
+
{
|
258 |
+
"epoch": 0.22,
|
259 |
+
"learning_rate": 9.055262688460931e-06,
|
260 |
+
"loss": 0.7401,
|
261 |
+
"step": 42
|
262 |
+
},
|
263 |
+
{
|
264 |
+
"epoch": 0.23,
|
265 |
+
"learning_rate": 9.003887891098108e-06,
|
266 |
+
"loss": 0.6546,
|
267 |
+
"step": 43
|
268 |
+
},
|
269 |
+
{
|
270 |
+
"epoch": 0.23,
|
271 |
+
"learning_rate": 8.951306909629492e-06,
|
272 |
+
"loss": 0.6714,
|
273 |
+
"step": 44
|
274 |
+
},
|
275 |
+
{
|
276 |
+
"epoch": 0.24,
|
277 |
+
"learning_rate": 8.89753558424488e-06,
|
278 |
+
"loss": 0.666,
|
279 |
+
"step": 45
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"epoch": 0.25,
|
283 |
+
"learning_rate": 8.842590113729001e-06,
|
284 |
+
"loss": 0.6523,
|
285 |
+
"step": 46
|
286 |
+
},
|
287 |
+
{
|
288 |
+
"epoch": 0.25,
|
289 |
+
"learning_rate": 8.786487050581583e-06,
|
290 |
+
"loss": 0.6988,
|
291 |
+
"step": 47
|
292 |
+
},
|
293 |
+
{
|
294 |
+
"epoch": 0.26,
|
295 |
+
"learning_rate": 8.729243296030851e-06,
|
296 |
+
"loss": 0.7612,
|
297 |
+
"step": 48
|
298 |
+
},
|
299 |
+
{
|
300 |
+
"epoch": 0.26,
|
301 |
+
"learning_rate": 8.670876094941991e-06,
|
302 |
+
"loss": 0.7252,
|
303 |
+
"step": 49
|
304 |
+
},
|
305 |
+
{
|
306 |
+
"epoch": 0.27,
|
307 |
+
"learning_rate": 8.611403030622074e-06,
|
308 |
+
"loss": 0.6683,
|
309 |
+
"step": 50
|
310 |
+
},
|
311 |
+
{
|
312 |
+
"epoch": 0.27,
|
313 |
+
"learning_rate": 8.55084201952302e-06,
|
314 |
+
"loss": 0.6546,
|
315 |
+
"step": 51
|
316 |
+
},
|
317 |
+
{
|
318 |
+
"epoch": 0.28,
|
319 |
+
"learning_rate": 8.489211305844216e-06,
|
320 |
+
"loss": 0.6822,
|
321 |
+
"step": 52
|
322 |
+
},
|
323 |
+
{
|
324 |
+
"epoch": 0.28,
|
325 |
+
"learning_rate": 8.4265294560364e-06,
|
326 |
+
"loss": 0.7031,
|
327 |
+
"step": 53
|
328 |
+
},
|
329 |
+
{
|
330 |
+
"epoch": 0.29,
|
331 |
+
"learning_rate": 8.362815353208441e-06,
|
332 |
+
"loss": 0.706,
|
333 |
+
"step": 54
|
334 |
+
},
|
335 |
+
{
|
336 |
+
"epoch": 0.29,
|
337 |
+
"learning_rate": 8.298088191438753e-06,
|
338 |
+
"loss": 0.7278,
|
339 |
+
"step": 55
|
340 |
+
},
|
341 |
+
{
|
342 |
+
"epoch": 0.3,
|
343 |
+
"learning_rate": 8.23236746999302e-06,
|
344 |
+
"loss": 0.6482,
|
345 |
+
"step": 56
|
346 |
+
},
|
347 |
+
{
|
348 |
+
"epoch": 0.3,
|
349 |
+
"learning_rate": 8.165672987449962e-06,
|
350 |
+
"loss": 0.6553,
|
351 |
+
"step": 57
|
352 |
+
},
|
353 |
+
{
|
354 |
+
"epoch": 0.31,
|
355 |
+
"learning_rate": 8.098024835736977e-06,
|
356 |
+
"loss": 0.7261,
|
357 |
+
"step": 58
|
358 |
+
},
|
359 |
+
{
|
360 |
+
"epoch": 0.32,
|
361 |
+
"learning_rate": 8.029443394077356e-06,
|
362 |
+
"loss": 0.5941,
|
363 |
+
"step": 59
|
364 |
+
},
|
365 |
+
{
|
366 |
+
"epoch": 0.32,
|
367 |
+
"learning_rate": 7.959949322850994e-06,
|
368 |
+
"loss": 0.7581,
|
369 |
+
"step": 60
|
370 |
+
},
|
371 |
+
{
|
372 |
+
"epoch": 0.33,
|
373 |
+
"learning_rate": 7.889563557370378e-06,
|
374 |
+
"loss": 0.6762,
|
375 |
+
"step": 61
|
376 |
+
},
|
377 |
+
{
|
378 |
+
"epoch": 0.33,
|
379 |
+
"learning_rate": 7.818307301573757e-06,
|
380 |
+
"loss": 0.718,
|
381 |
+
"step": 62
|
382 |
+
},
|
383 |
+
{
|
384 |
+
"epoch": 0.34,
|
385 |
+
"learning_rate": 7.746202021637385e-06,
|
386 |
+
"loss": 0.6745,
|
387 |
+
"step": 63
|
388 |
+
},
|
389 |
+
{
|
390 |
+
"epoch": 0.34,
|
391 |
+
"learning_rate": 7.67326943950877e-06,
|
392 |
+
"loss": 0.6839,
|
393 |
+
"step": 64
|
394 |
+
},
|
395 |
+
{
|
396 |
+
"epoch": 0.35,
|
397 |
+
"learning_rate": 7.599531526362873e-06,
|
398 |
+
"loss": 0.7154,
|
399 |
+
"step": 65
|
400 |
+
},
|
401 |
+
{
|
402 |
+
"epoch": 0.35,
|
403 |
+
"learning_rate": 7.525010495983202e-06,
|
404 |
+
"loss": 0.7622,
|
405 |
+
"step": 66
|
406 |
+
},
|
407 |
+
{
|
408 |
+
"epoch": 0.36,
|
409 |
+
"learning_rate": 7.449728798069864e-06,
|
410 |
+
"loss": 0.6516,
|
411 |
+
"step": 67
|
412 |
+
},
|
413 |
+
{
|
414 |
+
"epoch": 0.36,
|
415 |
+
"learning_rate": 7.373709111476498e-06,
|
416 |
+
"loss": 0.635,
|
417 |
+
"step": 68
|
418 |
+
},
|
419 |
+
{
|
420 |
+
"epoch": 0.37,
|
421 |
+
"learning_rate": 7.296974337378209e-06,
|
422 |
+
"loss": 0.6369,
|
423 |
+
"step": 69
|
424 |
+
},
|
425 |
+
{
|
426 |
+
"epoch": 0.37,
|
427 |
+
"learning_rate": 7.219547592372512e-06,
|
428 |
+
"loss": 0.6783,
|
429 |
+
"step": 70
|
430 |
+
},
|
431 |
+
{
|
432 |
+
"epoch": 0.38,
|
433 |
+
"learning_rate": 7.141452201515386e-06,
|
434 |
+
"loss": 0.7469,
|
435 |
+
"step": 71
|
436 |
+
},
|
437 |
+
{
|
438 |
+
"epoch": 0.38,
|
439 |
+
"learning_rate": 7.062711691294525e-06,
|
440 |
+
"loss": 0.6119,
|
441 |
+
"step": 72
|
442 |
+
},
|
443 |
+
{
|
444 |
+
"epoch": 0.39,
|
445 |
+
"learning_rate": 6.983349782541901e-06,
|
446 |
+
"loss": 0.6803,
|
447 |
+
"step": 73
|
448 |
+
},
|
449 |
+
{
|
450 |
+
"epoch": 0.4,
|
451 |
+
"learning_rate": 6.903390383287795e-06,
|
452 |
+
"loss": 0.6944,
|
453 |
+
"step": 74
|
454 |
+
},
|
455 |
+
{
|
456 |
+
"epoch": 0.4,
|
457 |
+
"learning_rate": 6.822857581558423e-06,
|
458 |
+
"loss": 0.662,
|
459 |
+
"step": 75
|
460 |
+
},
|
461 |
+
{
|
462 |
+
"epoch": 0.41,
|
463 |
+
"learning_rate": 6.741775638119345e-06,
|
464 |
+
"loss": 0.7491,
|
465 |
+
"step": 76
|
466 |
+
},
|
467 |
+
{
|
468 |
+
"epoch": 0.41,
|
469 |
+
"learning_rate": 6.66016897916682e-06,
|
470 |
+
"loss": 0.7162,
|
471 |
+
"step": 77
|
472 |
+
},
|
473 |
+
{
|
474 |
+
"epoch": 0.42,
|
475 |
+
"learning_rate": 6.57806218896935e-06,
|
476 |
+
"loss": 0.7451,
|
477 |
+
"step": 78
|
478 |
+
},
|
479 |
+
{
|
480 |
+
"epoch": 0.42,
|
481 |
+
"learning_rate": 6.495480002461577e-06,
|
482 |
+
"loss": 0.5831,
|
483 |
+
"step": 79
|
484 |
+
},
|
485 |
+
{
|
486 |
+
"epoch": 0.43,
|
487 |
+
"learning_rate": 6.412447297792818e-06,
|
488 |
+
"loss": 0.7656,
|
489 |
+
"step": 80
|
490 |
+
},
|
491 |
+
{
|
492 |
+
"epoch": 0.43,
|
493 |
+
"learning_rate": 6.328989088832431e-06,
|
494 |
+
"loss": 0.6309,
|
495 |
+
"step": 81
|
496 |
+
},
|
497 |
+
{
|
498 |
+
"epoch": 0.44,
|
499 |
+
"learning_rate": 6.245130517634307e-06,
|
500 |
+
"loss": 0.719,
|
501 |
+
"step": 82
|
502 |
+
},
|
503 |
+
{
|
504 |
+
"epoch": 0.44,
|
505 |
+
"learning_rate": 6.160896846862754e-06,
|
506 |
+
"loss": 0.6109,
|
507 |
+
"step": 83
|
508 |
+
},
|
509 |
+
{
|
510 |
+
"epoch": 0.45,
|
511 |
+
"learning_rate": 6.076313452182033e-06,
|
512 |
+
"loss": 0.6539,
|
513 |
+
"step": 84
|
514 |
+
},
|
515 |
+
{
|
516 |
+
"epoch": 0.45,
|
517 |
+
"learning_rate": 5.991405814611855e-06,
|
518 |
+
"loss": 0.6642,
|
519 |
+
"step": 85
|
520 |
+
},
|
521 |
+
{
|
522 |
+
"epoch": 0.46,
|
523 |
+
"learning_rate": 5.9061995128511455e-06,
|
524 |
+
"loss": 0.6832,
|
525 |
+
"step": 86
|
526 |
+
},
|
527 |
+
{
|
528 |
+
"epoch": 0.46,
|
529 |
+
"learning_rate": 5.820720215572375e-06,
|
530 |
+
"loss": 0.7376,
|
531 |
+
"step": 87
|
532 |
+
},
|
533 |
+
{
|
534 |
+
"epoch": 0.47,
|
535 |
+
"learning_rate": 5.734993673688801e-06,
|
536 |
+
"loss": 0.6968,
|
537 |
+
"step": 88
|
538 |
+
},
|
539 |
+
{
|
540 |
+
"epoch": 0.48,
|
541 |
+
"learning_rate": 5.6490457125969035e-06,
|
542 |
+
"loss": 0.743,
|
543 |
+
"step": 89
|
544 |
+
},
|
545 |
+
{
|
546 |
+
"epoch": 0.48,
|
547 |
+
"learning_rate": 5.562902224396416e-06,
|
548 |
+
"loss": 0.6412,
|
549 |
+
"step": 90
|
550 |
+
},
|
551 |
+
{
|
552 |
+
"epoch": 0.49,
|
553 |
+
"learning_rate": 5.476589160090238e-06,
|
554 |
+
"loss": 0.6313,
|
555 |
+
"step": 91
|
556 |
+
},
|
557 |
+
{
|
558 |
+
"epoch": 0.49,
|
559 |
+
"learning_rate": 5.390132521766626e-06,
|
560 |
+
"loss": 0.7327,
|
561 |
+
"step": 92
|
562 |
+
},
|
563 |
+
{
|
564 |
+
"epoch": 0.5,
|
565 |
+
"learning_rate": 5.30355835476596e-06,
|
566 |
+
"loss": 0.5725,
|
567 |
+
"step": 93
|
568 |
+
},
|
569 |
+
{
|
570 |
+
"epoch": 0.5,
|
571 |
+
"learning_rate": 5.216892739834519e-06,
|
572 |
+
"loss": 0.6587,
|
573 |
+
"step": 94
|
574 |
+
},
|
575 |
+
{
|
576 |
+
"epoch": 0.51,
|
577 |
+
"learning_rate": 5.13016178526756e-06,
|
578 |
+
"loss": 0.5893,
|
579 |
+
"step": 95
|
580 |
+
},
|
581 |
+
{
|
582 |
+
"epoch": 0.51,
|
583 |
+
"learning_rate": 5.043391619044122e-06,
|
584 |
+
"loss": 0.6343,
|
585 |
+
"step": 96
|
586 |
+
},
|
587 |
+
{
|
588 |
+
"epoch": 0.52,
|
589 |
+
"learning_rate": 4.956608380955877e-06,
|
590 |
+
"loss": 0.6764,
|
591 |
+
"step": 97
|
592 |
+
},
|
593 |
+
{
|
594 |
+
"epoch": 0.52,
|
595 |
+
"learning_rate": 4.869838214732441e-06,
|
596 |
+
"loss": 0.6486,
|
597 |
+
"step": 98
|
598 |
+
},
|
599 |
+
{
|
600 |
+
"epoch": 0.53,
|
601 |
+
"learning_rate": 4.783107260165483e-06,
|
602 |
+
"loss": 0.5761,
|
603 |
+
"step": 99
|
604 |
+
},
|
605 |
+
{
|
606 |
+
"epoch": 0.53,
|
607 |
+
"learning_rate": 4.696441645234042e-06,
|
608 |
+
"loss": 0.6928,
|
609 |
+
"step": 100
|
610 |
+
},
|
611 |
+
{
|
612 |
+
"epoch": 0.54,
|
613 |
+
"learning_rate": 4.609867478233377e-06,
|
614 |
+
"loss": 0.635,
|
615 |
+
"step": 101
|
616 |
+
},
|
617 |
+
{
|
618 |
+
"epoch": 0.54,
|
619 |
+
"learning_rate": 4.523410839909764e-06,
|
620 |
+
"loss": 0.6497,
|
621 |
+
"step": 102
|
622 |
+
},
|
623 |
+
{
|
624 |
+
"epoch": 0.55,
|
625 |
+
"learning_rate": 4.437097775603587e-06,
|
626 |
+
"loss": 0.6256,
|
627 |
+
"step": 103
|
628 |
+
},
|
629 |
+
{
|
630 |
+
"epoch": 0.56,
|
631 |
+
"learning_rate": 4.350954287403099e-06,
|
632 |
+
"loss": 0.6605,
|
633 |
+
"step": 104
|
634 |
+
},
|
635 |
+
{
|
636 |
+
"epoch": 0.56,
|
637 |
+
"learning_rate": 4.265006326311199e-06,
|
638 |
+
"loss": 0.6458,
|
639 |
+
"step": 105
|
640 |
+
},
|
641 |
+
{
|
642 |
+
"epoch": 0.57,
|
643 |
+
"learning_rate": 4.179279784427625e-06,
|
644 |
+
"loss": 0.6496,
|
645 |
+
"step": 106
|
646 |
+
},
|
647 |
+
{
|
648 |
+
"epoch": 0.57,
|
649 |
+
"learning_rate": 4.093800487148857e-06,
|
650 |
+
"loss": 0.6505,
|
651 |
+
"step": 107
|
652 |
+
},
|
653 |
+
{
|
654 |
+
"epoch": 0.58,
|
655 |
+
"learning_rate": 4.008594185388146e-06,
|
656 |
+
"loss": 0.6259,
|
657 |
+
"step": 108
|
658 |
+
},
|
659 |
+
{
|
660 |
+
"epoch": 0.58,
|
661 |
+
"learning_rate": 3.9236865478179685e-06,
|
662 |
+
"loss": 0.6401,
|
663 |
+
"step": 109
|
664 |
+
},
|
665 |
+
{
|
666 |
+
"epoch": 0.59,
|
667 |
+
"learning_rate": 3.839103153137247e-06,
|
668 |
+
"loss": 0.5964,
|
669 |
+
"step": 110
|
670 |
+
},
|
671 |
+
{
|
672 |
+
"epoch": 0.59,
|
673 |
+
"learning_rate": 3.7548694823656945e-06,
|
674 |
+
"loss": 0.6578,
|
675 |
+
"step": 111
|
676 |
+
},
|
677 |
+
{
|
678 |
+
"epoch": 0.6,
|
679 |
+
"learning_rate": 3.671010911167572e-06,
|
680 |
+
"loss": 0.6361,
|
681 |
+
"step": 112
|
682 |
+
},
|
683 |
+
{
|
684 |
+
"epoch": 0.6,
|
685 |
+
"learning_rate": 3.5875527022071808e-06,
|
686 |
+
"loss": 0.6685,
|
687 |
+
"step": 113
|
688 |
+
},
|
689 |
+
{
|
690 |
+
"epoch": 0.61,
|
691 |
+
"learning_rate": 3.5045199975384225e-06,
|
692 |
+
"loss": 0.5762,
|
693 |
+
"step": 114
|
694 |
+
},
|
695 |
+
{
|
696 |
+
"epoch": 0.61,
|
697 |
+
"learning_rate": 3.4219378110306523e-06,
|
698 |
+
"loss": 0.5239,
|
699 |
+
"step": 115
|
700 |
+
},
|
701 |
+
{
|
702 |
+
"epoch": 0.62,
|
703 |
+
"learning_rate": 3.3398310208331806e-06,
|
704 |
+
"loss": 0.6064,
|
705 |
+
"step": 116
|
706 |
+
},
|
707 |
+
{
|
708 |
+
"epoch": 0.62,
|
709 |
+
"learning_rate": 3.2582243618806574e-06,
|
710 |
+
"loss": 0.6162,
|
711 |
+
"step": 117
|
712 |
+
},
|
713 |
+
{
|
714 |
+
"epoch": 0.63,
|
715 |
+
"learning_rate": 3.177142418441578e-06,
|
716 |
+
"loss": 0.5415,
|
717 |
+
"step": 118
|
718 |
+
},
|
719 |
+
{
|
720 |
+
"epoch": 0.64,
|
721 |
+
"learning_rate": 3.096609616712207e-06,
|
722 |
+
"loss": 0.5404,
|
723 |
+
"step": 119
|
724 |
+
},
|
725 |
+
{
|
726 |
+
"epoch": 0.64,
|
727 |
+
"learning_rate": 3.0166502174581012e-06,
|
728 |
+
"loss": 0.6535,
|
729 |
+
"step": 120
|
730 |
+
},
|
731 |
+
{
|
732 |
+
"epoch": 0.65,
|
733 |
+
"learning_rate": 2.937288308705475e-06,
|
734 |
+
"loss": 0.6472,
|
735 |
+
"step": 121
|
736 |
+
},
|
737 |
+
{
|
738 |
+
"epoch": 0.65,
|
739 |
+
"learning_rate": 2.858547798484613e-06,
|
740 |
+
"loss": 0.5469,
|
741 |
+
"step": 122
|
742 |
+
},
|
743 |
+
{
|
744 |
+
"epoch": 0.66,
|
745 |
+
"learning_rate": 2.7804524076274898e-06,
|
746 |
+
"loss": 0.699,
|
747 |
+
"step": 123
|
748 |
+
},
|
749 |
+
{
|
750 |
+
"epoch": 0.66,
|
751 |
+
"learning_rate": 2.7030256626217932e-06,
|
752 |
+
"loss": 0.6188,
|
753 |
+
"step": 124
|
754 |
+
},
|
755 |
+
{
|
756 |
+
"epoch": 0.67,
|
757 |
+
"learning_rate": 2.6262908885235046e-06,
|
758 |
+
"loss": 0.6024,
|
759 |
+
"step": 125
|
760 |
+
},
|
761 |
+
{
|
762 |
+
"epoch": 0.67,
|
763 |
+
"learning_rate": 2.550271201930136e-06,
|
764 |
+
"loss": 0.5934,
|
765 |
+
"step": 126
|
766 |
+
},
|
767 |
+
{
|
768 |
+
"epoch": 0.68,
|
769 |
+
"learning_rate": 2.474989504016798e-06,
|
770 |
+
"loss": 0.6309,
|
771 |
+
"step": 127
|
772 |
+
},
|
773 |
+
{
|
774 |
+
"epoch": 0.68,
|
775 |
+
"learning_rate": 2.4004684736371276e-06,
|
776 |
+
"loss": 0.6157,
|
777 |
+
"step": 128
|
778 |
+
},
|
779 |
+
{
|
780 |
+
"epoch": 0.69,
|
781 |
+
"learning_rate": 2.32673056049123e-06,
|
782 |
+
"loss": 0.6419,
|
783 |
+
"step": 129
|
784 |
+
},
|
785 |
+
{
|
786 |
+
"epoch": 0.69,
|
787 |
+
"learning_rate": 2.253797978362617e-06,
|
788 |
+
"loss": 0.5713,
|
789 |
+
"step": 130
|
790 |
+
},
|
791 |
+
{
|
792 |
+
"epoch": 0.7,
|
793 |
+
"learning_rate": 2.1816926984262454e-06,
|
794 |
+
"loss": 0.6575,
|
795 |
+
"step": 131
|
796 |
+
},
|
797 |
+
{
|
798 |
+
"epoch": 0.7,
|
799 |
+
"learning_rate": 2.1104364426296237e-06,
|
800 |
+
"loss": 0.6824,
|
801 |
+
"step": 132
|
802 |
+
},
|
803 |
+
{
|
804 |
+
"epoch": 0.71,
|
805 |
+
"learning_rate": 2.040050677149008e-06,
|
806 |
+
"loss": 0.6374,
|
807 |
+
"step": 133
|
808 |
+
},
|
809 |
+
{
|
810 |
+
"epoch": 0.72,
|
811 |
+
"learning_rate": 1.970556605922645e-06,
|
812 |
+
"loss": 0.5793,
|
813 |
+
"step": 134
|
814 |
+
},
|
815 |
+
{
|
816 |
+
"epoch": 0.72,
|
817 |
+
"learning_rate": 1.9019751642630252e-06,
|
818 |
+
"loss": 0.5671,
|
819 |
+
"step": 135
|
820 |
+
},
|
821 |
+
{
|
822 |
+
"epoch": 0.73,
|
823 |
+
"learning_rate": 1.8343270125500379e-06,
|
824 |
+
"loss": 0.6081,
|
825 |
+
"step": 136
|
826 |
+
},
|
827 |
+
{
|
828 |
+
"epoch": 0.73,
|
829 |
+
"learning_rate": 1.7676325300069824e-06,
|
830 |
+
"loss": 0.5478,
|
831 |
+
"step": 137
|
832 |
+
},
|
833 |
+
{
|
834 |
+
"epoch": 0.74,
|
835 |
+
"learning_rate": 1.7019118085612474e-06,
|
836 |
+
"loss": 0.6179,
|
837 |
+
"step": 138
|
838 |
+
},
|
839 |
+
{
|
840 |
+
"epoch": 0.74,
|
841 |
+
"learning_rate": 1.6371846467915603e-06,
|
842 |
+
"loss": 0.6195,
|
843 |
+
"step": 139
|
844 |
+
},
|
845 |
+
{
|
846 |
+
"epoch": 0.75,
|
847 |
+
"learning_rate": 1.5734705439636017e-06,
|
848 |
+
"loss": 0.6112,
|
849 |
+
"step": 140
|
850 |
+
},
|
851 |
+
{
|
852 |
+
"epoch": 0.75,
|
853 |
+
"learning_rate": 1.5107886941557853e-06,
|
854 |
+
"loss": 0.5949,
|
855 |
+
"step": 141
|
856 |
+
},
|
857 |
+
{
|
858 |
+
"epoch": 0.76,
|
859 |
+
"learning_rate": 1.4491579804769817e-06,
|
860 |
+
"loss": 0.6089,
|
861 |
+
"step": 142
|
862 |
+
},
|
863 |
+
{
|
864 |
+
"epoch": 0.76,
|
865 |
+
"learning_rate": 1.3885969693779277e-06,
|
866 |
+
"loss": 0.6091,
|
867 |
+
"step": 143
|
868 |
+
},
|
869 |
+
{
|
870 |
+
"epoch": 0.77,
|
871 |
+
"learning_rate": 1.3291239050580085e-06,
|
872 |
+
"loss": 0.6016,
|
873 |
+
"step": 144
|
874 |
+
},
|
875 |
+
{
|
876 |
+
"epoch": 0.77,
|
877 |
+
"learning_rate": 1.2707567039691505e-06,
|
878 |
+
"loss": 0.6553,
|
879 |
+
"step": 145
|
880 |
+
},
|
881 |
+
{
|
882 |
+
"epoch": 0.78,
|
883 |
+
"learning_rate": 1.213512949418419e-06,
|
884 |
+
"loss": 0.633,
|
885 |
+
"step": 146
|
886 |
+
},
|
887 |
+
{
|
888 |
+
"epoch": 0.79,
|
889 |
+
"learning_rate": 1.1574098862709993e-06,
|
890 |
+
"loss": 0.6151,
|
891 |
+
"step": 147
|
892 |
+
},
|
893 |
+
{
|
894 |
+
"epoch": 0.79,
|
895 |
+
"learning_rate": 1.1024644157551206e-06,
|
896 |
+
"loss": 0.6017,
|
897 |
+
"step": 148
|
898 |
+
},
|
899 |
+
{
|
900 |
+
"epoch": 0.8,
|
901 |
+
"learning_rate": 1.0486930903705095e-06,
|
902 |
+
"loss": 0.5714,
|
903 |
+
"step": 149
|
904 |
+
},
|
905 |
+
{
|
906 |
+
"epoch": 0.8,
|
907 |
+
"learning_rate": 9.961121089018933e-07,
|
908 |
+
"loss": 0.6089,
|
909 |
+
"step": 150
|
910 |
+
},
|
911 |
+
{
|
912 |
+
"epoch": 0.81,
|
913 |
+
"learning_rate": 9.447373115390702e-07,
|
914 |
+
"loss": 0.5789,
|
915 |
+
"step": 151
|
916 |
+
},
|
917 |
+
{
|
918 |
+
"epoch": 0.81,
|
919 |
+
"learning_rate": 8.945841751049916e-07,
|
920 |
+
"loss": 0.6077,
|
921 |
+
"step": 152
|
922 |
+
},
|
923 |
+
{
|
924 |
+
"epoch": 0.82,
|
925 |
+
"learning_rate": 8.45667808393329e-07,
|
926 |
+
"loss": 0.5956,
|
927 |
+
"step": 153
|
928 |
+
},
|
929 |
+
{
|
930 |
+
"epoch": 0.82,
|
931 |
+
"learning_rate": 7.980029476168943e-07,
|
932 |
+
"loss": 0.4987,
|
933 |
+
"step": 154
|
934 |
+
},
|
935 |
+
{
|
936 |
+
"epoch": 0.83,
|
937 |
+
"learning_rate": 7.516039519683105e-07,
|
938 |
+
"loss": 0.5695,
|
939 |
+
"step": 155
|
940 |
+
},
|
941 |
+
{
|
942 |
+
"epoch": 0.83,
|
943 |
+
"learning_rate": 7.064847992942614e-07,
|
944 |
+
"loss": 0.6005,
|
945 |
+
"step": 156
|
946 |
+
},
|
947 |
+
{
|
948 |
+
"epoch": 0.84,
|
949 |
+
"learning_rate": 6.626590818846163e-07,
|
950 |
+
"loss": 0.569,
|
951 |
+
"step": 157
|
952 |
+
},
|
953 |
+
{
|
954 |
+
"epoch": 0.84,
|
955 |
+
"learning_rate": 6.201400023777105e-07,
|
956 |
+
"loss": 0.5699,
|
957 |
+
"step": 158
|
958 |
+
},
|
959 |
+
{
|
960 |
+
"epoch": 0.85,
|
961 |
+
"learning_rate": 5.789403697830104e-07,
|
962 |
+
"loss": 0.7248,
|
963 |
+
"step": 159
|
964 |
+
},
|
965 |
+
{
|
966 |
+
"epoch": 0.85,
|
967 |
+
"learning_rate": 5.390725956223531e-07,
|
968 |
+
"loss": 0.7469,
|
969 |
+
"step": 160
|
970 |
+
},
|
971 |
+
{
|
972 |
+
"epoch": 0.86,
|
973 |
+
"learning_rate": 5.005486901909429e-07,
|
974 |
+
"loss": 0.7068,
|
975 |
+
"step": 161
|
976 |
+
},
|
977 |
+
{
|
978 |
+
"epoch": 0.87,
|
979 |
+
"learning_rate": 4.6338025893920167e-07,
|
980 |
+
"loss": 0.5459,
|
981 |
+
"step": 162
|
982 |
+
},
|
983 |
+
{
|
984 |
+
"epoch": 0.87,
|
985 |
+
"learning_rate": 4.275784989765985e-07,
|
986 |
+
"loss": 0.6079,
|
987 |
+
"step": 163
|
988 |
+
},
|
989 |
+
{
|
990 |
+
"epoch": 0.88,
|
991 |
+
"learning_rate": 3.93154195698478e-07,
|
992 |
+
"loss": 0.5906,
|
993 |
+
"step": 164
|
994 |
+
},
|
995 |
+
{
|
996 |
+
"epoch": 0.88,
|
997 |
+
"learning_rate": 3.6011771953693044e-07,
|
998 |
+
"loss": 0.6135,
|
999 |
+
"step": 165
|
1000 |
+
},
|
1001 |
+
{
|
1002 |
+
"epoch": 0.89,
|
1003 |
+
"learning_rate": 3.284790228366602e-07,
|
1004 |
+
"loss": 0.4893,
|
1005 |
+
"step": 166
|
1006 |
+
},
|
1007 |
+
{
|
1008 |
+
"epoch": 0.89,
|
1009 |
+
"learning_rate": 2.982476368568177e-07,
|
1010 |
+
"loss": 0.5953,
|
1011 |
+
"step": 167
|
1012 |
+
},
|
1013 |
+
{
|
1014 |
+
"epoch": 0.9,
|
1015 |
+
"learning_rate": 2.6943266889966624e-07,
|
1016 |
+
"loss": 0.6246,
|
1017 |
+
"step": 168
|
1018 |
+
},
|
1019 |
+
{
|
1020 |
+
"epoch": 0.9,
|
1021 |
+
"learning_rate": 2.4204279956698994e-07,
|
1022 |
+
"loss": 0.5807,
|
1023 |
+
"step": 169
|
1024 |
+
},
|
1025 |
+
{
|
1026 |
+
"epoch": 0.91,
|
1027 |
+
"learning_rate": 2.1608628014502364e-07,
|
1028 |
+
"loss": 0.6546,
|
1029 |
+
"step": 170
|
1030 |
+
},
|
1031 |
+
{
|
1032 |
+
"epoch": 0.91,
|
1033 |
+
"learning_rate": 1.915709301187335e-07,
|
1034 |
+
"loss": 0.5779,
|
1035 |
+
"step": 171
|
1036 |
+
},
|
1037 |
+
{
|
1038 |
+
"epoch": 0.92,
|
1039 |
+
"learning_rate": 1.6850413481616868e-07,
|
1040 |
+
"loss": 0.6579,
|
1041 |
+
"step": 172
|
1042 |
+
},
|
1043 |
+
{
|
1044 |
+
"epoch": 0.92,
|
1045 |
+
"learning_rate": 1.468928431836092e-07,
|
1046 |
+
"loss": 0.6288,
|
1047 |
+
"step": 173
|
1048 |
+
},
|
1049 |
+
{
|
1050 |
+
"epoch": 0.93,
|
1051 |
+
"learning_rate": 1.2674356569217282e-07,
|
1052 |
+
"loss": 0.6266,
|
1053 |
+
"step": 174
|
1054 |
+
},
|
1055 |
+
{
|
1056 |
+
"epoch": 0.93,
|
1057 |
+
"learning_rate": 1.080623723765134e-07,
|
1058 |
+
"loss": 0.6483,
|
1059 |
+
"step": 175
|
1060 |
+
},
|
1061 |
+
{
|
1062 |
+
"epoch": 0.94,
|
1063 |
+
"learning_rate": 9.085489100620737e-08,
|
1064 |
+
"loss": 0.5573,
|
1065 |
+
"step": 176
|
1066 |
+
},
|
1067 |
+
{
|
1068 |
+
"epoch": 0.95,
|
1069 |
+
"learning_rate": 7.512630539036502e-08,
|
1070 |
+
"loss": 0.5503,
|
1071 |
+
"step": 177
|
1072 |
+
},
|
1073 |
+
{
|
1074 |
+
"epoch": 0.95,
|
1075 |
+
"learning_rate": 6.088135381599414e-08,
|
1076 |
+
"loss": 0.6147,
|
1077 |
+
"step": 178
|
1078 |
+
},
|
1079 |
+
{
|
1080 |
+
"epoch": 0.96,
|
1081 |
+
"learning_rate": 4.8124327620576726e-08,
|
1082 |
+
"loss": 0.6135,
|
1083 |
+
"step": 179
|
1084 |
+
},
|
1085 |
+
{
|
1086 |
+
"epoch": 0.96,
|
1087 |
+
"learning_rate": 3.685906989928656e-08,
|
1088 |
+
"loss": 0.6094,
|
1089 |
+
"step": 180
|
1090 |
+
},
|
1091 |
+
{
|
1092 |
+
"epoch": 0.97,
|
1093 |
+
"learning_rate": 2.7088974347246888e-08,
|
1094 |
+
"loss": 0.6843,
|
1095 |
+
"step": 181
|
1096 |
+
},
|
1097 |
+
{
|
1098 |
+
"epoch": 0.97,
|
1099 |
+
"learning_rate": 1.8816984237169378e-08,
|
1100 |
+
"loss": 0.5167,
|
1101 |
+
"step": 182
|
1102 |
+
},
|
1103 |
+
{
|
1104 |
+
"epoch": 0.98,
|
1105 |
+
"learning_rate": 1.2045591532681145e-08,
|
1106 |
+
"loss": 0.5398,
|
1107 |
+
"step": 183
|
1108 |
+
},
|
1109 |
+
{
|
1110 |
+
"epoch": 0.98,
|
1111 |
+
"learning_rate": 6.7768361376152616e-09,
|
1112 |
+
"loss": 0.5482,
|
1113 |
+
"step": 184
|
1114 |
+
},
|
1115 |
+
{
|
1116 |
+
"epoch": 0.99,
|
1117 |
+
"learning_rate": 3.0123052814812203e-09,
|
1118 |
+
"loss": 0.6061,
|
1119 |
+
"step": 185
|
1120 |
+
},
|
1121 |
+
{
|
1122 |
+
"epoch": 0.99,
|
1123 |
+
"learning_rate": 7.53133041307974e-10,
|
1124 |
+
"loss": 0.5735,
|
1125 |
+
"step": 186
|
1126 |
+
},
|
1127 |
+
{
|
1128 |
+
"epoch": 1.0,
|
1129 |
+
"learning_rate": 0.0,
|
1130 |
+
"loss": 0.5353,
|
1131 |
+
"step": 187
|
1132 |
+
}
|
1133 |
+
],
|
1134 |
+
"log_save_evaluate_time": 3923.0782437324524,
|
1135 |
+
"logging_steps": 1.0,
|
1136 |
+
"max_steps": 187,
|
1137 |
+
"num_input_tokens_seen": 0,
|
1138 |
+
"num_train_epochs": 1,
|
1139 |
+
"save_steps": 40,
|
1140 |
+
"total_flos": 7.673772156280793e+19,
|
1141 |
+
"total_tokens": 900759552.0,
|
1142 |
+
"train_batch_size": 2,
|
1143 |
+
"trial_name": null,
|
1144 |
+
"trial_params": null
|
1145 |
+
}
|
training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f1988675517ebb16f267246bf5443edac3ac5ae51e2adab418627265a4ebccfe
|
3 |
+
size 6331
|
zero_to_fp32.py
ADDED
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright (c) Microsoft Corporation.
|
4 |
+
# SPDX-License-Identifier: Apache-2.0
|
5 |
+
|
6 |
+
# DeepSpeed Team
|
7 |
+
|
8 |
+
# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
|
9 |
+
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
|
10 |
+
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
|
11 |
+
# application.
|
12 |
+
#
|
13 |
+
# example: python zero_to_fp32.py . pytorch_model.bin
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import torch
|
17 |
+
import glob
|
18 |
+
import math
|
19 |
+
import os
|
20 |
+
import re
|
21 |
+
from collections import OrderedDict
|
22 |
+
from dataclasses import dataclass
|
23 |
+
|
24 |
+
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
|
25 |
+
# DeepSpeed data structures it has to be available in the current python environment.
|
26 |
+
from deepspeed.utils import logger
|
27 |
+
from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
|
28 |
+
FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
|
29 |
+
FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class zero_model_state:
|
34 |
+
buffers: dict()
|
35 |
+
param_shapes: dict()
|
36 |
+
shared_params: list
|
37 |
+
ds_version: int
|
38 |
+
frozen_param_shapes: dict()
|
39 |
+
frozen_param_fragments: dict()
|
40 |
+
|
41 |
+
|
42 |
+
debug = 0
|
43 |
+
|
44 |
+
# load to cpu
|
45 |
+
device = torch.device('cpu')
|
46 |
+
|
47 |
+
|
48 |
+
def atoi(text):
|
49 |
+
return int(text) if text.isdigit() else text
|
50 |
+
|
51 |
+
|
52 |
+
def natural_keys(text):
|
53 |
+
'''
|
54 |
+
alist.sort(key=natural_keys) sorts in human order
|
55 |
+
http://nedbatchelder.com/blog/200712/human_sorting.html
|
56 |
+
(See Toothy's implementation in the comments)
|
57 |
+
'''
|
58 |
+
return [atoi(c) for c in re.split(r'(\d+)', text)]
|
59 |
+
|
60 |
+
|
61 |
+
def get_model_state_file(checkpoint_dir, zero_stage):
|
62 |
+
if not os.path.isdir(checkpoint_dir):
|
63 |
+
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
|
64 |
+
|
65 |
+
# there should be only one file
|
66 |
+
if zero_stage <= 2:
|
67 |
+
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
|
68 |
+
elif zero_stage == 3:
|
69 |
+
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
|
70 |
+
|
71 |
+
if not os.path.exists(file):
|
72 |
+
raise FileNotFoundError(f"can't find model states file at '{file}'")
|
73 |
+
|
74 |
+
return file
|
75 |
+
|
76 |
+
|
77 |
+
def get_checkpoint_files(checkpoint_dir, glob_pattern):
|
78 |
+
# XXX: need to test that this simple glob rule works for multi-node setup too
|
79 |
+
ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
|
80 |
+
|
81 |
+
if len(ckpt_files) == 0:
|
82 |
+
raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
|
83 |
+
|
84 |
+
return ckpt_files
|
85 |
+
|
86 |
+
|
87 |
+
def get_optim_files(checkpoint_dir):
|
88 |
+
return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
|
89 |
+
|
90 |
+
|
91 |
+
def get_model_state_files(checkpoint_dir):
|
92 |
+
return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
|
93 |
+
|
94 |
+
|
95 |
+
def parse_model_states(files):
|
96 |
+
zero_model_states = []
|
97 |
+
for file in files:
|
98 |
+
state_dict = torch.load(file, map_location=device)
|
99 |
+
|
100 |
+
if BUFFER_NAMES not in state_dict:
|
101 |
+
raise ValueError(f"{file} is not a model state checkpoint")
|
102 |
+
buffer_names = state_dict[BUFFER_NAMES]
|
103 |
+
if debug:
|
104 |
+
print("Found buffers:", buffer_names)
|
105 |
+
|
106 |
+
# recover just the buffers while restoring them to fp32 if they were saved in fp16
|
107 |
+
buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
|
108 |
+
param_shapes = state_dict[PARAM_SHAPES]
|
109 |
+
|
110 |
+
# collect parameters that are included in param_shapes
|
111 |
+
param_names = []
|
112 |
+
for s in param_shapes:
|
113 |
+
for name in s.keys():
|
114 |
+
param_names.append(name)
|
115 |
+
|
116 |
+
# update with frozen parameters
|
117 |
+
frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
|
118 |
+
if frozen_param_shapes is not None:
|
119 |
+
if debug:
|
120 |
+
print(f"Found frozen_param_shapes: {frozen_param_shapes}")
|
121 |
+
param_names += list(frozen_param_shapes.keys())
|
122 |
+
|
123 |
+
# handle shared params
|
124 |
+
shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
|
125 |
+
|
126 |
+
ds_version = state_dict.get(DS_VERSION, None)
|
127 |
+
|
128 |
+
frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
|
129 |
+
|
130 |
+
z_model_state = zero_model_state(buffers=buffers,
|
131 |
+
param_shapes=param_shapes,
|
132 |
+
shared_params=shared_params,
|
133 |
+
ds_version=ds_version,
|
134 |
+
frozen_param_shapes=frozen_param_shapes,
|
135 |
+
frozen_param_fragments=frozen_param_fragments)
|
136 |
+
zero_model_states.append(z_model_state)
|
137 |
+
|
138 |
+
return zero_model_states
|
139 |
+
|
140 |
+
|
141 |
+
def parse_optim_states(files, ds_checkpoint_dir):
|
142 |
+
|
143 |
+
total_files = len(files)
|
144 |
+
state_dicts = []
|
145 |
+
for f in files:
|
146 |
+
state_dicts.append(torch.load(f, map_location=device))
|
147 |
+
|
148 |
+
if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
|
149 |
+
raise ValueError(f"{files[0]} is not a zero checkpoint")
|
150 |
+
zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
|
151 |
+
world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
|
152 |
+
|
153 |
+
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert
|
154 |
+
# parameters can be different from data parallelism for non-expert parameters. So we can just
|
155 |
+
# use the max of the partition_count to get the dp world_size.
|
156 |
+
|
157 |
+
if type(world_size) is list:
|
158 |
+
world_size = max(world_size)
|
159 |
+
|
160 |
+
if world_size != total_files:
|
161 |
+
raise ValueError(
|
162 |
+
f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
|
163 |
+
"Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
|
164 |
+
)
|
165 |
+
|
166 |
+
# the groups are named differently in each stage
|
167 |
+
if zero_stage <= 2:
|
168 |
+
fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
|
169 |
+
elif zero_stage == 3:
|
170 |
+
fp32_groups_key = FP32_FLAT_GROUPS
|
171 |
+
else:
|
172 |
+
raise ValueError(f"unknown zero stage {zero_stage}")
|
173 |
+
|
174 |
+
if zero_stage <= 2:
|
175 |
+
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
|
176 |
+
elif zero_stage == 3:
|
177 |
+
# if there is more than one param group, there will be multiple flattened tensors - one
|
178 |
+
# flattened tensor per group - for simplicity merge them into a single tensor
|
179 |
+
#
|
180 |
+
# XXX: could make the script more memory efficient for when there are multiple groups - it
|
181 |
+
# will require matching the sub-lists of param_shapes for each param group flattened tensor
|
182 |
+
|
183 |
+
fp32_flat_groups = [
|
184 |
+
torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
|
185 |
+
]
|
186 |
+
|
187 |
+
return zero_stage, world_size, fp32_flat_groups
|
188 |
+
|
189 |
+
|
190 |
+
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
|
191 |
+
"""
|
192 |
+
Returns fp32 state_dict reconstructed from ds checkpoint
|
193 |
+
|
194 |
+
Args:
|
195 |
+
- ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
|
196 |
+
|
197 |
+
"""
|
198 |
+
print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
|
199 |
+
|
200 |
+
optim_files = get_optim_files(ds_checkpoint_dir)
|
201 |
+
zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
|
202 |
+
print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
|
203 |
+
|
204 |
+
model_files = get_model_state_files(ds_checkpoint_dir)
|
205 |
+
|
206 |
+
zero_model_states = parse_model_states(model_files)
|
207 |
+
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
|
208 |
+
|
209 |
+
if zero_stage <= 2:
|
210 |
+
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
|
211 |
+
elif zero_stage == 3:
|
212 |
+
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
|
213 |
+
|
214 |
+
|
215 |
+
def _zero2_merge_frozen_params(state_dict, zero_model_states):
|
216 |
+
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
217 |
+
return
|
218 |
+
|
219 |
+
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
220 |
+
frozen_param_fragments = zero_model_states[0].frozen_param_fragments
|
221 |
+
|
222 |
+
if debug:
|
223 |
+
num_elem = sum(s.numel() for s in frozen_param_shapes.values())
|
224 |
+
print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
|
225 |
+
|
226 |
+
wanted_params = len(frozen_param_shapes)
|
227 |
+
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
228 |
+
avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
|
229 |
+
print(f'Frozen params: Have {avail_numel} numels to process.')
|
230 |
+
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
231 |
+
|
232 |
+
total_params = 0
|
233 |
+
total_numel = 0
|
234 |
+
for name, shape in frozen_param_shapes.items():
|
235 |
+
total_params += 1
|
236 |
+
unpartitioned_numel = shape.numel()
|
237 |
+
total_numel += unpartitioned_numel
|
238 |
+
|
239 |
+
state_dict[name] = frozen_param_fragments[name]
|
240 |
+
|
241 |
+
if debug:
|
242 |
+
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
243 |
+
|
244 |
+
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
|
245 |
+
|
246 |
+
|
247 |
+
def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
|
248 |
+
param_shapes = zero_model_states[0].param_shapes
|
249 |
+
|
250 |
+
# Reconstruction protocol:
|
251 |
+
#
|
252 |
+
# XXX: document this
|
253 |
+
|
254 |
+
if debug:
|
255 |
+
for i in range(world_size):
|
256 |
+
for j in range(len(fp32_flat_groups[0])):
|
257 |
+
print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
|
258 |
+
|
259 |
+
# XXX: memory usage doubles here (zero2)
|
260 |
+
num_param_groups = len(fp32_flat_groups[0])
|
261 |
+
merged_single_partition_of_fp32_groups = []
|
262 |
+
for i in range(num_param_groups):
|
263 |
+
merged_partitions = [sd[i] for sd in fp32_flat_groups]
|
264 |
+
full_single_fp32_vector = torch.cat(merged_partitions, 0)
|
265 |
+
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
|
266 |
+
avail_numel = sum(
|
267 |
+
[full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
|
268 |
+
|
269 |
+
if debug:
|
270 |
+
wanted_params = sum([len(shapes) for shapes in param_shapes])
|
271 |
+
wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
|
272 |
+
# not asserting if there is a mismatch due to possible padding
|
273 |
+
print(f"Have {avail_numel} numels to process.")
|
274 |
+
print(f"Need {wanted_numel} numels in {wanted_params} params.")
|
275 |
+
|
276 |
+
# params
|
277 |
+
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
|
278 |
+
# out-of-core computing solution
|
279 |
+
total_numel = 0
|
280 |
+
total_params = 0
|
281 |
+
for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
|
282 |
+
offset = 0
|
283 |
+
avail_numel = full_single_fp32_vector.numel()
|
284 |
+
for name, shape in shapes.items():
|
285 |
+
|
286 |
+
unpartitioned_numel = shape.numel()
|
287 |
+
total_numel += unpartitioned_numel
|
288 |
+
total_params += 1
|
289 |
+
|
290 |
+
if debug:
|
291 |
+
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
292 |
+
state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
|
293 |
+
offset += unpartitioned_numel
|
294 |
+
|
295 |
+
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
|
296 |
+
# avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
|
297 |
+
# paddings performed in the code it's almost impossible to predict the exact numbers w/o the
|
298 |
+
# live optimizer object, so we are checking that the numbers are within the right range
|
299 |
+
align_to = 2 * world_size
|
300 |
+
|
301 |
+
def zero2_align(x):
|
302 |
+
return align_to * math.ceil(x / align_to)
|
303 |
+
|
304 |
+
if debug:
|
305 |
+
print(f"original offset={offset}, avail_numel={avail_numel}")
|
306 |
+
|
307 |
+
offset = zero2_align(offset)
|
308 |
+
avail_numel = zero2_align(avail_numel)
|
309 |
+
|
310 |
+
if debug:
|
311 |
+
print(f"aligned offset={offset}, avail_numel={avail_numel}")
|
312 |
+
|
313 |
+
# Sanity check
|
314 |
+
if offset != avail_numel:
|
315 |
+
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
316 |
+
|
317 |
+
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
|
318 |
+
|
319 |
+
|
320 |
+
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states):
|
321 |
+
state_dict = OrderedDict()
|
322 |
+
|
323 |
+
# buffers
|
324 |
+
buffers = zero_model_states[0].buffers
|
325 |
+
state_dict.update(buffers)
|
326 |
+
if debug:
|
327 |
+
print(f"added {len(buffers)} buffers")
|
328 |
+
|
329 |
+
_zero2_merge_frozen_params(state_dict, zero_model_states)
|
330 |
+
|
331 |
+
_zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
332 |
+
|
333 |
+
# recover shared parameters
|
334 |
+
for pair in zero_model_states[0].shared_params:
|
335 |
+
if pair[1] in state_dict:
|
336 |
+
state_dict[pair[0]] = state_dict[pair[1]]
|
337 |
+
|
338 |
+
return state_dict
|
339 |
+
|
340 |
+
|
341 |
+
def zero3_partitioned_param_info(unpartitioned_numel, world_size):
|
342 |
+
remainder = unpartitioned_numel % world_size
|
343 |
+
padding_numel = (world_size - remainder) if remainder else 0
|
344 |
+
partitioned_numel = math.ceil(unpartitioned_numel / world_size)
|
345 |
+
return partitioned_numel, padding_numel
|
346 |
+
|
347 |
+
|
348 |
+
def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
|
349 |
+
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
350 |
+
return
|
351 |
+
|
352 |
+
if debug:
|
353 |
+
for i in range(world_size):
|
354 |
+
num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
|
355 |
+
print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
|
356 |
+
|
357 |
+
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
358 |
+
wanted_params = len(frozen_param_shapes)
|
359 |
+
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
360 |
+
avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
|
361 |
+
print(f'Frozen params: Have {avail_numel} numels to process.')
|
362 |
+
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
363 |
+
|
364 |
+
total_params = 0
|
365 |
+
total_numel = 0
|
366 |
+
for name, shape in zero_model_states[0].frozen_param_shapes.items():
|
367 |
+
total_params += 1
|
368 |
+
unpartitioned_numel = shape.numel()
|
369 |
+
total_numel += unpartitioned_numel
|
370 |
+
|
371 |
+
param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
|
372 |
+
state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
|
373 |
+
|
374 |
+
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
375 |
+
|
376 |
+
if debug:
|
377 |
+
print(
|
378 |
+
f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
|
379 |
+
)
|
380 |
+
|
381 |
+
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
|
382 |
+
|
383 |
+
|
384 |
+
def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
|
385 |
+
param_shapes = zero_model_states[0].param_shapes
|
386 |
+
avail_numel = fp32_flat_groups[0].numel() * world_size
|
387 |
+
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
|
388 |
+
# param, re-consolidating each param, while dealing with padding if any
|
389 |
+
|
390 |
+
# merge list of dicts, preserving order
|
391 |
+
param_shapes = {k: v for d in param_shapes for k, v in d.items()}
|
392 |
+
|
393 |
+
if debug:
|
394 |
+
for i in range(world_size):
|
395 |
+
print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
|
396 |
+
|
397 |
+
wanted_params = len(param_shapes)
|
398 |
+
wanted_numel = sum(shape.numel() for shape in param_shapes.values())
|
399 |
+
# not asserting if there is a mismatch due to possible padding
|
400 |
+
avail_numel = fp32_flat_groups[0].numel() * world_size
|
401 |
+
print(f"Trainable params: Have {avail_numel} numels to process.")
|
402 |
+
print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
|
403 |
+
|
404 |
+
# params
|
405 |
+
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
|
406 |
+
# out-of-core computing solution
|
407 |
+
offset = 0
|
408 |
+
total_numel = 0
|
409 |
+
total_params = 0
|
410 |
+
for name, shape in param_shapes.items():
|
411 |
+
|
412 |
+
unpartitioned_numel = shape.numel()
|
413 |
+
total_numel += unpartitioned_numel
|
414 |
+
total_params += 1
|
415 |
+
|
416 |
+
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
417 |
+
|
418 |
+
if debug:
|
419 |
+
print(
|
420 |
+
f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
|
421 |
+
)
|
422 |
+
|
423 |
+
# XXX: memory usage doubles here
|
424 |
+
state_dict[name] = torch.cat(
|
425 |
+
tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
|
426 |
+
0).narrow(0, 0, unpartitioned_numel).view(shape)
|
427 |
+
offset += partitioned_numel
|
428 |
+
|
429 |
+
offset *= world_size
|
430 |
+
|
431 |
+
# Sanity check
|
432 |
+
if offset != avail_numel:
|
433 |
+
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
434 |
+
|
435 |
+
print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
|
436 |
+
|
437 |
+
|
438 |
+
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states):
|
439 |
+
state_dict = OrderedDict()
|
440 |
+
|
441 |
+
# buffers
|
442 |
+
buffers = zero_model_states[0].buffers
|
443 |
+
state_dict.update(buffers)
|
444 |
+
if debug:
|
445 |
+
print(f"added {len(buffers)} buffers")
|
446 |
+
|
447 |
+
_zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
|
448 |
+
|
449 |
+
_zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
450 |
+
|
451 |
+
# recover shared parameters
|
452 |
+
for pair in zero_model_states[0].shared_params:
|
453 |
+
if pair[1] in state_dict:
|
454 |
+
state_dict[pair[0]] = state_dict[pair[1]]
|
455 |
+
|
456 |
+
return state_dict
|
457 |
+
|
458 |
+
|
459 |
+
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
|
460 |
+
"""
|
461 |
+
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
|
462 |
+
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
|
463 |
+
via a model hub.
|
464 |
+
|
465 |
+
Args:
|
466 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder
|
467 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
|
468 |
+
|
469 |
+
Returns:
|
470 |
+
- pytorch ``state_dict``
|
471 |
+
|
472 |
+
Note: this approach may not work if your application doesn't have sufficient free CPU memory and
|
473 |
+
you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
|
474 |
+
the checkpoint.
|
475 |
+
|
476 |
+
A typical usage might be ::
|
477 |
+
|
478 |
+
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
479 |
+
# do the training and checkpoint saving
|
480 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
|
481 |
+
model = model.cpu() # move to cpu
|
482 |
+
model.load_state_dict(state_dict)
|
483 |
+
# submit to model hub or save the model to share with others
|
484 |
+
|
485 |
+
In this example the ``model`` will no longer be usable in the deepspeed context of the same
|
486 |
+
application. i.e. you will need to re-initialize the deepspeed engine, since
|
487 |
+
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
|
488 |
+
|
489 |
+
If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
|
490 |
+
|
491 |
+
"""
|
492 |
+
if tag is None:
|
493 |
+
latest_path = os.path.join(checkpoint_dir, 'latest')
|
494 |
+
if os.path.isfile(latest_path):
|
495 |
+
with open(latest_path, 'r') as fd:
|
496 |
+
tag = fd.read().strip()
|
497 |
+
else:
|
498 |
+
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
|
499 |
+
|
500 |
+
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
|
501 |
+
|
502 |
+
if not os.path.isdir(ds_checkpoint_dir):
|
503 |
+
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
|
504 |
+
|
505 |
+
return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
|
506 |
+
|
507 |
+
|
508 |
+
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
|
509 |
+
"""
|
510 |
+
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
|
511 |
+
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
|
512 |
+
|
513 |
+
Args:
|
514 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
515 |
+
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
|
516 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
517 |
+
"""
|
518 |
+
|
519 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
|
520 |
+
print(f"Saving fp32 state dict to {output_file}")
|
521 |
+
torch.save(state_dict, output_file)
|
522 |
+
|
523 |
+
|
524 |
+
def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
|
525 |
+
"""
|
526 |
+
1. Put the provided model to cpu
|
527 |
+
2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
|
528 |
+
3. Load it into the provided model
|
529 |
+
|
530 |
+
Args:
|
531 |
+
- ``model``: the model object to update
|
532 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
533 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
534 |
+
|
535 |
+
Returns:
|
536 |
+
- ``model`: modified model
|
537 |
+
|
538 |
+
Make sure you have plenty of CPU memory available before you call this function. If you don't
|
539 |
+
have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
|
540 |
+
conveniently placed for you in the checkpoint folder.
|
541 |
+
|
542 |
+
A typical usage might be ::
|
543 |
+
|
544 |
+
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
|
545 |
+
model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
|
546 |
+
# submit to model hub or save the model to share with others
|
547 |
+
|
548 |
+
Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
|
549 |
+
of the same application. i.e. you will need to re-initialize the deepspeed engine, since
|
550 |
+
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
|
551 |
+
|
552 |
+
"""
|
553 |
+
logger.info(f"Extracting fp32 weights")
|
554 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
|
555 |
+
|
556 |
+
logger.info(f"Overwriting model with fp32 weights")
|
557 |
+
model = model.cpu()
|
558 |
+
model.load_state_dict(state_dict, strict=False)
|
559 |
+
|
560 |
+
return model
|
561 |
+
|
562 |
+
|
563 |
+
if __name__ == "__main__":
|
564 |
+
|
565 |
+
parser = argparse.ArgumentParser()
|
566 |
+
parser.add_argument("checkpoint_dir",
|
567 |
+
type=str,
|
568 |
+
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
|
569 |
+
parser.add_argument(
|
570 |
+
"output_file",
|
571 |
+
type=str,
|
572 |
+
help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
|
573 |
+
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
|
574 |
+
args = parser.parse_args()
|
575 |
+
|
576 |
+
debug = args.debug
|
577 |
+
|
578 |
+
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file)
|