Clemspace commited on
Commit
cb9e677
1 Parent(s): 0bdd0e6

Initial model upload

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +201 -0
  2. README.md +497 -3
  3. __pycache__/train.cpython-310.pyc +0 -0
  4. example/7B.yaml +37 -0
  5. finetune/__init__.py +0 -0
  6. finetune/__pycache__/__init__.cpython-310.pyc +0 -0
  7. finetune/__pycache__/__init__.cpython-38.pyc +0 -0
  8. finetune/__pycache__/args.cpython-310.pyc +0 -0
  9. finetune/__pycache__/args.cpython-38.pyc +0 -0
  10. finetune/__pycache__/checkpointing.cpython-310.pyc +0 -0
  11. finetune/__pycache__/checkpointing.cpython-38.pyc +0 -0
  12. finetune/__pycache__/distributed.cpython-310.pyc +0 -0
  13. finetune/__pycache__/distributed.cpython-38.pyc +0 -0
  14. finetune/__pycache__/eval.cpython-310.pyc +0 -0
  15. finetune/__pycache__/loss.cpython-310.pyc +0 -0
  16. finetune/__pycache__/mixed_precision.cpython-310.pyc +0 -0
  17. finetune/__pycache__/utils.cpython-310.pyc +0 -0
  18. finetune/__pycache__/wrapped_model.cpython-310.pyc +0 -0
  19. finetune/args.py +116 -0
  20. finetune/checkpointing.py +246 -0
  21. finetune/data/__init__.py +0 -0
  22. finetune/data/__pycache__/__init__.cpython-310.pyc +0 -0
  23. finetune/data/__pycache__/__init__.cpython-38.pyc +0 -0
  24. finetune/data/__pycache__/args.cpython-310.pyc +0 -0
  25. finetune/data/__pycache__/args.cpython-38.pyc +0 -0
  26. finetune/data/__pycache__/data_loader.cpython-310.pyc +0 -0
  27. finetune/data/__pycache__/dataset.cpython-310.pyc +0 -0
  28. finetune/data/__pycache__/dataset.cpython-38.pyc +0 -0
  29. finetune/data/__pycache__/exceptions.cpython-310.pyc +0 -0
  30. finetune/data/__pycache__/exceptions.cpython-38.pyc +0 -0
  31. finetune/data/__pycache__/tokenize.cpython-310.pyc +0 -0
  32. finetune/data/__pycache__/tokenize.cpython-38.pyc +0 -0
  33. finetune/data/args.py +61 -0
  34. finetune/data/data_loader.py +126 -0
  35. finetune/data/dataset.py +475 -0
  36. finetune/data/exceptions.py +56 -0
  37. finetune/data/tokenize.py +355 -0
  38. finetune/distributed.py +59 -0
  39. finetune/eval.py +77 -0
  40. finetune/loss.py +16 -0
  41. finetune/mixed_precision.py +47 -0
  42. finetune/monitoring/__init__.py +0 -0
  43. finetune/monitoring/__pycache__/__init__.cpython-310.pyc +0 -0
  44. finetune/monitoring/__pycache__/metrics_logger.cpython-310.pyc +0 -0
  45. finetune/monitoring/__pycache__/utils.cpython-310.pyc +0 -0
  46. finetune/monitoring/metrics_logger.py +226 -0
  47. finetune/monitoring/utils.py +34 -0
  48. finetune/utils.py +83 -0
  49. finetune/wrapped_model.py +227 -0
  50. huggingface.ipynb +40 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,497 @@
1
- ---
2
- license: gpl-3.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mistral-finetune
2
+
3
+ <a target="_blank" href="https://colab.research.google.com/github/mistralai/mistral-finetune/blob/main/tutorials/mistral_finetune_7b.ipynb">
4
+ <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
5
+ </a>
6
+
7
+
8
+ `mistral-finetune` is a light-weight codebase that enables memory-efficient and performant finetuning of Mistral's models.
9
+ It is based on [LoRA](https://arxiv.org/abs/2106.09685), a training paradigm where most weights are frozen and only 1-2% additional weights in the form of low-rank matrix perturbations are trained.
10
+
11
+ For maximum efficiency it is recommended to use a A100 or H100 GPU. The codebase is optimized
12
+ for multi-GPU-single-node training setups, but for smaller models, such as the 7B a single GPU suffices.
13
+
14
+ > **Note**
15
+ >
16
+ > - The goal of this repository is to provide a simple, guided entrypoint to finetune Mistral models.
17
+ > As such, it is fairly opinionated (especially around data formatting) and does not aim at being exhaustive
18
+ > across multiple model architecture or hardware types.
19
+ > For more generic approaches, you can check out some other great projects like
20
+ > [torchtune](https://pytorch.org/torchtune/stable/overview.html).
21
+
22
+ ## Installation
23
+
24
+ To get started with Mistral LoRA fine-tuning, follow these steps:
25
+
26
+ 1. Clone this repository:
27
+ ```
28
+ cd $HOME && git clone https://github.com/mistralai/mistral-finetune.git
29
+ ```
30
+
31
+ 2. Install all required dependencies:
32
+ ```
33
+ cd mistral-finetune
34
+ pip install -r requirements.txt
35
+ ```
36
+
37
+ ## Model download
38
+
39
+ We recommend fine-tuning one of the official Mistral models which you can download here:
40
+
41
+ | Model | Link | Checksum |
42
+ |----------------|---------------------------------------------------------------------------------------------------------|-----------------------------------|
43
+ | 7B Base V3 | [7B Base](https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-v0.3.tar) | `0663b293810d7571dad25dae2f2a5806`|
44
+ | 7B Instruct v3 | [7B Instruct v3](https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-Instruct-v0.3.tar) | `80b71fcb6416085bcb4efad86dfb4d52`|
45
+ | 8x7B Base V1 | [8x7B Base](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) | (HF link) |
46
+ | 8x7B Instruct V1 | [8x7B Instruct](https://models.mistralcdn.com/mixtral-8x7b-v0-1/Mixtral-8x7B-v0.1-Instruct.tar) | `8e2d3930145dc43d3084396f49d38a3f` |
47
+ | 8x22 Instruct V3 | [8x22 Instruct](https://models.mistralcdn.com/mixtral-8x22b-v0-3/mixtral-8x22B-Instruct-v0.3.tar) | `471a02a6902706a2f1e44a693813855b`|
48
+ | 8x22B Base V3 | [8x22B Base](https://models.mistralcdn.com/mixtral-8x22b-v0-3/mixtral-8x22B-v0.3.tar) | `a2fa75117174f87d1197e3a4eb50371a`|
49
+
50
+ **Important Notice**: For 8x7B Base V1 and 8x7B Instruct V1, it is necessary to use our v3 tokenizer and extend the vocabulary size to 32768 prior to fine-tuning. For detailed instructions on this process, please refer to the ["Model extension"](https://github.com/mistralai/mistral-finetune?tab=readme-ov-file#model-extension) section.
51
+
52
+ E.g., to download the 7B-base model you can run the following command:
53
+ ```sh
54
+ mkdir -p ~/${HOME}/mistral_models
55
+ cd ${HOME} && wget https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-v0.3.tar
56
+ tar -xf mistral-7B-v0.3.tar -C mistral_models
57
+ ```
58
+
59
+ Make sure to modify your training script and add the path to the downloaded
60
+ folder as `model_id_or_path`.
61
+
62
+ E.g., modify [example/7B.yaml](https://github.com/mistralai/mistral-finetune/blob/main/example/7B.yaml) to include the absolute path to `$HOME/mistral_models/7B`:
63
+
64
+ ```
65
+ model_id_or_path: "/Users/johndoe/mistral_models/7B"
66
+ ```
67
+
68
+ ## Prepare dataset
69
+
70
+ To ensure effective training, `mistral-finetune` has strict
71
+ requirements for how the training data has to be formatted.
72
+
73
+ All data files must be stored in jsonl format files.
74
+
75
+ You can build two types of data files:
76
+
77
+ ### _Pretrain_:
78
+
79
+ Pretrain data corresponds to plain text data stored in the `"text"` key. E.g:
80
+
81
+ ```jsonl
82
+ {"text": "Text contained in document n°1"}
83
+ {"text": "Text contained in document n°2"}
84
+ ```
85
+
86
+ ### _Instruct_:
87
+
88
+ Currently two different types of instruction following data are supported:
89
+
90
+ - _Instruct_: conversational data stored in the `"messages"` key in the form of a list. Each list item is a dictionary containing the `"content"` and `"role"` keys. `"role"` is a string being one of "user", "assistant" or "system_prompt". The loss will only be computed if "role" == "assistant". E.g.:
91
+
92
+ ```jsonl
93
+ {
94
+ "messages": [
95
+ {
96
+ "role": "user",
97
+ "content": "User interaction n°1 contained in document n°1"
98
+ },
99
+ {
100
+ "role": "assistant",
101
+ "content": "Bot interaction n°1 contained in document n°1"
102
+ },
103
+ {
104
+ "role": "user",
105
+ "content": "User interaction n°2 contained in document n°1"
106
+ },
107
+ {
108
+ "role": "assistant",
109
+ "content": "Bot interaction n°2 contained in document n°1"
110
+ }
111
+ ]
112
+ }
113
+ {
114
+ "messages": [
115
+ {
116
+ "role": "user",
117
+ "content": "User interaction n°1 contained in document n°2"
118
+ },
119
+ {
120
+ "role": "assistant",
121
+ "content": "Bot interaction n°1 contained in document n°2"
122
+ },
123
+ {
124
+ "role": "user",
125
+ "content": "User interaction n°2 contained in document n°2"
126
+ },
127
+ {
128
+ "role": "assistant",
129
+ "content": "Bot interaction n°2 contained in document n°2",
130
+ "weight": 0, # don't train on n°2
131
+ },
132
+ {
133
+ "role": "user",
134
+ "content": "User interaction n°3 contained in document n°2"
135
+ },
136
+ {
137
+ "role": "assistant",
138
+ "content": "Bot interaction n°3 contained in document n°2"
139
+ }
140
+ ]
141
+ }
142
+ ```
143
+
144
+ - _Function calling_: conversational data stored in the `"messages"` key in the form of a list. Each list item is a dictionary containing the `"role"` and `"content"` or `"tool_calls"` keys. `"role"` is a string being one of "user", "assistant", "system_prompt", or "tool". The loss will only be computed if "role" == "assistant".
145
+
146
+ **Note**: In function calling the `"id"` of `"tool_calls"` and the `"tool_call_id"` are randomly generated strings of exactly 9 chars. We recommend to generate this automatically
147
+ in a data preparation script as is done [here](https://github.com/mistralai/mistral-finetune/blob/208b25c0f7299bb78d06cea25b82adee03834319/utils/reformat_data_glaive.py#L74).
148
+
149
+ E.g.:
150
+
151
+ ```jsonl
152
+ {
153
+ "messages": [
154
+ {
155
+ "role": "system",
156
+ "content": "You are an helpful assistant who has access to the following functions to help the user, you can use the functions if needed"
157
+ },
158
+ {
159
+ "role": "user",
160
+ "content": "Can you help me generate an anagram of the word \"listen\"?"
161
+ },
162
+ {
163
+ "role": "assistant",
164
+ "tool_calls": [
165
+ {
166
+ "id": "TX92Jm8Zi",
167
+ "type": "function",
168
+ "function": {
169
+ "name": "generate_anagram",
170
+ "arguments": "{\"word\": \"listen\"}"
171
+ }
172
+ }
173
+ ]
174
+ },
175
+ {
176
+ "role": "tool",
177
+ "content": "{\"anagram\": \"silent\"}",
178
+ "tool_call_id": "TX92Jm8Zi"
179
+ },
180
+ {
181
+ "role": "assistant",
182
+ "content": "The anagram of the word \"listen\" is \"silent\"."
183
+ },
184
+ {
185
+ "role": "user",
186
+ "content": "That's amazing! Can you generate an anagram for the word \"race\"?"
187
+ },
188
+ {
189
+ "role": "assistant",
190
+ "tool_calls": [
191
+ {
192
+ "id": "3XhQnxLsT",
193
+ "type": "function",
194
+ "function": {
195
+ "name": "generate_anagram",
196
+ "arguments": "{\"word\": \"race\"}"
197
+ }
198
+ }
199
+ ]
200
+ }
201
+ ],
202
+ "tools": [
203
+ {
204
+ "type": "function",
205
+ "function": {
206
+ "name": "generate_anagram",
207
+ "description": "Generate an anagram of a given word",
208
+ "parameters": {
209
+ "type": "object",
210
+ "properties": {
211
+ "word": {
212
+ "type": "string",
213
+ "description": "The word to generate an anagram of"
214
+ }
215
+ },
216
+ "required": [
217
+ "word"
218
+ ]
219
+ }
220
+ }
221
+ }
222
+ ]
223
+ }
224
+ ```
225
+
226
+ ## Verify dataset
227
+
228
+ Before starting a training run you should verify that your dataset is correctly formatted and get an
229
+ estimation of the training time. You can do so by using the [./utils/validate_data](https://github.com/mistralai/mistral-finetune/blob/main/utils/validate_data.py) script.
230
+
231
+ Note that this step is crucial to ensure that the data is correctly formatted.
232
+
233
+ ### Instruction following
234
+
235
+ Let's go over a simple example to train a model in instruction following:
236
+
237
+ - 1. **Load a chunk of [Ultachat_200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)**
238
+
239
+ Create the data folder and navigate to the folder.
240
+ ```sh
241
+ cd $HOME && mkdir -p data && cd $HOME/data
242
+ ```
243
+
244
+ Load the data into a Pandas Dataframe.
245
+
246
+ **Note**: Make sure to have pandas and pyarrow installed (`pip install pandas pyarrow`).
247
+
248
+ ```py
249
+ import pandas as pd
250
+
251
+ df = pd.read_parquet('https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k/resolve/main/data/test_gen-00000-of-00001-3d4cd8309148a71f.parquet')
252
+ ```
253
+ - 2. Split into train and eval
254
+
255
+ ```py
256
+ df_train=df.sample(frac=0.95,random_state=200)
257
+ df_eval=df.drop(df_train.index)
258
+ ```
259
+
260
+ - 3. Save data to jsonl
261
+
262
+ ```py
263
+ df_train.to_json("ultrachat_chunk_train.jsonl", orient="records", lines=True)
264
+ df_eval.to_json("ultrachat_chunk_eval.jsonl", orient="records", lines=True)
265
+ ```
266
+
267
+ - 4. Modify your training yaml to include the ultrachat dataset and verify the yaml
268
+
269
+ Modify [example/7B.yaml](https://github.com/mistralai/mistral-finetune/blob/main/example/7B.yaml) to include the absolute path to `$HOME/data/ultrachat_chunk_train.jsonl` as well as a dataset mixing weight for training and `$HOME/data/ultrachat_chunk_eval.jsonl` for eval, *e.g.*
270
+
271
+ ```
272
+ data:
273
+ instruct_data: "/Users/johndoe/data/ultrachat_chunk_train.jsonl"
274
+ eval_instruct_data: "/Users/johndoe/data/ultrachat_chunk_eval.jsonl"
275
+ ```
276
+
277
+ Now you can verify your training yaml to make sure the data is correctly formatted and to get an estimate of your training time.
278
+
279
+ ```
280
+ cd $HOME/mistral-finetune
281
+ python -m utils.validate_data --train_yaml example/7B.yaml
282
+ ```
283
+
284
+ Upon completion you should see an error report with many of the following errors:
285
+
286
+ ```
287
+ The data in line 1412 of dataset /Users/johndoe/data/ultrachat_chunk_eval.jsonl is incorrectly formated.Expected last role to be one of: [assistant] but got user
288
+ The data in line 1413 of dataset /Users/johndoe/data/ultrachat_chunk_eval.jsonl is incorrectly formated.Expected last role to be one of: [assistant] but got user
289
+ The data in line 1414 of dataset /Users/johndoe/data/ultrachat_chunk_eval.jsonl is incorrectly formated.Expected last role to be one of: [assistant] but got user
290
+ The data in line 1415 of dataset /Users/johndoe/data/ultrachat_chunk_eval.jsonl is incorrectly formated.Expected last role to be one of: [assistant] but got user
291
+ ```
292
+
293
+ Many conversations seem to end with the 'user' role which is unnecessary as we only train on 'assistant' messages and thus would unnecessarily process data.
294
+
295
+ You can make use of [./utils/reformat_data.py](https://github.com/mistralai/mistral-finetune/blob/main/utils/reformat_data.py) to correct the data:
296
+
297
+ ```
298
+ cd $HOME/mistral-finetune
299
+ python -m utils.reformat_data $HOME/data/ultrachat_chunk_train.jsonl
300
+ python -m utils.reformat_data $HOME/data/ultrachat_chunk_eval.jsonl
301
+ ```
302
+
303
+ You should see that a couple of samples will be skipped.
304
+
305
+ - 5. Potentially change number of training steps
306
+
307
+ Upon correction of the dataset, run the script again
308
+
309
+ ```
310
+ cd $HOME/mistral-finetune
311
+ python -m utils.validate_data --train_yaml example/7B.yaml
312
+ ```
313
+
314
+ You should get a summary of the data input and training parameters:
315
+
316
+ ```
317
+ Train States
318
+ --------------------
319
+ {
320
+ "expected": {
321
+ "eta": "00:52:44",
322
+ "data_tokens": 25169147,
323
+ "train_tokens": 131072000,
324
+ "epochs": "5.21",
325
+ "max_steps": 500,
326
+ "data_tokens_per_dataset": {
327
+ "/Users/johndoe/data/ultrachat_chunk_train.jsonl": "25169147.0"
328
+ },
329
+ "train_tokens_per_dataset": {
330
+ "/Users/johndoe/data/ultrachat_chunk_train.jsonl": "131072000.0"
331
+ },
332
+ "epochs_per_dataset": {
333
+ "/Users/johndoe/data/ultrachat_chunk_train.jsonl": "5.2"
334
+ }
335
+ },
336
+ }
337
+ ```
338
+
339
+ Having `max_steps` set to 500 would lead to iterating through the dataset roughly 5 times which is reasonable, but might
340
+ be a bit too much. A recommended setting is shown below which would only take 30min on a 8xH100 cluster.
341
+
342
+ ### Function calling
343
+
344
+ Next let's go over a more advanced use case to fine-tune a model on function calling.
345
+ Function calling requires the data to be in the format as [explained above](#instruct). Let's go over an example.
346
+
347
+ - 1. **Load a chat-formatted version of the [Glaive function calling dataset](https://huggingface.co/datasets/Locutusque/function-calling-chatml)**
348
+
349
+ Create the data folder and navigate to the folder.
350
+ ```sh
351
+ cd $HOME && mkdir -p data && cd $HOME/data
352
+ ```
353
+
354
+ Load the data into a Pandas Dataframe.
355
+
356
+ **Note**: Make sure to have pandas and pyarrow installed (`pip install pandas pyarrow`).
357
+
358
+ ```py
359
+ import pandas as pd
360
+
361
+ df = pd.read_parquet('https://huggingface.co/datasets/Locutusque/function-calling-chatml/resolve/main/data/train-00000-of-00001-f0b56c6983b4a78f.parquet')
362
+ ```
363
+ - 2. Split into train and eval
364
+
365
+ ```py
366
+ df_train=df.sample(frac=0.95,random_state=200)
367
+ df_eval=df.drop(df_train.index)
368
+ ```
369
+
370
+ - 3. Save data to jsonl
371
+
372
+ ```py
373
+ df_train.to_json("glaive_train.jsonl", orient="records", lines=True)
374
+ df_eval.to_json("glaive_eval.jsonl", orient="records", lines=True)
375
+ ```
376
+
377
+ - 4. Reformat dataset
378
+
379
+ As one can see the dataset does not follow the required function calling format, so it will need to be reformatted. Among other things `"from"` should be renamed to `"user"` and superfluous `"\n"` characters should be removed.
380
+ For this dataset you can make use of [`./utils/reformat_data_glaive.py`](https://github.com/mistralai/mistral-finetune/blob/main/utils/reformat_data_glaive.py):
381
+
382
+ ```
383
+ cd $HOME/mistral-finetune
384
+ python -m utils.reformat_data_glaive $HOME/data/glaive_train.jsonl
385
+ python -m utils.reformat_data_glaive $HOME/data/glaive_eval.jsonl
386
+ ```
387
+
388
+ Running this command will make sure that most samples are in the correct format.
389
+
390
+ **Note**: It is impossible to write reformatting scripts that work for all kinds of datasets.
391
+ If you have datasets that don't yet follow the required format above, you will most probably have to
392
+ create a reformatting script yourself (mistral-chat or chat-gpt is your best friend here!).
393
+
394
+ - 5. Validate dataset
395
+
396
+ You can now validate the dataset by setting `data.instruct_data` and `data.eval_instruct_data` to
397
+ `$HOME/data/glaive_train.jsonl` and `$HOME/data/glaive_eval.jsonl` in `example/7B.yaml` respectively.
398
+
399
+ The reformatted datasets still has some errors which can be removed with `--create_corrected`. For this, make sure to add
400
+ `--create_corrected` as follows:
401
+
402
+ ```
403
+ cd $HOME/mistral-finetune
404
+ python -m utils.validate_data --train_yaml example/7B.yaml --create_corrected
405
+ ```
406
+
407
+ Running this command will show a couple of errors and save two new datasets `$HOME/data/glaive_train.jsonl.corrected` and `$HOME/data/glaive_eval.jsonl.corrected`. Make sure to use these two dataset in `example/7B.yaml` and run the command again. Now the dataset should be correctly formatted!
408
+
409
+
410
+ ## Start training
411
+
412
+ Having followed the [dataset verification section](#verify-dataset), we can now start training.
413
+ For faster training, we recommend setting max_steps to only 300. Make sure to define `run_dir` to your experiment folder and optionally set `wandb_project` to a Weights & Biases project for logging`, *e.g.*:
414
+ ```
415
+ max_steps: 300
416
+ run_dir: "/Users/johndoe/ultra_chat_test"
417
+ wandb.project: ultra_chat
418
+ ```
419
+
420
+ Optionally you can also set `wandb`
421
+
422
+ Save the training configuration and start training! Make sure to set `--nproc-per-node` to the number of available GPUs.
423
+
424
+ ```
425
+ cd $HOME/mistral-finetune
426
+ torchrun --nproc-per-node 8 --master_port $RANDOM -m train example/7B.yaml
427
+ ```
428
+
429
+ Training on ultra-chat should take around 30min on a 8xH100 node and the resulting weights should give an MT Bench score around 6.3.
430
+
431
+ Training on glaive should take around 1h on a 8xH100 node and the resulting weights should work nicely for function calling.
432
+
433
+ ## Customizing training configuration
434
+
435
+ The example `mistral-finetune/examples/7B` defines reasonable parameters for learning rate, weight decay, etc... but you are advised to
436
+ customize these settings for your use case.
437
+
438
+ Generally, a training configuration should fill the following parameters:
439
+
440
+ - `model_id_or_path` defines the model to start training from. This can be a path to a pre-trained model or a local model directory.
441
+ - `run_dir` defines the directory where training checkpoints and metrics are stored.
442
+ - `seq_len` defines the sequence length for training. This is the maximum length of input sequences the model will process. Samples are packed to reach a length of `seq_len` for maximum training efficiency.
443
+ - `batch_size` defines the number of training examples used per GPU. **Note**: The overall effective batch_size (in tokens) across all GPUs equals `num_gpus` x `batch_size` x `seq_len`.
444
+ - `max_steps` defines the maximum number of training steps. This is the total number of iterations the training process will run. It can be adjusted based on the specific needs of your training scenario. Total number of tokens seen during training is `max_steps` x `num_gpus` x `batch_size` x `seq_len`.
445
+ - `optim.lr` defines the learning rate. This is the initial learning rate for the optimizer.
446
+ - `optim.weight_decay` defines weight decay. Weight decay is a regularization technique used to prevent overfitting by penalizing large weights. We recommend leaving it at 0.1.
447
+ - `optim.pct_start` defines the percentage of the total training steps used for the learning rate warm-up phase before it starts to decrease. It corresponds to pct_start of PyTorch's OneCycleLR.
448
+ - `lora.rank` defines the size of the LoRA (Low-Rank Adaptation) adapters. We recommend 64 or less, which adjusts the rank of the low-rank decomposition used in LoRA.
449
+ - `seed` defines the random seed for initialization and data shuffling/sampling. Setting a seed ensures reproducibility of results.
450
+ - `log_freq` defines the logging frequency. This specifies how often (in steps) to log training metrics.
451
+ - `data.instruct_data` is the path to the instruction data used for training. This field has to be filled with one or multiple data sources in the format as explained above. Each data source should either be a path to jsonl file of a path to a directory containing jsonl files followed by a weighting to define the importance of this dataset: `<path/to/data_source>:<weight>`. E.g.: `data.instruct_data: "/path/to/data1.jsonl:5.,/path/to/data2.jsonl:1.,/path/to/dir_of_jsonls:1."`
452
+ - `data.data` is an optional path to additional pretraining data in the format as explained above. Note that this field can be left blank.
453
+ - `data.eval_instruct_data` is an optional path to evaluation instruction data to run cross-validation at every `eval_freq` steps. Cross-validation metrics are displayed as `loss` and `perplexity`.
454
+ - `eval_freq` defines how often (in steps) to evaluate the model. This specifies the interval at which the model is evaluated on the validation set.
455
+ - `no_eval` is a flag to enable or disable intermediate evaluation. Setting it to False enables periodic evaluation during training.
456
+ - `ckpt_freq` defines how often (in steps) to save checkpoints. This specifies the interval at which the model's state is saved.
457
+ - `ckpt_only_lora` defines whether to only save the trained LoRA checkpoints or whether the trained LoRA should directly be merged into the base model and saved. **Note**: When setting `ckpt_only_lora=False` make sure that you have enough CPU and GPU memory to save the full model on a single process (this is usually only possible for the 7B model).
458
+ - `wandb.key` is used to pass your Weights & Biases (wandb) API key for logging. This allows you to log training metrics to the wandb dashboard.
459
+ - `wandb.project` defines the wandb project name. This is where the training run will be logged in the wandb interface.
460
+
461
+ ## Inference
462
+
463
+ Once your model is trained, you should try it out in inference. We recommend using [mistral-inference](https://github.com/mistralai/mistral-inference).
464
+
465
+ Make sure to have `mistral_inference` correctly installed:
466
+ ```
467
+ pip install mistral_inference
468
+ ```
469
+
470
+ Assuming your `lora.safetensors` is saved under `$HOME/ultra_chat_test/checkpoints/checkpoint_000300/consolidated/lora.safetensors`, you can chat with the model using `mistral_inference`, *e.g.*:
471
+
472
+ ```sh
473
+ mistral-chat /mnt/slow/runs/patrick/mistral-finetune/7B/ --max_tokens 256 --temperature 1.0 --instruct --lora_path $HOME/ultra_chat_test/checkpoints/checkpoint_000300/consolidated/lora.safetensors
474
+ ```
475
+
476
+ ## Model extension
477
+
478
+ **Important**: Note that one can only fine-tune mistral models that are compatible with the v3 tokenizer which entails that the models have a vocabulary size of 32768 - not 32000. One can however easily extend older version of vocabulary size 32000 to have a vocabulary size of 32768 by using:
479
+ ```
480
+ python -m utils.extend_model_vocab --original_model_ckpt /folder/to/old/model --extended_model_ckpt /folder/to/extended/model
481
+ ```
482
+
483
+ Once the extension has worked, one can fine-tune using the newly created model checkpoint in `/folder/to/extended/model`.
484
+
485
+ ## FAQ:
486
+
487
+ > - What's the best practice of fine-tuning MoEs?
488
+
489
+ We see a higher degree of performance variance in when fine-tuning MoE models. It's not unusual to find that fine-tuning MoEs models with different seeds can lead to a high variance in performance. We did not observe such a high variance with dense models. Therefore, we suggest running multiple instances of the same fine-tuning process on MoEs models and selecting the one that performs best.
490
+
491
+ > - How can I determine the number of tokens used during the model training process?
492
+
493
+ You can use the following script to find out: https://github.com/mistralai/mistral-finetune/blob/main/utils/validate_data.py. This script accepts a .yaml training file as input and returns the number of tokens the model is being trained on.
494
+
495
+ > - What should I do if I encounter a CUDA out-of-memory error?
496
+
497
+ One possible solution is to reduce the batch size per GPU. The batch size is equal to `seq_len` x `batch_size`. Try setting `batch_size` to 1 and reduce `seq_len`. You can define the `batch_size` and `seq_len` in the .yaml file.
__pycache__/train.cpython-310.pyc ADDED
Binary file (6.27 kB). View file
 
example/7B.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data
2
+ data:
3
+ instruct_data: "/root/data/mol_instructions_train.jsonl" # Fill this with the path to your training data
4
+ data: "" # Optionally fill with pretraining data
5
+ eval_instruct_data: "" # Optionally fill with evaluation data
6
+
7
+ # model
8
+ model_id_or_path: "/root/mistral_models/7B-v0.3" # Path to downloaded model
9
+ lora:
10
+ rank: 64
11
+
12
+ # optim
13
+ seq_len: 32768
14
+ batch_size: 2
15
+ #TODO try other values
16
+ max_steps: 500
17
+ optim:
18
+ lr: 5.e-5
19
+ weight_decay: 0.05
20
+ pct_start: 0.05
21
+
22
+ # other
23
+ seed: 99
24
+ log_freq: 1
25
+ eval_freq: 100
26
+ no_eval: True
27
+ ckpt_freq: 100
28
+
29
+ ckpt_only_lora: False # Save only trained LoRA adapters. Set to `False` to merge LoRA adapter into the base model and save full fine-tuned model
30
+
31
+ run_dir: "/root/mistral-finetune/runseed99"
32
+
33
+ wandb:
34
+ project: "CHEMISTral7b-ft"
35
+ offline: False # Set to True if you want to use wandb in offline mode
36
+ key: "aaf77f83a4e316f6a8b47fa975ab6b5e73c7c8df" # Optionally set your WandB API key
37
+ run_name: "runseed99" # Optionally name your WandB run
finetune/__init__.py ADDED
File without changes
finetune/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (136 Bytes). View file
 
finetune/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (134 Bytes). View file
 
finetune/__pycache__/args.cpython-310.pyc ADDED
Binary file (3.81 kB). View file
 
finetune/__pycache__/args.cpython-38.pyc ADDED
Binary file (3.79 kB). View file
 
finetune/__pycache__/checkpointing.cpython-310.pyc ADDED
Binary file (8.73 kB). View file
 
finetune/__pycache__/checkpointing.cpython-38.pyc ADDED
Binary file (8.67 kB). View file
 
finetune/__pycache__/distributed.cpython-310.pyc ADDED
Binary file (2.02 kB). View file
 
finetune/__pycache__/distributed.cpython-38.pyc ADDED
Binary file (2.05 kB). View file
 
finetune/__pycache__/eval.cpython-310.pyc ADDED
Binary file (2.24 kB). View file
 
finetune/__pycache__/loss.cpython-310.pyc ADDED
Binary file (569 Bytes). View file
 
finetune/__pycache__/mixed_precision.cpython-310.pyc ADDED
Binary file (1.6 kB). View file
 
finetune/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.94 kB). View file
 
finetune/__pycache__/wrapped_model.cpython-310.pyc ADDED
Binary file (7.49 kB). View file
 
finetune/args.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from dataclasses import dataclass, field
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ from simple_parsing.helpers import Serializable
8
+
9
+ from model.args import LoraArgs
10
+
11
+ from .data.args import DataArgs
12
+
13
+
14
+ @dataclass
15
+ class OptimArgs(Serializable):
16
+ lr: float = 3e-4
17
+ weight_decay: float = 0.1
18
+ pct_start: float = 0.3
19
+
20
+
21
+ @dataclass
22
+ class WandbArgs(Serializable):
23
+ project: Optional[str] = None # Fill this argument to use wandb.
24
+ offline: bool = False
25
+ key: Optional[str] = None
26
+ run_name: Optional[str] = None
27
+
28
+ def __post_init__(self) -> None:
29
+ if self.project is not None:
30
+ try:
31
+ import wandb # noqa: F401
32
+ except ImportError:
33
+ raise ImportError("`wandb` not installed. Either make sure `wandb` is installed or set `wandb:project` to None.")
34
+
35
+ if len(self.project) == 0:
36
+ raise ValueError("`wandb.project` must not be an empty string.")
37
+
38
+ @dataclass
39
+ class MLFlowArgs(Serializable):
40
+ tracking_uri: Optional[str] = None
41
+ experiment_name: Optional[str] = None
42
+
43
+ def __post_init__(self) -> None:
44
+ if self.tracking_uri is not None:
45
+ try:
46
+ import mlflow # noqa: F401
47
+ except ImportError:
48
+ raise ImportError("`mlflow` not installed. Either make sure `mlflow` is installed or set `mlflow.tracking_uri` to None.")
49
+
50
+ if self.experiment_name is None:
51
+ raise ValueError("If `mlflow.tracking_uri` is set, `mlflow.experiment_name` must be set as well.")
52
+
53
+
54
+
55
+ @dataclass
56
+ class TrainArgs(Serializable):
57
+ data: DataArgs
58
+
59
+ # if specified, instruct_tokenizer and model will be loaded
60
+ model_id_or_path: str # Path to the directory containing the initial model or model id: "mistral-small"
61
+
62
+ run_dir: str # Path to the directory where everything will be saved. It needs to be empty.
63
+ # Name of the wandb run, if None it will be set to the name of the run_dir.
64
+
65
+ optim: OptimArgs = field(default_factory=OptimArgs)
66
+ seed: int = 0
67
+ # Number of steps to accumulate gradients before calling doing an optimizer step.
68
+ num_microbatches: int = 1
69
+
70
+ seq_len: int = 2048 # Number of tokens per batch per device.
71
+ batch_size: int = 1
72
+ max_norm: float = 1.0 # Gradient clipping.
73
+ max_steps: int = 100 # Number of training steps.
74
+ log_freq: int = 1 # Number of steps between each logging.
75
+
76
+ # Number of steps between each checkpoint saving. If inferior to 1, only the last checkpoint will be saved.
77
+ ckpt_freq: int = 0
78
+ ckpt_only_lora: bool = True
79
+ # If True, no checkpoint will be saved. This is useful for development.
80
+ no_ckpt: bool = False
81
+ num_ckpt_keep: Optional[int] = 3
82
+ eval_freq: int = 0
83
+ no_eval: bool = True
84
+
85
+ # Efficiency
86
+ # Determines whether gradient checkpointing should be utilized or not during the training process. Gradient checkpointing can be beneficial in reducing memory usage at the cost of slightly longer training times.
87
+ checkpoint: bool = True
88
+
89
+ world_size: Optional[int] = field(init=False, default=None)
90
+
91
+ # logging
92
+ wandb: WandbArgs = field(default_factory=WandbArgs)
93
+ mlflow: MLFlowArgs = field(default_factory=MLFlowArgs)
94
+
95
+ # LoRA
96
+ lora: Optional[LoraArgs] = field(default_factory=LoraArgs)
97
+
98
+ def __post_init__(self) -> None:
99
+ assert getattr(self, "world_size", None) is None
100
+ self.world_size = int(os.environ.get("WORLD_SIZE", -1))
101
+
102
+ if self.wandb.offline:
103
+ command = f"cd {self.run_dir}; wandb sync --sync-all"
104
+ logging.info(f"to sync wandb offline, run: {command}")
105
+
106
+ assert self.num_microbatches >= 1
107
+
108
+ assert self.num_ckpt_keep is None or self.num_ckpt_keep >= 1
109
+
110
+ if self.model_id_or_path is not None:
111
+ Path(self.model_id_or_path).exists()
112
+
113
+ if not self.ckpt_only_lora:
114
+ logging.warning(
115
+ "You are have disabled `ckpt_only_lora` and are thus merging the trained LoRA checkpoint into the base model upon checkpointing. This might lead to OOM erros - make sure you have enough CPU and GPU memory."
116
+ )
finetune/checkpointing.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import shutil
4
+ from pathlib import Path
5
+ from typing import Dict, List, Optional, Union
6
+
7
+ import safetensors.torch
8
+ import torch
9
+ from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase
10
+ from torch.distributed import barrier
11
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
12
+
13
+ from model.transformer import LoRALinear
14
+
15
+ from .distributed import get_rank, get_world_size
16
+ from .utils import TrainState
17
+
18
+ logger = logging.getLogger("checkpointing")
19
+
20
+
21
+ def main_logger_info(message: str) -> None:
22
+ if get_rank() == 0:
23
+ logger.info(message)
24
+
25
+
26
+ class Checkpointer:
27
+ """A class to save PyTorch model and optimizer states"""
28
+
29
+ def __init__(
30
+ self,
31
+ model: FullyShardedDataParallel,
32
+ state: TrainState,
33
+ run_dir: Union[Path, str],
34
+ optimizer: Optional[torch.optim.Optimizer] = None,
35
+ num_ckpt_keep: Optional[int] = None,
36
+ ):
37
+ self.model = model
38
+ self.optimizer = optimizer
39
+ self.state = state
40
+ self.run_dir = Path(run_dir)
41
+ self.rank = get_rank()
42
+ self.num_ckpt_keep = num_ckpt_keep
43
+
44
+ @property
45
+ def ckpt_dir(self) -> Path:
46
+ return self.run_dir / "checkpoints"
47
+
48
+ @property
49
+ def dst_dir(self) -> Path:
50
+ return self.ckpt_dir / f"checkpoint_{self.state.step:06d}" / "consolidated"
51
+
52
+ @staticmethod
53
+ def consolidated_path(
54
+ ckpt_dir: Path, use_safetensors: bool, save_only_lora: Optional[bool] = False
55
+ ) -> Path:
56
+ suffix = "safetensors" if use_safetensors else "00.pth"
57
+ prefix = "lora" if save_only_lora else "consolidated"
58
+
59
+ return ckpt_dir / f"{prefix}.{suffix}"
60
+
61
+ @staticmethod
62
+ def _tmp(ckpt_dir: Path) -> Path:
63
+ return ckpt_dir.with_name(f"tmp.{ckpt_dir.name}")
64
+
65
+ def write_params_info(self, tmp_dst: Path):
66
+ params_path = tmp_dst / "params.json"
67
+ with open(params_path, "w") as f:
68
+ model_args = self.model.args.to_dict()
69
+
70
+ f.write(json.dumps(model_args, indent=4))
71
+
72
+ def delete_old_ckpts(self) -> List[Path]:
73
+ all_saved_ckpts = [d for d in self.ckpt_dir.iterdir() if d.is_dir()]
74
+
75
+ # Sort directories by creation time (oldest to newest)
76
+ all_saved_ckpts.sort(key=lambda x: x.stat().st_ctime, reverse=True)
77
+
78
+ ckpts_to_delete = all_saved_ckpts[self.num_ckpt_keep :]
79
+
80
+ for ckpt_to_delete in ckpts_to_delete:
81
+ try:
82
+ shutil.rmtree(ckpt_to_delete)
83
+ main_logger_info(f"Deleted ckpt: {ckpt_to_delete}")
84
+ except OSError as e:
85
+ main_logger_info(f"Error deleting directory {ckpt_to_delete}: {e}")
86
+
87
+ return ckpts_to_delete
88
+
89
+ @staticmethod
90
+ def get_lora_states(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
91
+ return {k: v for k, v in state_dict.items() if "lora" in k}
92
+
93
+ @staticmethod
94
+ def get_non_lora_states(
95
+ state_dict: Dict[str, torch.Tensor]
96
+ ) -> Dict[str, torch.Tensor]:
97
+ return {
98
+ k: v
99
+ for k, v in state_dict.items()
100
+ if not any(l_key in k for l_key in ["lora", "frozen"])
101
+ }
102
+
103
+ @torch.no_grad()
104
+ def retrieve_save_states(
105
+ self, save_only_lora: bool, save_dtype: torch.dtype
106
+ ) -> Dict[str, torch.Tensor]:
107
+ if save_only_lora:
108
+ assert (
109
+ self.model.args.lora.enable
110
+ ), "Cannot save LoRA checkpoint as LoRA training is not enabled."
111
+
112
+ # remove all potential hooks
113
+ for module in self.model.modules():
114
+ if isinstance(module, LoRALinear) and hasattr(module, "_merge_lora_handle"):
115
+ module._merge_lora_handle.remove() # type: ignore
116
+
117
+ # merge weights if we don't just save LoRA
118
+ if not save_only_lora:
119
+
120
+ def merge_lora(
121
+ m: torch.nn.Module,
122
+ destination: Dict[str, torch.Tensor],
123
+ prefix: str,
124
+ *args,
125
+ ):
126
+ weight = m.merge_weight() # type: ignore
127
+ destination[prefix + "weight"] = weight
128
+
129
+ for module in self.model.modules():
130
+ if isinstance(module, LoRALinear):
131
+ module._merge_lora_handle = module._register_state_dict_hook(
132
+ merge_lora
133
+ )
134
+
135
+ offload_to_cpu = get_world_size() > 1
136
+ if save_only_lora:
137
+
138
+ def is_trainable_fsdp(
139
+ module: Union[torch.nn.Module, FullyShardedDataParallel]
140
+ ):
141
+ is_fsdp = isinstance(module, FullyShardedDataParallel)
142
+ all_params_have_grads = is_fsdp and all(
143
+ p.requires_grad is True for p in module.parameters()
144
+ )
145
+
146
+ # need to make sure only lowest fsdp wrap is used
147
+ is_leaf_node = is_fsdp and len(list(module.module.children())) == 0 # type: ignore
148
+
149
+ return is_fsdp and all_params_have_grads and is_leaf_node
150
+
151
+ # extract all modules with only trainable weights
152
+ modules = {
153
+ k: m for k, m in self.model.named_modules() if is_trainable_fsdp(m)
154
+ }
155
+
156
+ states = {}
157
+ for key, module in modules.items():
158
+ assert isinstance(
159
+ module, FullyShardedDataParallel
160
+ ), "`module` should be an instance of `FullyShardedDataParallel`"
161
+ parent_prefix = key.replace("_fsdp_wrapped_module.", "").replace(
162
+ "_checkpoint_wrapped_module.", ""
163
+ )
164
+ with module.summon_full_params(
165
+ module, writeback=True, offload_to_cpu=offload_to_cpu
166
+ ):
167
+ states.update(
168
+ {
169
+ f"{parent_prefix}.{k}": v.to(dtype=save_dtype)
170
+ for k, v in module.state_dict().items()
171
+ }
172
+ )
173
+ else:
174
+ # make sure you have enough CPU RAM available to save the full model
175
+ assert isinstance(
176
+ self.model, FullyShardedDataParallel
177
+ ), "`self.model` should be an instance of `FullyShardedDataParallel`"
178
+ with self.model.summon_full_params(
179
+ self.model, writeback=True, offload_to_cpu=offload_to_cpu
180
+ ):
181
+ states = self.get_non_lora_states(self.model.state_dict())
182
+ states = {k: v.to(dtype=save_dtype) for k, v in states.items()}
183
+
184
+ states = dict(sorted(states.items()))
185
+ return states
186
+
187
+ @staticmethod
188
+ def save_tokenizer(instruct_tokenizer: InstructTokenizerBase, tmp_dst: Path):
189
+ serialized_spm = instruct_tokenizer.tokenizer._model.serialized_model_proto() # type: ignore
190
+
191
+ tokenizer_path = tmp_dst / "tokenizer.model.v3"
192
+
193
+ with open(tokenizer_path, "wb") as f:
194
+ f.write(serialized_spm)
195
+
196
+ @torch.no_grad()
197
+ def save_checkpoint(
198
+ self,
199
+ save_only_lora: bool,
200
+ dtype: torch.dtype = torch.float16,
201
+ instruct_tokenizer: Optional[InstructTokenizerBase] = None,
202
+ ):
203
+ tmp_dst = self._tmp(self.dst_dir)
204
+ main_logger_info(
205
+ f"Dumping checkpoint in {self.dst_dir} using tmp name: {tmp_dst.name}"
206
+ )
207
+
208
+ assert not self.dst_dir.exists(), f"dst exists {self.dst_dir}"
209
+ tmp_dst.mkdir(parents=True, exist_ok=True)
210
+
211
+ states: Dict[str, torch.Tensor] = self.retrieve_save_states(
212
+ save_only_lora, dtype
213
+ )
214
+
215
+ barrier()
216
+
217
+ if self.rank == 0:
218
+ # save checkpoint in tmp path
219
+ safetensors.torch.save_file(
220
+ states,
221
+ self.consolidated_path(
222
+ tmp_dst, use_safetensors=True, save_only_lora=save_only_lora
223
+ ), # always use safetensors for checkpointing
224
+ )
225
+
226
+ self.write_params_info(tmp_dst)
227
+
228
+ # save tokenizer
229
+ if instruct_tokenizer is not None:
230
+ self.save_tokenizer(instruct_tokenizer, tmp_dst)
231
+
232
+ assert not self.dst_dir.exists(), f"should not happen! {self.dst_dir}"
233
+ tmp_dst.rename(self.dst_dir)
234
+
235
+ logger.info(
236
+ f"Done dumping checkpoint in {self.dst_dir} for step: {self.state.step}"
237
+ )
238
+
239
+ # delete last n checkpoints
240
+ if self.num_ckpt_keep is not None:
241
+ ckpts_to_delete = self.delete_old_ckpts()
242
+ logger.info(
243
+ f"Done deleting checkpoints {', '.join([str(c) for c in ckpts_to_delete])}"
244
+ )
245
+
246
+ main_logger_info("Done!")
finetune/data/__init__.py ADDED
File without changes
finetune/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (141 Bytes). View file
 
finetune/data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (139 Bytes). View file
 
finetune/data/__pycache__/args.cpython-310.pyc ADDED
Binary file (1.34 kB). View file
 
finetune/data/__pycache__/args.cpython-38.pyc ADDED
Binary file (1.33 kB). View file
 
finetune/data/__pycache__/data_loader.cpython-310.pyc ADDED
Binary file (4.26 kB). View file
 
finetune/data/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
finetune/data/__pycache__/dataset.cpython-38.pyc ADDED
Binary file (11.1 kB). View file
 
finetune/data/__pycache__/exceptions.cpython-310.pyc ADDED
Binary file (2.57 kB). View file
 
finetune/data/__pycache__/exceptions.cpython-38.pyc ADDED
Binary file (2.91 kB). View file
 
finetune/data/__pycache__/tokenize.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
finetune/data/__pycache__/tokenize.cpython-38.pyc ADDED
Binary file (10.3 kB). View file
 
finetune/data/args.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass, field
3
+
4
+ from simple_parsing.helpers import Serializable
5
+
6
+ logger = logging.getLogger("data")
7
+
8
+
9
+ @dataclass()
10
+ class InstructArgs(Serializable):
11
+ shuffle: bool = True
12
+
13
+ # For function calling training examples only the last tool call
14
+ # of the assistant message can be used for training. Therefore,
15
+ # we chunk longer function calling conversations into multiple
16
+ # training samples to not loose any data. E.g.:
17
+ # [[
18
+ # UserMessage_1, AssisantToolCallMessage_1, ToolMessage_1, AssisantMessage_1
19
+ # UserMessage_2, AssisantToolCallMessage_2, ToolMessage_2, AssisantMessage_2
20
+ # ]]
21
+ # => is chunked into two training samples:
22
+ # [[
23
+ # UserMessage_1, AssisantToolCallMessage_1, ToolMessage_1, AssisantMessage_1
24
+ # ],
25
+ # [
26
+ # UserMessage_1, AssisantToolCallMessage_1, ToolMessage_1, AssisantMessage_1
27
+ # UserMessage_2, AssisantToolCallMessage_2, ToolMessage_2, AssisantMessage_2
28
+ # ]]
29
+ # NOTE: Only if your data is already pre-chunked should this argument be set to False
30
+ dynamic_chunk_fn_call: bool = True
31
+
32
+
33
+ @dataclass()
34
+ class DataArgs(Serializable):
35
+ # The data arguments `data` and `instruct_data` are a string in the format
36
+ # "data_source_dir_1:weight_1,data_source_dir_2:weight_2,...". The weight
37
+ # will be used to sample the data sources. If the sum of the weights is
38
+ # not 1 when concatenating the two arguments `data` and `instruct_data`,
39
+ # it will be normalized. The data sources folders must contain jsonl files.
40
+ # If the value is an empty string, no data will be used for the corresponding
41
+ # data type.
42
+ data: str = (
43
+ "" # Each line in the jsonl files inside the data source directories must be a dictionary with a "text" key. See Readme for more details. Can be left empty.
44
+ )
45
+ shuffle: bool = False
46
+ instruct_data: str = (
47
+ "" # Each line in the jsonl files inside the data source directories must be a dictionary with a "interactions" key. See Readme for more details. Can be left empty.
48
+ )
49
+ eval_instruct_data: str = (
50
+ "" # Each line in the jsonl files inside the data source directories must be a dictionary with a "interactions" key. See Readme for more details. Can be left empty.
51
+ )
52
+ instruct: InstructArgs = field(default_factory=InstructArgs)
53
+
54
+ def __post_init__(self) -> None:
55
+ if (
56
+ self.instruct.shuffle is False
57
+ and self.instruct.dynamic_chunk_fn_call is True
58
+ ):
59
+ raise ValueError(
60
+ "Make sure to either enable `data.instruct.shuffle=True` or `data.instruct.dynamic_chunk_fn_call=False`. Dynamic chunking is only possible if data is loaded and shuffled before training."
61
+ )
finetune/data/data_loader.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from typing import Any, Iterator, List, Optional
3
+
4
+ import numpy as np
5
+ from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase
6
+
7
+ from .args import DataArgs
8
+ from .dataset import build_dataset
9
+
10
+
11
+ @dataclasses.dataclass
12
+ class Batch:
13
+ x: np.ndarray
14
+ y: np.ndarray
15
+ sizes: List[int]
16
+ y_mask: Optional[np.ndarray] = None
17
+ is_pad_only: bool = False
18
+
19
+ def __post_init__(self):
20
+ assert self.x.ndim == 1
21
+ assert self.x.shape == self.y.shape
22
+ assert self.x.dtype == np.int64
23
+ assert self.y.dtype == np.int64
24
+ assert isinstance(self.sizes, list)
25
+ assert sum(self.sizes) == self.x.size == self.y.size
26
+
27
+ if self.y_mask is not None:
28
+ assert self.y_mask.size == self.y.size, (self.y_mask.shape, self.y.shape)
29
+ assert self.y_mask.dtype == bool
30
+ assert sum(self.sizes) == self.y_mask.size
31
+ assert not self.y_mask.all()
32
+ assert self.y_mask.any()
33
+
34
+ if self.is_pad_only:
35
+ assert np.sum(np.abs(self.y)) == 0
36
+ assert np.sum(np.abs(self.x)) == 0
37
+ assert self.y_mask is None
38
+ # create all 0's mask for pad samples
39
+ self.y_mask = np.zeros_like(self.x)
40
+
41
+
42
+
43
+
44
+ @dataclasses.dataclass
45
+ class BatchList:
46
+ x: List[List[int]] = dataclasses.field(default_factory=list)
47
+ y: List[List[int]] = dataclasses.field(default_factory=list)
48
+ sizes: List[List[int]] = dataclasses.field(default_factory=list)
49
+ y_mask: List[Optional[List[int]]] = dataclasses.field(default_factory=list)
50
+
51
+ def __post_init__(self):
52
+ assert self.x == [], "`BatchList` has to be empty at init."
53
+ assert self.y == [], "`BatchList` has to be empty at init."
54
+ assert self.sizes == [], "`BatchList` has to be empty at init."
55
+ assert self.y_mask == [], "`BatchList` has to be empty at init."
56
+
57
+ def __len__(self) -> int:
58
+ return len(self.x)
59
+
60
+ def add(self, x: List[int], y: List[int], sizes: List[int], y_mask: Optional[List[int]] = None):
61
+ self.x.append(x)
62
+ self.y.append(y)
63
+ self.sizes.append(sizes)
64
+ self.y_mask.append(y_mask)
65
+
66
+ def empty(self):
67
+ self.x = []
68
+ self.y = []
69
+ self.sizes = []
70
+ self.y_mask = []
71
+
72
+ @staticmethod
73
+ def flatten_to_numpy(list_of_lists: List[List[Any]], dtype: np.dtype) -> np.array:
74
+ return np.array([el for sublist in list_of_lists for el in sublist], dtype=dtype)
75
+
76
+ def create_batch(self) -> Batch:
77
+ x_np: np.array = self.flatten_to_numpy(self.x, dtype=np.int64)
78
+ y_np: np.array = self.flatten_to_numpy(self.y, dtype=np.int64)
79
+ sizes = sum(self.sizes, []) # noqa
80
+
81
+ y_mask_np: Optional[np.array] = self.flatten_to_numpy(self.y_mask, dtype=bool)
82
+ y_mask_np = None if y_mask_np.all() else y_mask_np
83
+
84
+ return Batch(x_np, y_np, sizes, y_mask_np)
85
+
86
+
87
+
88
+
89
+ def build_data_loader(
90
+ instruct_tokenizer: InstructTokenizerBase,
91
+ args: DataArgs,
92
+ batch_size: int,
93
+ seq_len: int,
94
+ seed: Optional[int],
95
+ rank: int,
96
+ world_size: int,
97
+ is_eval: bool,
98
+ ) -> Iterator[Batch]:
99
+ pretrain_data = args.data if not is_eval else ""
100
+ instruct_data = args.instruct_data if not is_eval else args.eval_instruct_data
101
+
102
+ dataset = build_dataset(
103
+ pretrain_data=pretrain_data,
104
+ instruct_data=instruct_data,
105
+ instruct_args=args.instruct,
106
+ instruct_tokenizer=instruct_tokenizer,
107
+ seq_len=seq_len,
108
+ seed=seed,
109
+ rank=rank,
110
+ world_size=world_size,
111
+ is_eval=is_eval,
112
+ shuffle_pretrain=args.shuffle,
113
+ )
114
+
115
+ batch_list = BatchList()
116
+ for sample in dataset:
117
+ assert all(s >= 0 for s in sample.sizes)
118
+
119
+ batch_list.add(sample.x, sample.y, sample.sizes, sample.mask)
120
+
121
+ if len(batch_list) == batch_size:
122
+ batch: Batch = batch_list.create_batch()
123
+ yield batch
124
+
125
+ batch_list.empty()
126
+
finetune/data/dataset.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import itertools
3
+ import json
4
+ import logging
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union
8
+
9
+ import numpy as np
10
+ import torch.distributed as dist
11
+ from mistral_common.protocol.instruct.messages import (
12
+ FinetuningAssistantMessage,
13
+ SystemMessage,
14
+ )
15
+ from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase
16
+
17
+ from finetune.distributed import get_rank
18
+
19
+ from .args import InstructArgs
20
+ from .tokenize import (
21
+ Mask,
22
+ SampleType,
23
+ TokenSample,
24
+ TrainingInstructSample,
25
+ build_instruct_sample,
26
+ encode,
27
+ )
28
+
29
+ logger = logging.getLogger("dataset")
30
+
31
+
32
+ _LOADED_DATASETS: Dict[Path, List[str]] = {}
33
+
34
+
35
+ def main_logger_info(message: str) -> None:
36
+ if dist.is_initialized() and get_rank() == 0:
37
+ logger.info(message)
38
+
39
+
40
+ def load_file(path: Path, world_size: int, rank: int) -> List[str]:
41
+ lines = []
42
+ with path.open() as f:
43
+ for idx, line in enumerate(f):
44
+ if not idx % world_size == rank:
45
+ continue
46
+ lines.append(line)
47
+ return lines
48
+
49
+
50
+ def maybe_load_local_dataset(
51
+ path: Path, chunk: bool, rank: int, world_size: int, instruct_tokenizer: InstructTokenizerBase, sample_type: SampleType
52
+ ) -> List[TokenSample]:
53
+ global _LOADED_DATASETS
54
+
55
+ if path in _LOADED_DATASETS:
56
+ return _LOADED_DATASETS[path]
57
+
58
+ main_logger_info(f"Loading {path} ...")
59
+ lines: List[str] = load_file(path, rank=rank, world_size=world_size)
60
+
61
+ if chunk:
62
+ lines += maybe_chunk_lines(lines)
63
+
64
+ tokens_list: List[TokenSample] = []
65
+ for line in lines:
66
+ data = json.loads(line)
67
+
68
+ token_sample: TokenSample = encode(
69
+ data,
70
+ instruct_tokenizer=instruct_tokenizer,
71
+ as_type=sample_type,
72
+ )
73
+ tokens_list.append(token_sample)
74
+
75
+ main_logger_info(f"{path} loaded and tokenized.")
76
+ _LOADED_DATASETS[path] = tokens_list
77
+
78
+ return _LOADED_DATASETS[path]
79
+
80
+
81
+ @dataclass
82
+ class DataDir:
83
+ path: Path
84
+ sample_type: SampleType
85
+
86
+ @property
87
+ def jsonl_files(self):
88
+ assert self.path.exists(), f"Make sure that {self.path} exists"
89
+ jsonl_files = list(self.path.rglob("*jsonl"))
90
+ assert (
91
+ len(jsonl_files) > 0
92
+ ), f"{self.path} does not seem to have any files ending with '.jsonl'"
93
+ return jsonl_files
94
+
95
+
96
+ @dataclass
97
+ class DataFile:
98
+ path: Path
99
+ sample_type: SampleType
100
+
101
+ @property
102
+ def jsonl_files(self):
103
+ assert self.path.exists(), f"Make sure that {self.path} exists"
104
+ return [self.path]
105
+
106
+
107
+ def parse_data_sources(
108
+ pretrain_data: str,
109
+ instruct_data: str,
110
+ ) -> Tuple[List[Union[DataDir, DataFile]], List[float]]:
111
+ seen: Set[str] = set()
112
+ sources: List[Union[DataDir, DataFile]] = []
113
+ weights: List[float] = []
114
+ for sample_sources, sample_type in [
115
+ (pretrain_data, SampleType.PRETRAIN),
116
+ (instruct_data, SampleType.INSTRUCT),
117
+ ]:
118
+ for source in sample_sources.strip().split(","):
119
+ if not source:
120
+ continue
121
+
122
+ source_items = source.strip().split(":")
123
+ if len(source_items) == 1:
124
+ path_ = source_items[0]
125
+ weight = 1.0
126
+ elif len(source_items) == 2:
127
+ path_, weight_ = source_items
128
+ weight = float(weight_)
129
+ else:
130
+ raise ValueError(
131
+ f"{source} is not correctly formatted. Make sure to format each data source as <path/to/data>:<weight> or just <path/to/data>"
132
+ )
133
+
134
+ assert (
135
+ path_ not in seen
136
+ ), f"{path_} seems to be duplicated. Make sure to only add it once."
137
+ assert (
138
+ weight > 0
139
+ ), f"Make sure to define strictly positive data sampling weights, not {weight}"
140
+
141
+ data: Union[DataDir, DataFile]
142
+ if Path(path_).is_dir():
143
+ data = DataDir(path=Path(path_), sample_type=sample_type)
144
+ elif Path(path_).is_file():
145
+ data = DataFile(path=Path(path_), sample_type=sample_type)
146
+ else:
147
+ raise FileNotFoundError(
148
+ f"The path {path_} does not exist. Make sure {path_} is either a file or directory that contains training data."
149
+ )
150
+
151
+ sources.append(data)
152
+ weights.append(weight)
153
+
154
+ seen.add(path_)
155
+
156
+ sum_weights = sum(weights)
157
+ n_weights = [weight / sum_weights for weight in weights]
158
+
159
+ assert min(n_weights) > 0
160
+ assert (
161
+ abs(1 - sum(n_weights)) < 1e-8
162
+ ), f"Defined data sampling weights {weights} must sum to 1."
163
+ return sources, n_weights
164
+
165
+
166
+ @dataclasses.dataclass()
167
+ class SequenceMaskAndSizes:
168
+ """
169
+ Concatenation of samples to reach a given size
170
+ """
171
+
172
+ x: List[int]
173
+ y: List[int]
174
+ mask: Mask
175
+ sizes: List[int]
176
+
177
+ def __post_init__(self):
178
+ assert sum(self.sizes) == len(self.x) == len(self.y) == len(self.mask)
179
+
180
+
181
+ def sequence_iterator(
182
+ ds_it: Iterator[TokenSample],
183
+ seq_len: int,
184
+ is_finite: bool,
185
+ ) -> Iterator[SequenceMaskAndSizes]:
186
+ """
187
+ Creates sequences of length `seq_len` from the dataset iterator by concatenating samples.
188
+ """
189
+ x_buffer: List[int] = []
190
+ y_buffer: List[int] = []
191
+ mask_buffer: Mask = []
192
+
193
+ sizes: List[int] = []
194
+ n_missing = seq_len
195
+ for sample in ds_it:
196
+ assert 0 <= len(x_buffer) < seq_len, len(x_buffer)
197
+ assert n_missing == seq_len - len(
198
+ x_buffer
199
+ ), f"n_missing: {n_missing} | seq_len - len(x_buffer) {seq_len - len(x_buffer)}"
200
+
201
+ tokens, mask = sample.tokens, sample.masks[1:]
202
+ x, y = tokens[:-1], tokens[1:]
203
+ cur_pos = 0
204
+
205
+ while cur_pos < len(x):
206
+ size = len(x[cur_pos : cur_pos + n_missing])
207
+
208
+ curr_mask = mask[cur_pos : cur_pos + n_missing]
209
+ if not any(curr_mask):
210
+ cur_pos += size
211
+ # we have a sequence with a mask filled with False
212
+ continue
213
+
214
+ x_buffer.extend(x[cur_pos : cur_pos + n_missing])
215
+ y_buffer.extend(y[cur_pos : cur_pos + n_missing])
216
+ mask_buffer.extend(curr_mask)
217
+ n_missing -= size
218
+ sizes.append(size)
219
+
220
+ cur_pos += size
221
+
222
+ if n_missing == 0:
223
+ assert len(mask_buffer) == len(x_buffer) == seq_len == len(y_buffer)
224
+ assert sum(sizes) == seq_len
225
+ # we don't want to yield sequences with a mask filled with False
226
+ if any(mask_buffer):
227
+ yield SequenceMaskAndSizes(
228
+ x=x_buffer,
229
+ y=y_buffer,
230
+ mask=mask_buffer,
231
+ sizes=sizes,
232
+ )
233
+ x_buffer, y_buffer = [], []
234
+ mask_buffer = []
235
+ sizes = []
236
+ n_missing = seq_len
237
+
238
+ if is_finite:
239
+ # if dataloader is in eval, pad to seq length
240
+ if any(mask_buffer):
241
+ mask_buffer.extend(n_missing * [False])
242
+ x_buffer.extend(n_missing * [0])
243
+ y_buffer.extend(n_missing * [0])
244
+ sizes.append(n_missing)
245
+
246
+ yield SequenceMaskAndSizes(
247
+ x=x_buffer,
248
+ y=y_buffer,
249
+ mask=mask_buffer,
250
+ sizes=sizes,
251
+ )
252
+
253
+
254
+ def build_dataset(
255
+ pretrain_data: str,
256
+ instruct_data: str,
257
+ instruct_args: InstructArgs,
258
+ instruct_tokenizer: InstructTokenizerBase,
259
+ seq_len: int,
260
+ seed: Optional[int],
261
+ rank: int,
262
+ world_size: int,
263
+ is_eval: bool,
264
+ shuffle_pretrain: bool = False,
265
+ ) -> Iterator[SequenceMaskAndSizes]:
266
+ sources, probabilities = parse_data_sources(
267
+ pretrain_data=pretrain_data, instruct_data=instruct_data
268
+ )
269
+
270
+ def do_shuffle(source: Union[DataDir, DataFile]) -> bool:
271
+ shuffle = {
272
+ SampleType.PRETRAIN: shuffle_pretrain,
273
+ SampleType.INSTRUCT: instruct_args.shuffle,
274
+ }[source.sample_type]
275
+
276
+ return not is_eval and shuffle
277
+
278
+ dataset_iterators = [
279
+ get_dataset_iterator(
280
+ source,
281
+ instruct_args=instruct_args,
282
+ instruct_tokenizer=instruct_tokenizer,
283
+ rank=rank,
284
+ world_size=world_size,
285
+ is_finite=is_eval,
286
+ seed=seed,
287
+ shuffle_at_epoch=do_shuffle(source),
288
+ )
289
+ for source in sources
290
+ ]
291
+
292
+ sequence_iterators = [
293
+ sequence_iterator(
294
+ ds_it=it,
295
+ seq_len=seq_len,
296
+ is_finite=is_eval,
297
+ )
298
+ for it in dataset_iterators
299
+ ]
300
+
301
+ if is_eval:
302
+ combined_iterator = itertools.chain.from_iterable(sequence_iterators)
303
+ else:
304
+ # make sure random_seed is different per rank and original seed
305
+ random_seed = np.array((seed, rank))
306
+ rng = np.random.RandomState(seed=random_seed)
307
+ combined_iterator = interleave_iterators(
308
+ sequence_iterators, probabilities=probabilities, rng=rng
309
+ )
310
+
311
+ return combined_iterator
312
+
313
+
314
+ def get_rng(seed: int, rank: int) -> np.random.RandomState:
315
+ random_seed = np.array((seed, rank))
316
+ rng = np.random.RandomState(seed=random_seed)
317
+ return rng
318
+
319
+
320
+ def get_dataset_iterator(
321
+ source: Union[DataDir, DataFile],
322
+ instruct_args: InstructArgs,
323
+ instruct_tokenizer: InstructTokenizerBase,
324
+ rank: int,
325
+ world_size: int,
326
+ is_finite: bool,
327
+ seed: Optional[int],
328
+ shuffle_at_epoch: bool,
329
+ ) -> Iterator[TokenSample]:
330
+ jsonl_files = source.jsonl_files
331
+ rng: Optional[np.random.RandomState] = (
332
+ get_rng(seed, rank) if seed is not None else None
333
+ )
334
+
335
+ chunk_dataset = (
336
+ instruct_args.dynamic_chunk_fn_call
337
+ and source.sample_type == SampleType.INSTRUCT
338
+ )
339
+
340
+ if not is_finite:
341
+ # train mode
342
+ while True:
343
+ for jsonl_file in jsonl_files:
344
+ if shuffle_at_epoch:
345
+ assert rng is not None, "`seed` has to be passed when shuffling"
346
+ # will preload all data into RAM, shuffle and yield
347
+ yield from preload_and_yield(
348
+ jsonl_file,
349
+ chunk_dataset=chunk_dataset,
350
+ rank=rank,
351
+ world_size=world_size,
352
+ rng=rng,
353
+ instruct_tokenizer=instruct_tokenizer,
354
+ sample_type=source.sample_type,
355
+ )
356
+ else:
357
+ # will read data on-the-fly and yield
358
+ main_logger_info(f"Lazily loading {jsonl_file} ...")
359
+ yield from lazy_load_and_yield(
360
+ jsonl_file,
361
+ rank=rank,
362
+ world_size=world_size,
363
+ instruct_tokenizer=instruct_tokenizer,
364
+ sample_type=source.sample_type,
365
+ )
366
+ else:
367
+ # eval mode
368
+ for jsonl_file in jsonl_files:
369
+ # No need to shuffle for eval
370
+ yield from lazy_load_and_yield(
371
+ jsonl_file,
372
+ rank=rank,
373
+ world_size=world_size,
374
+ instruct_tokenizer=instruct_tokenizer,
375
+ sample_type=source.sample_type,
376
+ )
377
+
378
+
379
+ def preload_and_yield(
380
+ jsonl_file: Path,
381
+ chunk_dataset: bool,
382
+ rank: int,
383
+ world_size: int,
384
+ rng: np.random.RandomState,
385
+ instruct_tokenizer: InstructTokenizerBase,
386
+ sample_type: SampleType,
387
+ ) -> Iterator[TokenSample]:
388
+ # only instruct data has to be chunked
389
+ # load dataset if not already loaded. Make sure to only load 1/world_size dataset
390
+ tokens_list = maybe_load_local_dataset(
391
+ jsonl_file, chunk=chunk_dataset, rank=rank, world_size=world_size, instruct_tokenizer=instruct_tokenizer, sample_type=sample_type
392
+ )
393
+
394
+ if sample_type == SampleType.PRETRAIN:
395
+ assert chunk_dataset is False, "Pretrain data should not have chunking enabled."
396
+
397
+ main_logger_info(f"Shuffling {jsonl_file} ...")
398
+ rng.shuffle(tokens_list)
399
+
400
+ for token_sample in tokens_list:
401
+ yield token_sample
402
+
403
+ def lazy_load_and_yield(
404
+ jsonl_file: Path,
405
+ rank: int,
406
+ world_size: int,
407
+ instruct_tokenizer: InstructTokenizerBase,
408
+ sample_type: SampleType,
409
+ ):
410
+ with jsonl_file.open() as file_handle:
411
+ for idx, line in enumerate(file_handle):
412
+ if not idx % world_size == rank:
413
+ continue
414
+
415
+ data = json.loads(line)
416
+ yield encode(
417
+ data,
418
+ instruct_tokenizer=instruct_tokenizer,
419
+ as_type=sample_type,
420
+ )
421
+
422
+
423
+ def maybe_chunk_lines(lines: List[str]) -> List[str]:
424
+ extra_lines: List[str] = []
425
+ for line in lines:
426
+ data = json.loads(line)
427
+ # mult-turn fn call data will be chunked and shorder conversations are added additionally
428
+ maybe_chunked_lines = maybe_chunk_data(data)
429
+ extra_lines.extend([json.dumps(line) for line in maybe_chunked_lines])
430
+
431
+ return extra_lines
432
+
433
+
434
+ def maybe_chunk_data(data: Dict[str, Any]) -> List[Dict[str, Any]]:
435
+ # think about always allowing both open-ai and non-open-ai data
436
+ sample = build_instruct_sample(data)
437
+
438
+ def num_assistant_messages(sample: TrainingInstructSample) -> int:
439
+ return len(
440
+ [m for m in sample.messages if isinstance(m, FinetuningAssistantMessage)]
441
+ )
442
+
443
+ chunk_data = []
444
+ while sample.only_last is True and num_assistant_messages(sample) > 1:
445
+ assert sample == build_instruct_sample(sample.dict())
446
+ last_message = sample.messages.pop()
447
+
448
+ # 1. First pop until and including last assistant message
449
+ system_message = None
450
+ while not isinstance(last_message, FinetuningAssistantMessage):
451
+ last_message = sample.messages.pop()
452
+ if isinstance(last_message, SystemMessage):
453
+ system_message = last_message
454
+
455
+ # 2. Second pop until and excluding last assistant message
456
+ prev_last_message = sample.messages[-1]
457
+ while not isinstance(prev_last_message, FinetuningAssistantMessage):
458
+ last_message = sample.messages.pop()
459
+ if isinstance(last_message, SystemMessage):
460
+ system_message = last_message
461
+
462
+ prev_last_message = sample.messages[-1]
463
+
464
+ # if system_message is not None, append again
465
+ if system_message is not None:
466
+ sample.messages.append(system_message)
467
+ chunk_data.append(sample.dict())
468
+
469
+ return chunk_data
470
+
471
+
472
+ def interleave_iterators(iterators: List[Iterator], probabilities, rng):
473
+ while True:
474
+ it_id = rng.choice(range(len(iterators)), p=probabilities)
475
+ yield next(iterators[it_id])
finetune/data/exceptions.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class MessageFormatError(Exception):
2
+ def __init__(self, message, data):
3
+ self._message = message
4
+ self._begin_data = data[:20]
5
+ super().__init__()
6
+
7
+ def __str__(self):
8
+ return f"A message starting with {self._begin_data} is incorrectly formated." + self._message
9
+
10
+
11
+ class ToolCallFormatError(Exception):
12
+ def __init__(self, message, data):
13
+ self._message = message
14
+ self._begin_data = data[:20]
15
+ super().__init__()
16
+
17
+ def __str__(self):
18
+ return f"A tool call assistant message starting with {self._begin_data} of the conversation is incorrectly formated. " + self._message
19
+
20
+
21
+ class FunctionFormatError(Exception):
22
+ def __init__(self, message, data):
23
+ self._message = message
24
+ self._begin_data = data[:20]
25
+ super().__init__()
26
+
27
+ def __str__(self):
28
+ return (
29
+ f"A function of the conversation starting with {self._begin_data} is incorrectly formated. "
30
+ + self._message
31
+ )
32
+
33
+
34
+ class ConversationFormatError(Exception):
35
+ def __init__(self, message, data):
36
+ self._message = message
37
+ self._begin_data = data[:20]
38
+ super().__init__()
39
+
40
+ def __str__(self):
41
+ return (
42
+ f"A conversation starting with {self._begin_data} is incorrectly formated. " + self._message
43
+ )
44
+
45
+
46
+ class UnrecognizedRoleError(Exception):
47
+ def __init__(self, role, allowed_roles):
48
+ self._role = role
49
+ self._allowed_roles = allowed_roles
50
+ super().__init__()
51
+
52
+ def __str__(self):
53
+ return (
54
+ f"The following role: {self._role} is not recognized in line: {self.line} of the dataset {self.dataset}. Make sure that each role is one of {self._allowed_roles}"
55
+ + self._message
56
+ )
finetune/data/tokenize.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass
3
+ from enum import Enum
4
+ from typing import Any, Dict, List, Optional, Union
5
+
6
+ from mistral_common.protocol.instruct.messages import (
7
+ FinetuningAssistantMessage,
8
+ Roles,
9
+ SystemMessage,
10
+ ToolMessage,
11
+ UserMessage,
12
+ )
13
+ from mistral_common.protocol.instruct.tool_calls import (
14
+ Function,
15
+ FunctionCall,
16
+ Tool,
17
+ ToolCall,
18
+ )
19
+ from mistral_common.protocol.instruct.validator import (
20
+ MistralRequestValidatorV3,
21
+ ValidationMode,
22
+ )
23
+ from mistral_common.tokens.instruct.request import InstructRequest
24
+ from mistral_common.tokens.tokenizers.base import Tokenizer
25
+ from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase
26
+
27
+ from .exceptions import (
28
+ ConversationFormatError,
29
+ FunctionFormatError,
30
+ MessageFormatError,
31
+ ToolCallFormatError,
32
+ UnrecognizedRoleError,
33
+ )
34
+
35
+ logger = logging.getLogger("tokenize")
36
+
37
+ Sequence = List[int]
38
+ Mask = List[bool]
39
+
40
+
41
+ class TrainingInstructSample(InstructRequest):
42
+ available_tools: Optional[List[Tool]] = None
43
+ only_last: bool = False
44
+
45
+
46
+ @dataclass()
47
+ class TokenSample:
48
+ tokens: Sequence
49
+ masks: Mask
50
+
51
+
52
+ class SampleType(str, Enum):
53
+ PRETRAIN = "pretrain"
54
+ INSTRUCT = "instruct"
55
+
56
+
57
+ def encode(
58
+ data: Dict[str, Any],
59
+ instruct_tokenizer: InstructTokenizerBase,
60
+ as_type: SampleType,
61
+ ) -> TokenSample:
62
+ sample: Union[str, TrainingInstructSample]
63
+ if as_type == SampleType.PRETRAIN:
64
+ sample = get_pretrain_sample(data)
65
+ elif as_type == SampleType.INSTRUCT:
66
+ sample = build_instruct_sample(data)
67
+
68
+ return tokenize(sample=sample, instruct_tokenizer=instruct_tokenizer)
69
+
70
+
71
+ def get_pretrain_sample(data: Dict[str, Any]) -> str:
72
+ content_keys = ["text", "content"]
73
+ assert not all(
74
+ k in data for k in content_keys
75
+ ), "Make sure to have either 'text' or 'content' in your data. Not both."
76
+ assert any(
77
+ data.get(k) is not None for k in content_keys
78
+ ), f"Must have one of 'text' or 'content' in your data. Only have {data.keys()}"
79
+
80
+ # get first non-None value
81
+ sample = None
82
+ for key in content_keys:
83
+ sample = data[key] if key in data else sample
84
+
85
+ assert isinstance(sample, str), sample
86
+
87
+ return sample
88
+
89
+
90
+ def build_instruct_sample(data: Dict[str, Any]) -> TrainingInstructSample:
91
+ messages: List[
92
+ SystemMessage | UserMessage | FinetuningAssistantMessage | ToolMessage
93
+ ] = []
94
+ # optional data fields that might be set
95
+ available_tools: Optional[List[Tool]] = data.get("available_tools")
96
+ system_prompt = data.get("system_prompt")
97
+
98
+ messages_keys = ["messages", "interactions"]
99
+ content_keys = ["content", "text"] # both are accepted
100
+ allowed_roles = [role.value for role in Roles]
101
+
102
+ if not any(messages_key in data for messages_key in messages_keys):
103
+ err = f"The conversation does not contain one of '{', '.join(messages_keys)}' key, but only {', '.join(data.keys())}. Make sure that the conversation includes one of '{', '.join(messages_keys)}'."
104
+ raise ConversationFormatError(err, str(data))
105
+
106
+ if all(messages_key in data for messages_key in messages_keys):
107
+ err = f"The conversation cannot contain both of '{', '.join(messages_keys)}' key, but only one of the two."
108
+ raise ConversationFormatError(err, str(data))
109
+
110
+ # get first non-None value
111
+ data_messages: Optional[List[Dict[str, Any]]] = None
112
+ for key in messages_keys:
113
+ data_messages = data[key] if key in data else data_messages
114
+
115
+ assert data_messages is not None, "data_messages can't be None"
116
+
117
+ if "available_tools" in data and "tools" in data:
118
+ err = "The conversation contains both an `available_tools` and `tools` key. You can only have one."
119
+ raise ConversationFormatError(err, str(data))
120
+
121
+ if data.get("tools", None) is not None and len(data["tools"]) > 0:
122
+ available_tools = _parse_available_tools(data["tools"])
123
+ elif (
124
+ data.get("available_tools", None) is not None
125
+ and len(data["available_tools"]) > 0
126
+ ):
127
+ available_tools = _parse_available_tools(data["available_tools"])
128
+
129
+ for data_message in data_messages:
130
+ is_tool_call = data_message.get("tool_calls") is not None
131
+
132
+ if "role" not in data_message:
133
+ err = f"A message does not contain a 'role' key, but only {', '.join(data_message.keys())}. Make sure that the message includes the key 'role'."
134
+ raise MessageFormatError(err, str(data))
135
+
136
+ role = data_message["role"]
137
+
138
+ if all(key in data_message for key in content_keys):
139
+ err = f"A {role} message contains both a 'text' and 'content' key. Make sure that there is only one of the two."
140
+ raise MessageFormatError(err, str(data))
141
+
142
+ content: Optional[str] = None
143
+ for key in content_keys:
144
+ content = content if content is not None else data_message.get(key)
145
+
146
+ # non-function call message must have content
147
+ if not is_tool_call and content is None:
148
+ err = f"A {role} message does not contain one of '{content_keys}' key, but only {', '.join(data_message.keys())}. Make sure that the message includes one of '{content_keys}' keys."
149
+ raise MessageFormatError(err, str(data))
150
+
151
+ if role not in allowed_roles:
152
+ raise UnrecognizedRoleError(role, allowed_roles)
153
+
154
+ if data_message["role"] == "user":
155
+ assert content is not None
156
+ messages.append(UserMessage(content=content))
157
+ elif data_message["role"] == "assistant":
158
+ tool_calls: Optional[List[ToolCall]] = None
159
+
160
+ if is_tool_call:
161
+ tool_calls = _parse_tool_calls(data_message["tool_calls"])
162
+
163
+ weight = data_message.get("weight")
164
+ messages.append(
165
+ FinetuningAssistantMessage(
166
+ content=content, tool_calls=tool_calls, weight=weight
167
+ )
168
+ )
169
+ elif data_message["role"] == "system":
170
+ if system_prompt is not None:
171
+ err = "Multiple messages with role 'system' encountered. Only one is allowed."
172
+ raise MessageFormatError(err, str(data))
173
+
174
+ system_prompt = content
175
+ elif data_message["role"] == "tool":
176
+ assert content is not None
177
+ tool_message = _parse_tool_message(content, data_message)
178
+ messages.append(tool_message)
179
+
180
+ # validate created messages
181
+ validator = MistralRequestValidatorV3(ValidationMode.finetuning)
182
+ validator.validate_messages(messages)
183
+ validator._validate_tools(available_tools or [])
184
+
185
+ # whether to train only on last assistant message
186
+ only_last = data.get("only_last", False) or available_tools is not None
187
+
188
+ return TrainingInstructSample(
189
+ messages=messages,
190
+ system_prompt=system_prompt,
191
+ available_tools=available_tools,
192
+ only_last=only_last,
193
+ )
194
+
195
+
196
+ def _parse_available_tools(tools: List[Dict[str, Any]]) -> List[Tool]:
197
+ available_tools = []
198
+ for tool in tools:
199
+ if "function" not in tool:
200
+ raise FunctionFormatError(
201
+ "A tool dict does not have a 'function' key.", str(tool)
202
+ )
203
+
204
+ func_data = tool["function"]
205
+
206
+ for key in ["name", "description", "parameters"]:
207
+ if key not in func_data:
208
+ raise FunctionFormatError(
209
+ f"A function dict does not have a {key} key.", str(func_data)
210
+ )
211
+
212
+ if not isinstance(func_data["parameters"], dict):
213
+ raise FunctionFormatError(
214
+ f"A function 'parameters' key has to be of type dict, but is {type(func_data['parameters'])}. If the function has no parameters pass an empyt dict ", str(func_data)
215
+ )
216
+
217
+ description = func_data["description"]
218
+ function = Function(
219
+ name=func_data["name"],
220
+ description=description,
221
+ parameters=func_data["parameters"],
222
+ )
223
+
224
+ available_tools.append(Tool(function=function))
225
+ return available_tools
226
+
227
+
228
+ def _parse_tool_calls(calls: List[Dict[str, Any]]) -> List[ToolCall]:
229
+ for key in ["id", "function"]:
230
+ if not all(key in call for call in calls):
231
+ err = f"A tool call of an assistant message does not have a {key} key"
232
+ raise ToolCallFormatError(err, str(calls))
233
+
234
+ for key in ["name", "arguments"]:
235
+ if not all(key in call["function"] for call in calls):
236
+ err = (
237
+ f"A tool call function of an assistant message does not have a {key} key"
238
+ )
239
+ raise ToolCallFormatError(err, str(calls))
240
+
241
+ if not all(isinstance(call["function"]["arguments"], str) for call in calls):
242
+ err = "A tool call function of an assistant message does not have a 'arguments' key of type str"
243
+ raise ToolCallFormatError(err, str(calls))
244
+
245
+ tool_calls = [
246
+ ToolCall(
247
+ id=call["id"],
248
+ function=FunctionCall(
249
+ name=call["function"]["name"],
250
+ arguments=call["function"]["arguments"],
251
+ ),
252
+ )
253
+ for call in calls
254
+ ]
255
+ return tool_calls
256
+
257
+
258
+ def _parse_tool_message(content: str, data_message: Dict[str, Any]) -> ToolMessage:
259
+ if "tool_call_id" not in data_message:
260
+ err = f"A tool message does not contain a 'tool_call_id' key, but only {', '.join(data_message.keys())}. Make sure that the message includes the key 'tool_call_id'."
261
+ raise MessageFormatError(err, str(data_message))
262
+
263
+ call_id = data_message["tool_call_id"]
264
+ # name is deprecated in v3, but we'll add it nevertheless for now
265
+ name = data_message.get("name")
266
+
267
+ return ToolMessage(content=content, tool_call_id=call_id, name=name)
268
+
269
+
270
+ def tokenize(
271
+ sample: Union[str, TrainingInstructSample],
272
+ instruct_tokenizer: InstructTokenizerBase,
273
+ ) -> TokenSample:
274
+ if isinstance(sample, str):
275
+ tokenizer: Tokenizer = instruct_tokenizer.tokenizer
276
+ return tokenize_pretrain(sample, tokenizer)
277
+ elif isinstance(sample, TrainingInstructSample):
278
+ return tokenize_instruct(sample, instruct_tokenizer)
279
+
280
+ raise ValueError(
281
+ f"`sample` has to be either of type `str` or `TrainingInstructSample`, not {type(sample)}."
282
+ )
283
+
284
+
285
+ def tokenize_pretrain(sample: str, tokenizer: Tokenizer) -> TokenSample:
286
+ tokens = tokenizer.encode(sample, bos=True, eos=True)
287
+ masks = [True] * len(tokens)
288
+ return TokenSample(tokens, masks)
289
+
290
+
291
+ def tokenize_instruct(
292
+ sample: TrainingInstructSample,
293
+ instruct_tokenizer: InstructTokenizerBase,
294
+ ) -> TokenSample:
295
+ tokens: List[int] = instruct_tokenizer.start()
296
+ masks: List[bool] = [False]
297
+
298
+ mask_all_but_last = sample.only_last
299
+
300
+ # find first and last user message
301
+ user_messages = [
302
+ i for i, msg in enumerate(sample.messages) if isinstance(msg, UserMessage)
303
+ ]
304
+ first_user_idx = user_messages[0] if user_messages else -1
305
+ last_user_idx = user_messages[-1] if user_messages else -1
306
+
307
+ for msg_idx, message in enumerate(sample.messages):
308
+ if isinstance(message, UserMessage):
309
+ curr_tokens = instruct_tokenizer.encode_user_message(
310
+ message,
311
+ available_tools=sample.available_tools,
312
+ is_last=msg_idx == last_user_idx,
313
+ is_first=msg_idx == first_user_idx,
314
+ system_prompt=sample.system_prompt,
315
+ )
316
+ curr_masks = [False] * len(curr_tokens) # only predict bot answers
317
+ elif isinstance(message, ToolMessage):
318
+ curr_tokens = instruct_tokenizer.encode_tool_message(
319
+ message, is_before_last_user_message=msg_idx < last_user_idx
320
+ )
321
+ curr_masks = [False] * len(curr_tokens) # only predict bot answers
322
+ elif isinstance(message, FinetuningAssistantMessage):
323
+ is_last_message = msg_idx == (len(sample.messages) - 1)
324
+
325
+ # we don't want to predict a random call id
326
+ message = maybe_remove_call_id(message, is_last_message=is_last_message)
327
+
328
+ curr_tokens = instruct_tokenizer.encode_assistant_message(
329
+ message, is_before_last_user_message=False
330
+ )
331
+
332
+ is_weighted = message.weight is None or message.weight == 1
333
+ is_relevant = (not mask_all_but_last) or is_last_message
334
+ if is_weighted and is_relevant:
335
+ curr_masks = [True] * len(curr_tokens) # only predict bot answers
336
+ else:
337
+ # in function calling we only backprop through last message
338
+ curr_masks = [False] * len(curr_tokens)
339
+
340
+ tokens.extend(curr_tokens)
341
+ masks.extend(curr_masks)
342
+
343
+ return TokenSample(tokens, masks)
344
+
345
+
346
+ def maybe_remove_call_id(message: FinetuningAssistantMessage, is_last_message: bool):
347
+ if message.tool_calls is None or not is_last_message:
348
+ return message
349
+
350
+ # remove call id
351
+ message.tool_calls = [
352
+ ToolCall(function=call.function) for call in message.tool_calls
353
+ ]
354
+
355
+ return message
finetune/distributed.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from functools import lru_cache
4
+ from typing import List, Union
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+ logger = logging.getLogger("distributed")
10
+
11
+ BACKEND = "nccl"
12
+
13
+
14
+ @lru_cache()
15
+ def get_rank() -> int:
16
+ return dist.get_rank()
17
+
18
+
19
+ @lru_cache()
20
+ def get_world_size() -> int:
21
+ return dist.get_world_size()
22
+
23
+
24
+ def visible_devices() -> List[int]:
25
+ return [int(d) for d in os.environ["CUDA_VISIBLE_DEVICES"].split(",")]
26
+
27
+
28
+ def set_device():
29
+ logger.info(f"torch.cuda.device_count: {torch.cuda.device_count()}")
30
+ logger.info(f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}")
31
+ logger.info(f"local rank: {int(os.environ['LOCAL_RANK'])}")
32
+
33
+ assert torch.cuda.is_available()
34
+
35
+ assert len(visible_devices()) == torch.cuda.device_count()
36
+
37
+ if torch.cuda.device_count() == 1:
38
+ # gpus-per-task set to 1
39
+ torch.cuda.set_device(0)
40
+ return
41
+
42
+ local_rank = int(os.environ["LOCAL_RANK"])
43
+ logger.info(f"Set cuda device to {local_rank}")
44
+
45
+ assert 0 <= local_rank < torch.cuda.device_count(), (
46
+ local_rank,
47
+ torch.cuda.device_count(),
48
+ )
49
+ torch.cuda.set_device(local_rank)
50
+
51
+
52
+ def avg_aggregate(metric: Union[float, int]) -> Union[float, int]:
53
+ buffer = torch.tensor([metric], dtype=torch.float32, device="cuda")
54
+ dist.all_reduce(buffer, op=dist.ReduceOp.SUM)
55
+ return buffer[0].item() / get_world_size()
56
+
57
+
58
+ def is_torchrun() -> bool:
59
+ return "TORCHELASTIC_RESTART_COUNT" in os.environ
finetune/eval.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import torch.cuda
6
+ import torch.distributed as dist
7
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
8
+
9
+ from .data.data_loader import Batch
10
+ from .distributed import get_rank, get_world_size
11
+ from .loss import compute_loss_with_mask
12
+ from .utils import TrainState
13
+
14
+ logger = logging.getLogger("eval")
15
+
16
+
17
+ def main_logger_info(message: str) -> None:
18
+ if get_rank() == 0:
19
+ logger.info(message)
20
+
21
+
22
+ def evaluate(
23
+ model: FullyShardedDataParallel,
24
+ batches: List[Batch],
25
+ state: TrainState,
26
+ ):
27
+ # Create fake samples to make FSDP happy for unbalanced data
28
+ num_samples = torch.tensor([len(batches)], device="cuda", dtype=torch.long)
29
+ all_num_samples = [torch.zeros_like(num_samples) for _ in range(get_world_size())]
30
+
31
+ torch.distributed.all_gather(all_num_samples, num_samples)
32
+
33
+ total_num_samples = int(torch.tensor(all_num_samples).sum().item())
34
+ max_num_samples = int(torch.tensor(all_num_samples).max().item())
35
+
36
+ for _ in range(max_num_samples - int(num_samples.item())):
37
+ pad_x = np.zeros_like(batches[-1].x)
38
+ pad_y = np.zeros_like(batches[-1].y)
39
+ pad_sizes = batches[-1].sizes.copy()
40
+
41
+ pad_batch = Batch(pad_x, pad_y, pad_sizes, is_pad_only=True)
42
+ batches.append(pad_batch)
43
+
44
+ # eval mode!
45
+ model.eval()
46
+
47
+ eval_loss = torch.tensor(0.0).cuda()
48
+ main_logger_info("Start eval...")
49
+ for batch in batches:
50
+ x = torch.from_numpy(batch.x).cuda()
51
+ y = torch.from_numpy(batch.y).cuda()
52
+ y_mask = (
53
+ torch.from_numpy(batch.y_mask).cuda() if batch.y_mask is not None else None
54
+ )
55
+
56
+ with torch.no_grad():
57
+ output = model(
58
+ input_ids=x,
59
+ seqlens=batch.sizes,
60
+ )
61
+
62
+ if y_mask.sum() > 0:
63
+ eval_loss += compute_loss_with_mask(output, y, y_mask)
64
+
65
+ assert batch.is_pad_only or y.abs().sum() != 0, "Pad sample is used to compute loss."
66
+
67
+ # sum loss
68
+ main_logger_info("Eval finished!")
69
+
70
+ dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
71
+ eval_loss /= total_num_samples
72
+
73
+ state.this_eval_loss = eval_loss.item()
74
+ state.this_eval_perplexity = (2**eval_loss).item()
75
+
76
+ # train mode!
77
+ model.train()
finetune/loss.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+
7
+ def compute_loss_with_mask(
8
+ logits: torch.Tensor, target: torch.Tensor, target_mask: Optional[torch.Tensor]
9
+ ):
10
+ if target_mask is None:
11
+ return F.cross_entropy(logits, target, reduction="mean")
12
+
13
+ mb_loss = F.cross_entropy(logits, target, reduction="none")
14
+ mb_loss = torch.sum(mb_loss * target_mask) / torch.sum(target_mask)
15
+
16
+ return mb_loss
finetune/mixed_precision.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterable
2
+
3
+ import torch
4
+
5
+
6
+ def prepare_mixed_precision(
7
+ params: Iterable[torch.nn.Parameter],
8
+ param_dtype: torch.dtype,
9
+ optim_dtype: torch.dtype,
10
+ ):
11
+ """Appends a freshly allocated fp32 tensor copy of all params to parameters that can be updated."""
12
+ with torch.no_grad():
13
+ for p in params:
14
+ if p.requires_grad:
15
+ # Mixed precision: let's save a fp32 param tensor to each params that require a grad
16
+ p._mp_param = torch.empty_like(p, dtype=optim_dtype) # type: ignore
17
+ p._mp_param.copy_(p.to(optim_dtype)) # type: ignore
18
+
19
+ p.data = p.data.to(param_dtype)
20
+
21
+
22
+ def upcast_mixed_precision(
23
+ params: Iterable[torch.nn.Parameter], optim_dtype: torch.dtype
24
+ ):
25
+ """Make sure to run this function BEFORE optimizer.step() so that all weights and optimizer states are updated in fp32 in .step()"""
26
+ with torch.no_grad():
27
+ for p in params:
28
+ if p.requires_grad and p.grad is not None:
29
+ # store original tensor in p._temp
30
+ p._temp = p.data # type: ignore
31
+ # upcast data for the optimizer step
32
+ p.data = p._mp_param # type: ignore
33
+ p.grad = p.grad.to(optim_dtype)
34
+
35
+
36
+ def downcast_mixed_precision(
37
+ params: Iterable[torch.nn.Parameter], param_dtype: torch.dtype
38
+ ):
39
+ """Make sure to run this function AFTER optimizer.step() as optimizer.step() will update data underlying p.data and p._mp_param pointers"""
40
+ with torch.no_grad():
41
+ for p in params:
42
+ if p.requires_grad and p.grad is not None:
43
+ # copy fp32 weights into bfloat16 tensor
44
+ p._temp.copy_(p.data) # type: ignore
45
+ # set _temp again to the data tensor
46
+ p.data = p._temp # type: ignore
47
+ p.grad = p.grad.to(param_dtype)
finetune/monitoring/__init__.py ADDED
File without changes
finetune/monitoring/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (147 Bytes). View file
 
finetune/monitoring/__pycache__/metrics_logger.cpython-310.pyc ADDED
Binary file (5.46 kB). View file
 
finetune/monitoring/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.27 kB). View file
 
finetune/monitoring/metrics_logger.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from datetime import datetime, timedelta
5
+ from pathlib import Path
6
+ from typing import Any, Dict, Optional, Union
7
+
8
+ from torch.utils.tensorboard import SummaryWriter
9
+
10
+ from finetune.args import MLFlowArgs, TrainArgs, WandbArgs
11
+ from finetune.utils import TrainState
12
+
13
+ logger = logging.getLogger("metrics_logger")
14
+
15
+ GB = 1024**3
16
+
17
+
18
+ def get_train_logs(
19
+ state: TrainState,
20
+ loss: float,
21
+ lr: float,
22
+ peak_allocated_mem: float,
23
+ allocated_mem: float,
24
+ train_args: TrainArgs,
25
+ ) -> Dict[str, Union[float, int]]:
26
+ metrics = {
27
+ "lr": lr,
28
+ "step": state.step,
29
+ "loss": loss,
30
+ "percent_done": 100 * state.step / train_args.max_steps,
31
+ "peak_allocated_mem": peak_allocated_mem / GB,
32
+ "allocated_mem": allocated_mem / GB,
33
+ "wps": state.wps,
34
+ "avg_wps": state.avg_wps,
35
+ "eta_in_seconds": state.eta,
36
+ }
37
+
38
+ return metrics
39
+
40
+
41
+ def get_eval_logs(
42
+ step: int,
43
+ train_loss: float,
44
+ perplexity: Optional[float],
45
+ eval_loss: Optional[float],
46
+ ) -> Dict[str, Union[float, int]]:
47
+ eval_dict = {"step": step, "train_loss": train_loss}
48
+
49
+ if perplexity is not None:
50
+ eval_dict["perplexity"] = perplexity
51
+
52
+ if eval_loss is not None:
53
+ eval_dict["eval_loss"] = eval_loss
54
+ return eval_dict
55
+
56
+
57
+ def train_log_msg(
58
+ state: TrainState, logs: Dict[str, Union[float, int]], loss: float
59
+ ) -> str:
60
+ metrics: Dict[str, Union[float, int, datetime]] = dict(logs) # shallow copy
61
+ metrics.pop("eta_in_seconds")
62
+
63
+ metrics["eta"] = datetime.now() + timedelta(seconds=state.eta)
64
+ metrics["step"] = state.step
65
+ metrics["loss"] = loss
66
+
67
+ parts = []
68
+ for key, fmt, new_name in [
69
+ ("step", "06", None),
70
+ ("percent_done", "03.1f", "done (%)"),
71
+ ("loss", ".3f", None),
72
+ ("lr", ".1e", None),
73
+ ("peak_allocated_mem", ".1f", "peak_alloc_mem (GB)"),
74
+ ("allocated_mem", ".1f", "alloc_mem (GB)"),
75
+ ("wps", ".1f", "words_per_second"),
76
+ ("avg_wps", ".1f", "avg_words_per_second"),
77
+ ("eta", "%Y-%m-%d %H:%M:%S", "ETA"),
78
+ ]:
79
+ name = key if new_name is None else new_name
80
+ try:
81
+ parts.append(f"{name}: {metrics[key]:>{fmt}}")
82
+ except KeyError:
83
+ logger.error(f"{key} not found in {sorted(metrics.keys())}")
84
+ raise
85
+
86
+ return " - ".join(parts)
87
+
88
+
89
+ def eval_log_msg(logs: Dict[str, Union[float, int]]) -> str:
90
+ parts = []
91
+ for key, fmt, new_name in [
92
+ ("step", "06", None),
93
+ ("perplexity", ".3f", "eval_perplexity"),
94
+ ("eval_loss", ".3f", None),
95
+ ("train_loss", ".3f", None),
96
+ ]:
97
+ name = key if new_name is None else new_name
98
+ if key in logs:
99
+ parts.append(f"{name}: {logs[key]:>{fmt}}")
100
+
101
+ return " - ".join(parts)
102
+
103
+
104
+ class MetricsLogger:
105
+ def __init__(
106
+ self,
107
+ dst_dir: Path,
108
+ tag: str,
109
+ is_master: bool,
110
+ wandb_args: WandbArgs,
111
+ mlflow_args: MLFlowArgs,
112
+ config: Optional[Dict[str, Any]] = None,
113
+ ):
114
+ self.dst_dir = dst_dir
115
+ self.tag = tag
116
+ self.is_master = is_master
117
+ self.jsonl_path = dst_dir / f"metrics.{tag}.jsonl"
118
+ self.tb_dir = dst_dir / "tb"
119
+ self.summary_writer: Optional[SummaryWriter] = None
120
+
121
+ if not self.is_master:
122
+ return
123
+
124
+ filename_suffix = f".{tag}"
125
+ self.tb_dir.mkdir(exist_ok=True)
126
+ self.summary_writer = SummaryWriter(
127
+ log_dir=str(self.tb_dir),
128
+ max_queue=1000,
129
+ filename_suffix=filename_suffix,
130
+ )
131
+ self.is_wandb = wandb_args.project is not None
132
+ self.is_mlflow = mlflow_args.tracking_uri is not None
133
+
134
+ if self.is_wandb:
135
+ import wandb
136
+
137
+ if wandb_args.key is not None:
138
+ wandb.login(key=wandb_args.key) # LLM
139
+ if wandb_args.offline:
140
+ os.environ["WANDB_MODE"] = "offline"
141
+ if wandb.run is None:
142
+ logger.info("initializing wandb")
143
+ wandb.init(
144
+ config=config,
145
+ dir=dst_dir,
146
+ project=wandb_args.project,
147
+ job_type="training",
148
+ name=wandb_args.run_name or dst_dir.name,
149
+ resume=False,
150
+ )
151
+
152
+ self.wandb_log = wandb.log
153
+
154
+ if self.is_mlflow:
155
+ import mlflow
156
+
157
+ mlflow.set_tracking_uri(mlflow_args.tracking_uri)
158
+ mlflow.set_experiment(mlflow_args.experiment_name or dst_dir.name)
159
+
160
+ if tag == "train":
161
+ mlflow.start_run()
162
+
163
+ self.mlflow_log = mlflow.log_metric
164
+
165
+ def log(self, metrics: Dict[str, Union[float, int]], step: int):
166
+ if not self.is_master:
167
+ return
168
+
169
+ metrics_to_ignore = {"step"}
170
+ assert self.summary_writer is not None
171
+ for key, value in metrics.items():
172
+ if key in metrics_to_ignore:
173
+ continue
174
+ assert isinstance(value, (int, float)), (key, value)
175
+ self.summary_writer.add_scalar(
176
+ tag=f"{self.tag}.{key}", scalar_value=value, global_step=step
177
+ )
178
+
179
+ if self.is_mlflow:
180
+ self.mlflow_log(f"{self.tag}.{key}", value, step=step)
181
+
182
+ if self.is_wandb:
183
+ # grouping in wandb is done with /
184
+ self.wandb_log(
185
+ {
186
+ f"{self.tag}/{key}": value
187
+ for key, value in metrics.items()
188
+ if key not in metrics_to_ignore
189
+ },
190
+ step=step,
191
+ )
192
+
193
+ metrics_: Dict[str, Any] = dict(metrics) # shallow copy
194
+ if "step" in metrics_:
195
+ assert step == metrics_["step"]
196
+ else:
197
+ metrics_["step"] = step
198
+ metrics_["at"] = datetime.utcnow().isoformat()
199
+ with self.jsonl_path.open("a") as fp:
200
+ fp.write(f"{json.dumps(metrics_)}\n")
201
+
202
+ def close(self):
203
+ if not self.is_master:
204
+ return
205
+
206
+ if self.summary_writer is not None:
207
+ self.summary_writer.close()
208
+ self.summary_writer = None
209
+
210
+ if self.is_wandb:
211
+ import wandb
212
+
213
+ # to be sure we are not hanging while finishing
214
+ wandb.finish()
215
+
216
+ if self.is_mlflow:
217
+ import mlflow
218
+
219
+ mlflow.end_run()
220
+
221
+ def __del__(self):
222
+ if self.summary_writer is not None:
223
+ raise RuntimeError(
224
+ "MetricsLogger not closed properly! You should "
225
+ "make sure the close() method is called!"
226
+ )
finetune/monitoring/utils.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import sys
4
+ import time
5
+
6
+
7
+ class DeltaTimeFormatter(logging.Formatter):
8
+ def format(self, record):
9
+ delta = datetime.timedelta(
10
+ seconds=int(record.relativeCreated / 1000)
11
+ ) # no milliseconds
12
+ record.delta = delta
13
+ return super().format(record)
14
+
15
+
16
+ def set_logger(level: int = logging.INFO):
17
+ root = logging.getLogger()
18
+ root.handlers.clear()
19
+ root.setLevel(level)
20
+ tz, *_ = time.tzname
21
+
22
+ LOGFORMAT = "%(asctime)s - %(delta)s - %(name)s - %(levelname)s - %(message)s"
23
+ TIMEFORMAT = f"%Y-%m-%d %H:%M:%S ({tz})"
24
+ formatter = DeltaTimeFormatter(LOGFORMAT, TIMEFORMAT)
25
+
26
+ handler = logging.StreamHandler(sys.stdout)
27
+ handler.setLevel(level)
28
+ handler.setFormatter(formatter)
29
+ root.addHandler(handler)
30
+
31
+ handler = logging.StreamHandler(sys.stderr)
32
+ handler.setLevel(logging.WARNING)
33
+ handler.setFormatter(formatter)
34
+ root.addHandler(handler)
finetune/utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import dataclasses
3
+ import datetime
4
+ import logging
5
+ import time
6
+ from typing import Optional, Protocol
7
+
8
+ import torch
9
+
10
+ logger = logging.getLogger("utils")
11
+
12
+
13
+ @dataclasses.dataclass
14
+ class TrainState:
15
+ max_steps: int
16
+ step: int = 0
17
+ elapsed_time: float = 0.0
18
+ n_seen_tokens: int = 0
19
+ this_step_time: float = 0.0
20
+ begin_step_time: float = 0.0
21
+ this_eval_perplexity: Optional[float] = None
22
+ this_eval_loss: Optional[float] = None
23
+
24
+ def start_step(self):
25
+ self.step += 1
26
+ self.begin_step_time = time.time()
27
+
28
+ def end_step(self, n_batch_tokens: int):
29
+ self.this_step_time = time.time() - self.begin_step_time
30
+ self.this_step_tokens = n_batch_tokens
31
+
32
+ self.elapsed_time += self.this_step_time
33
+ self.n_seen_tokens += self.this_step_tokens
34
+
35
+ self.begin_step_time = time.time()
36
+
37
+ @property
38
+ def wps(self):
39
+ return self.this_step_tokens / self.this_step_time
40
+
41
+ @property
42
+ def avg_wps(self):
43
+ return self.n_seen_tokens / self.elapsed_time
44
+
45
+ @property
46
+ def eta(self):
47
+ steps_left = self.max_steps - self.step
48
+ avg_time_per_step = self.elapsed_time / self.step
49
+
50
+ return steps_left * avg_time_per_step
51
+
52
+
53
+ def set_random_seed(seed: int) -> None:
54
+ """Set random seed for reproducibility."""
55
+ torch.manual_seed(seed)
56
+ torch.cuda.manual_seed(seed)
57
+
58
+
59
+ class Closable(Protocol):
60
+ def close(self):
61
+ pass
62
+
63
+
64
+ @contextlib.contextmanager
65
+ def logged_closing(thing: Closable, name: str):
66
+ """
67
+ Logging the closing to be sure something is not hanging at exit time
68
+ """
69
+ try:
70
+ setattr(thing, "wrapped_by_closing", True)
71
+ yield
72
+ finally:
73
+ logger.info(f"Closing: {name}")
74
+ try:
75
+ thing.close()
76
+ except Exception:
77
+ logger.error(f"Error while closing {name}!")
78
+ raise
79
+ logger.info(f"Closed: {name}")
80
+
81
+
82
+ def now_as_str() -> str:
83
+ return datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
finetune/wrapped_model.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import json
3
+ import logging
4
+ import math
5
+ from pathlib import Path
6
+ from typing import Callable, Union
7
+
8
+ import safetensors
9
+ import torch
10
+ import torch.distributed.fsdp.wrap as torch_wrap
11
+ from torch.distributed.fsdp import BackwardPrefetch
12
+ from torch.distributed.fsdp.api import ShardingStrategy
13
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
14
+
15
+ from model.args import ModelArgs, MoeArgs
16
+ from model.transformer import Transformer, TransformerBlock
17
+
18
+ from .args import LoraArgs
19
+ from .checkpointing import Checkpointer
20
+ from .distributed import (
21
+ get_rank,
22
+ get_world_size,
23
+ )
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ def main_logger_info(message: str) -> None:
29
+ if get_rank() == 0:
30
+ logger.info(message)
31
+
32
+
33
+ def get_fsdp_policy(is_lora: bool) -> Callable[[torch.nn.Module], bool]:
34
+ """
35
+ This function instantiates the FSDP wrap policy.
36
+ - Each Transformers block becomes it's own FSDP group so that only a single Transformer block is sharded at a time
37
+ - If LoRA is enabled, we additionally create seperate FSDP sub-groups for every trainable and non-trainable parameter group
38
+ since this is a requirement for mixed requires_grad=True/False training. See: https://pytorch.org/docs/stable/fsdp.html
39
+ """
40
+
41
+ # Each transformer block becomes a FSDP group, each being sharded seperately
42
+ transformer_block_wrap_policy = functools.partial(
43
+ torch_wrap.transformer_auto_wrap_policy,
44
+ transformer_layer_cls=(TransformerBlock,),
45
+ )
46
+
47
+ if not is_lora:
48
+ return transformer_block_wrap_policy
49
+
50
+ def fsdp_lora_policy_fn(module):
51
+ return all(p.requires_grad for p in module.parameters())
52
+
53
+ # For LoRA training, trainable and non-trainable parameters need to be put into
54
+ # different FSDP groups
55
+ fsdp_lora_policy = functools.partial(
56
+ torch_wrap.lambda_auto_wrap_policy, lambda_fn=fsdp_lora_policy_fn
57
+ )
58
+
59
+ policies = [fsdp_lora_policy, transformer_block_wrap_policy]
60
+
61
+ return functools.partial(torch_wrap._or_policy, policies=policies)
62
+
63
+
64
+ def log_train_params(model: Union[torch.nn.Module, FullyShardedDataParallel]):
65
+ world_size = get_world_size()
66
+
67
+ num_params = world_size * sum(p.numel() for p in model.parameters())
68
+ num_train_params = world_size * sum(
69
+ p.numel() for p in model.parameters() if p.requires_grad
70
+ )
71
+
72
+ main_logger_info(
73
+ f"{num_train_params:,.0f} out of {num_params:,.0f} parameter are finetuned ({num_train_params / num_params * 100:.2f}%)."
74
+ )
75
+
76
+
77
+ def initialize_lora_parameters(model: torch.nn.Module, param_dtype: torch.dtype):
78
+ """
79
+ Initialize LoRA layers with Kaiming uniform and zeros.
80
+ See original paper for more info: https://arxiv.org/abs/2106.09685 and
81
+ original github repo: https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L122
82
+ """
83
+ for m_name, module in model.named_modules():
84
+ if all(p.is_meta for p in module.parameters()):
85
+ for p_name, param in module.named_parameters():
86
+ module._parameters[p_name] = torch.nn.Parameter(
87
+ torch.empty_like(param, device="cpu", dtype=param_dtype)
88
+ )
89
+ param = module._parameters[p_name]
90
+
91
+ if m_name.split(".")[-1] == "lora_A":
92
+ torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5))
93
+ elif m_name.split(".")[-1] == "lora_B":
94
+ torch.nn.init.zeros_(param)
95
+ else:
96
+ raise ValueError(
97
+ "Only Lora layers should be randomely initialized."
98
+ )
99
+
100
+
101
+ def load_model(
102
+ folder: Path,
103
+ lora: LoraArgs,
104
+ checkpoint: bool,
105
+ param_dtype: torch.dtype,
106
+ ) -> FullyShardedDataParallel:
107
+ with open(folder / "params.json", "r") as f:
108
+ args = json.loads(f.read())
109
+
110
+ model_args = ModelArgs(
111
+ lora=lora,
112
+ dim=args["dim"],
113
+ n_layers=args["n_layers"],
114
+ head_dim=args["head_dim"],
115
+ hidden_dim=args["hidden_dim"],
116
+ n_heads=args["n_heads"],
117
+ n_kv_heads=args["n_kv_heads"],
118
+ norm_eps=args["norm_eps"],
119
+ vocab_size=args["vocab_size"],
120
+ )
121
+
122
+ if model_args.vocab_size == 32000:
123
+ raise ValueError(
124
+ f"Fine-tuning is not supported for older model versions with vocab_size 32000. Make sure to extend your model to vocab_size=32768 using `python -m utils.extend_model_vocab --original_model_ckpt {folder} --extended_model_ckpt {folder}_extended`."
125
+ )
126
+
127
+ assert (
128
+ model_args.vocab_size >= 32768
129
+ ), "Make sure to use a model with a vocab size of at least 32768"
130
+
131
+ if args.get("rope_theta") is not None:
132
+ model_args.rope_theta = args["rope_theta"]
133
+
134
+ if args.get("moe") is not None:
135
+ model_args.moe = MoeArgs(**args["moe"])
136
+
137
+ with torch.device("meta"):
138
+ model = Transformer(args=model_args, checkpoint=checkpoint)
139
+
140
+ if get_rank() == 0:
141
+ state_dict = load_state_dict(folder, dtype=param_dtype)
142
+
143
+ model.load_state_dict(state_dict, assign=True) # type: ignore
144
+ logger.info("Loaded model on cpu!")
145
+
146
+ if lora.enable:
147
+ logger.info("Initializing lora layers ...")
148
+ # initialize LoRA layers
149
+ initialize_lora_parameters(model, param_dtype)
150
+
151
+ assert not any(
152
+ p.is_meta for p in model.parameters()
153
+ ), "All parameters should be intialized by now"
154
+ assert all(
155
+ p.dtype == param_dtype for p in model.parameters()
156
+ ), f"All parameters should be on {param_dtype}"
157
+
158
+ logger.info("Finished initialization!")
159
+ param_init_fn = None
160
+ else:
161
+
162
+ def param_init_fn(m):
163
+ m.to_empty(device=torch.cuda.current_device(), recurse=False)
164
+ m.to(param_dtype)
165
+
166
+ assert all(
167
+ p.is_meta for p in model.parameters()
168
+ ), "All parameters should be on meta"
169
+
170
+ torch.distributed.barrier()
171
+
172
+ # only finetune LoRA parameters and freeze before wrapping
173
+ if lora.enable:
174
+ for name, param in model.named_parameters():
175
+ if "lora" in name:
176
+ param.requires_grad = True
177
+ else:
178
+ param.requires_grad = False
179
+
180
+ auto_wrap_policy = get_fsdp_policy(model_args.lora.enable)
181
+
182
+ main_logger_info(f"Sharding model over {get_world_size()} GPUs ...")
183
+
184
+ wrapped_model = FullyShardedDataParallel(
185
+ model,
186
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
187
+ auto_wrap_policy=auto_wrap_policy,
188
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
189
+ limit_all_gathers=True,
190
+ device_id=torch.cuda.current_device(),
191
+ sync_module_states=True,
192
+ param_init_fn=param_init_fn,
193
+ )
194
+ main_logger_info("Model sharded!")
195
+
196
+ log_train_params(wrapped_model)
197
+
198
+ return wrapped_model
199
+
200
+
201
+ @torch.no_grad()
202
+ def load_state_dict(path: Path, dtype: torch.dtype):
203
+ assert path.is_dir(), path
204
+
205
+ this_safetensors_path = Checkpointer.consolidated_path(path, use_safetensors=True)
206
+ this_torch_path = Checkpointer.consolidated_path(path, use_safetensors=False)
207
+
208
+ assert (
209
+ this_safetensors_path.exists() or this_torch_path.exists()
210
+ ), f"Either {this_safetensors_path} or {this_torch_path} must exist."
211
+ assert not (
212
+ this_safetensors_path.exists() and this_torch_path.exists()
213
+ ), f"Only one of {this_safetensors_path} or {this_torch_path} should exist."
214
+
215
+ if this_safetensors_path.exists():
216
+ logger.info(f"Reloading model from {this_safetensors_path} ...")
217
+ model_state_dict = safetensors.torch.load_file(this_safetensors_path)
218
+ else:
219
+ logger.info(f"Reloading model from {this_torch_path} ...")
220
+ model_state_dict = torch.load(this_torch_path)
221
+
222
+ logger.info(f"Converting model to dtype {dtype} ...")
223
+
224
+ for k, v in model_state_dict.items():
225
+ model_state_dict[k] = v.to(dtype)
226
+
227
+ return model_state_dict
huggingface.ipynb ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from huggingface_hub import HfApi, HfFolder, Repository\n",
10
+ "\n",
11
+ "repo_id = \"your_username/your_model_name\"\n",
12
+ "repo_local_path = \"./path_to_your_model\"\n",
13
+ "\n",
14
+ "# Create the repository object and clone the repo\n",
15
+ "repo = Repository(local_dir=repo_local_path, clone_from=repo_id)\n",
16
+ "\n",
17
+ "# Copy your model files to the repository\n",
18
+ "model_files = [\"config.json\", \"pytorch_model.bin\", \"tokenizer_config.json\", \"vocab.json\"]\n",
19
+ "for file in model_files:\n",
20
+ " shutil.copy(file, repo_local_path)\n",
21
+ "\n",
22
+ "# Push the model files to the repository\n",
23
+ "repo.push_to_hub(commit_message=\"Initial model upload\")\n"
24
+ ]
25
+ }
26
+ ],
27
+ "metadata": {
28
+ "kernelspec": {
29
+ "display_name": "chemistralpy310",
30
+ "language": "python",
31
+ "name": "python3"
32
+ },
33
+ "language_info": {
34
+ "name": "python",
35
+ "version": "3.10.14"
36
+ }
37
+ },
38
+ "nbformat": 4,
39
+ "nbformat_minor": 2
40
+ }