Text Generation
Transformers
PyTorch
mpt
Composer
MosaicML
llm-foundry
custom_code
text-generation-inference

LLM-foundry update March 26, 2024 23:50:31

#73
act_ckpt.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import torch
3
+ from .attention import ATTN_CLASS_REGISTRY
4
+ from .blocks import MPTBlock
5
+ from .ffn import FFN_CLASS_REGISTRY
6
+ from .norm import NORM_CLASS_REGISTRY
7
+
8
+ def pass_on_block_idx(parent: torch.nn.Module):
9
+ if not hasattr(parent, 'block_idx') or not hasattr(parent, 'max_block_idx'):
10
+ return
11
+ for child in parent.children():
12
+ child.block_idx = parent.block_idx
13
+ child.max_block_idx = parent.max_block_idx
14
+ if child.children():
15
+ pass_on_block_idx(child)
16
+
17
+ def get_act_ckpt_module(mod_name: str) -> Any:
18
+ """Get the module type from the module name."""
19
+ if mod_name.lower() == 'mptblock':
20
+ mod_type = MPTBlock
21
+ elif mod_name in ATTN_CLASS_REGISTRY:
22
+ mod_type = ATTN_CLASS_REGISTRY[mod_name]
23
+ elif mod_name in FFN_CLASS_REGISTRY:
24
+ mod_type = FFN_CLASS_REGISTRY[mod_name]
25
+ elif mod_name in NORM_CLASS_REGISTRY:
26
+ mod_type = NORM_CLASS_REGISTRY[mod_name]
27
+ else:
28
+ msg = ', '.join(list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) + list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock'])
29
+ raise ValueError(f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.')
30
+ return mod_type
31
+
32
+ def parse_ele_str(ele: str, max_block_idx: int) -> list:
33
+ """Parse a string in target_blocks and return a list of block ids to add.
34
+
35
+ Supported formats are: first-n, middle-m, last-k, range-i-j which correspond
36
+ to the first n, the middle m, the last k, and the range [i, j).
37
+ """
38
+ to_add = None
39
+ if ele.startswith('first-'):
40
+ assert ele[6:].isdigit(), f'Invalid target_blocks element {ele}'
41
+ to_add = list(range(min(int(ele[6:]), max_block_idx + 1)))
42
+ elif ele.startswith('last-'):
43
+ assert ele[5:].isdigit(), f'Invalid target_blocks element {ele}'
44
+ to_add = list(range(max(max_block_idx - int(ele[5:]) + 1, 0), max_block_idx + 1))
45
+ elif ele.startswith('middle-'):
46
+ assert ele[7:].isdigit(), f'Invalid target_blocks element {ele}'
47
+ num = int(ele[7:])
48
+ start = max(max_block_idx // 2 - num // 2, 0)
49
+ end = min(start + num, max_block_idx + 1)
50
+ to_add = list(range(start, end))
51
+ elif ele.startswith('range-'):
52
+ r = ele[6:].split('-')
53
+ assert len(r) == 2, f'Invalid target_blocks element {ele}'
54
+ start, end = (int(r[0]), int(r[1]))
55
+ start = max(start, 0)
56
+ end = min(end, max_block_idx + 1)
57
+ to_add = list(range(start, end))
58
+ else:
59
+ raise ValueError(f'Invalid target_blocks element {ele}')
60
+ return to_add
61
+
62
+ def get_target_block_list(target_blocks: Any, max_block_idx: int) -> list:
63
+ """Parse the user input and return a list of block ids."""
64
+ candidate_block_ids = []
65
+ if isinstance(target_blocks, int):
66
+ candidate_block_ids = list(range(target_blocks))
67
+ elif isinstance(target_blocks, list):
68
+ for ele in target_blocks:
69
+ if isinstance(ele, int):
70
+ candidate_block_ids.append(ele)
71
+ elif isinstance(ele, str):
72
+ to_add = parse_ele_str(ele, max_block_idx)
73
+ candidate_block_ids.extend(to_add)
74
+ else:
75
+ raise ValueError(f'target_blocks must be a list of integers or "first-n", "middle-m", "last-k", or "range-i-j" where n, m, k, i, j are integers, but got {target_blocks}')
76
+ elif isinstance(target_blocks, str):
77
+ target_blocks = target_blocks.replace(' ', '')
78
+ for ele in target_blocks.split(','):
79
+ to_add = parse_ele_str(ele, max_block_idx)
80
+ candidate_block_ids.extend(to_add)
81
+ else:
82
+ raise ValueError(f'target_blocks must be either a single intege, or a list of integers, or a comma separated string made of "first-n", "last-m", "middle-k", "range-i-j", or a list of mixed integers and before-mentioned strings, but got {type(target_blocks)}')
83
+ candidate_block_ids = list(set(candidate_block_ids))
84
+ return candidate_block_ids
85
+
86
+ def check_mapping_blocks_overlap(mapping: dict, max_block_idx: int) -> None:
87
+ """Check if the block ids in the mapping overlap with each other."""
88
+ all_blocks = [None] * (max_block_idx + 1)
89
+ for k, v in mapping.items():
90
+ if v == -1:
91
+ v = list(range(max_block_idx + 1))
92
+ for vv in v:
93
+ if vv < 0 or vv > max_block_idx:
94
+ continue
95
+ elif all_blocks[vv] is not None:
96
+ raise ValueError(f'Block {vv} is assigned to both {k} and {all_blocks[vv]}. Each block can only have one granularity of activation checkpointing. Make sure the target_blocks in activation_checkpointing_target do not overlap. For more details, refer to the docs of activation_checkpointing_fn.')
97
+ else:
98
+ all_blocks[vv] = k
99
+
100
+ def build_act_ckpt_mod_to_blocks(act_ckpt_target: Any, top_module: Any, max_block_idx: int) -> dict:
101
+ act_ckpt_mod_to_blocks = {}
102
+ if act_ckpt_target is None or act_ckpt_target == []:
103
+ mod = top_module
104
+ act_ckpt_mod_to_blocks[mod] = -1
105
+ elif isinstance(act_ckpt_target, str):
106
+ mod = get_act_ckpt_module(act_ckpt_target)
107
+ act_ckpt_mod_to_blocks[mod] = -1
108
+ elif isinstance(act_ckpt_target, list):
109
+ for target in act_ckpt_target:
110
+ mod = get_act_ckpt_module(target)
111
+ act_ckpt_mod_to_blocks[mod] = -1
112
+ elif isinstance(act_ckpt_target, dict):
113
+ for k, v in act_ckpt_target.items():
114
+ mod = get_act_ckpt_module(k)
115
+ block_ids = get_target_block_list(v, max_block_idx)
116
+ act_ckpt_mod_to_blocks[mod] = block_ids
117
+ else:
118
+ raise ValueError(f'activation_checkpointing_target must be either a single string or a list or a dict, but got {type(act_ckpt_target)}')
119
+ return act_ckpt_mod_to_blocks
async_eval_callback.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run the eval loop asynchronously as part of a MosaicML platform run.
2
+
3
+ This callback is currently experimental. The API may change in the future.
4
+ """
5
+ import logging
6
+ import os
7
+ import warnings
8
+ from collections import Counter
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List, Optional, Tuple, Union
11
+ from .interfaces import CallbackWithConfig
12
+ from mcli import Run, RunConfig, create_run, get_run
13
+ log = logging.getLogger(__name__)
14
+ REQUIRED_PARAMS_FOR_EVAL = {'device_eval_batch_size', 'icl_tasks', 'max_seq_len', 'model', 'tokenizer'}
15
+ OPTIONAL_PARAMS_FOR_EVAL = {'dist_timeout', 'eval_gauntlet', 'eval_loader', 'fsdp_config', 'eval_subset_num_batches', 'icl_subset_num_batches', 'loggers', 'precision', 'python_log_level', 'seed'}
16
+ RUN_NAME_PREFIX = 'eval'
17
+ MAX_RUN_NAME_BASE_LENGTH = 55
18
+
19
+ def get_run_name(training_run_name: str, current_interval: str) -> str:
20
+ """Get the new eval run name.
21
+
22
+ Args:
23
+ training_run_name: The name of the current training run
24
+ current_interval: The current interval string of the training run
25
+
26
+ Returns:
27
+ The new run name
28
+ """
29
+ name_without_uuid_suffix = training_run_name.rsplit('-', 1)[0]
30
+ max_length = MAX_RUN_NAME_BASE_LENGTH - len(RUN_NAME_PREFIX) - len(current_interval) - 2
31
+ if len(name_without_uuid_suffix) > max_length:
32
+ new_name = name_without_uuid_suffix[:max_length]
33
+ log.warning(f'Training run name {name_without_uuid_suffix} may be too long,' + f' truncating to {new_name}')
34
+ name_without_uuid_suffix = new_name
35
+ return f'{RUN_NAME_PREFIX}-{current_interval}-{name_without_uuid_suffix}'
36
+
37
+ def get_eval_parameters(parameters: Dict[str, Any], checkpoint: str, training_run_name: str) -> Dict[str, Any]:
38
+ """Get the parameters needed for the eval run.
39
+
40
+ Args:
41
+ parameters: The parameters from the training run
42
+ checkpoint: The path to the latest checkpoint
43
+ training_run_name: The name of the training run
44
+
45
+ Returns:
46
+ The parameters needed for the eval run as a dict
47
+ """
48
+ looking_for = REQUIRED_PARAMS_FOR_EVAL.copy()
49
+ subset_keys = {}
50
+ for key in parameters:
51
+ if key in OPTIONAL_PARAMS_FOR_EVAL:
52
+ subset_keys[key] = parameters[key]
53
+ elif key in REQUIRED_PARAMS_FOR_EVAL:
54
+ subset_keys[key] = parameters[key]
55
+ looking_for.remove(key)
56
+ if looking_for:
57
+ raise Exception(f'Missing the following required parameters for async eval: {looking_for}')
58
+ for logger, config in subset_keys.get('loggers', {}).items():
59
+ if logger == 'wandb':
60
+ config['group'] = config.pop('name', training_run_name)
61
+ model = subset_keys.pop('model')
62
+ model_name = model.get('name', None)
63
+ if not model_name:
64
+ raise Exception(f'Async evaluation requires "name" keys for models')
65
+ new_models = {'model_name': model_name, 'model': model, 'load_path': checkpoint}
66
+ tokenizer = subset_keys.pop('tokenizer', None)
67
+ if tokenizer is not None:
68
+ new_models['tokenizer'] = tokenizer
69
+ subset_keys['models'] = [new_models]
70
+ return subset_keys
71
+
72
+ def validate_interval(interval: Union[str, int, Time], save_interval: Union[str, int, Time]) -> Time:
73
+ new_save_interval = Time.from_input(save_interval, TimeUnit.EPOCH)
74
+ async_interval = Time.from_input(interval, TimeUnit.EPOCH)
75
+ if new_save_interval.unit != async_interval.unit:
76
+ raise ValueError('Save interval and async eval interval must be in the same unit')
77
+ if async_interval < new_save_interval:
78
+ raise ValueError('Async eval interval must be equal or greater (less frequent) than save interval')
79
+ if async_interval.value % new_save_interval.value != 0:
80
+ raise ValueError('Async eval interval must be a multiple of save interval')
81
+ return async_interval
82
+
83
+ def validate_eval_run_config(eval_run_config: Optional[Dict[str, Any]]) -> Dict[str, Any]:
84
+ if not eval_run_config:
85
+ return {}
86
+ run_config = eval_run_config.copy()
87
+ supported_keys = {'image', 'command', 'compute', 'scheduling'}
88
+ found_unsupported = set()
89
+ for key in run_config:
90
+ if key not in supported_keys:
91
+ found_unsupported.add(key)
92
+ if found_unsupported:
93
+ raise ValueError(f"Unsupported eval run config keys found: {', '.join(found_unsupported)}" + f'. Supported keys: {supported_keys}')
94
+ return run_config
95
+ CHECKS_PER_INTERVAL = 4
96
+
97
+ class AsyncEval(CallbackWithConfig):
98
+ """Run the eval loop asynchronously as part of a MosaicML platform run.
99
+
100
+ This callback is currently experimental. The API may change in the future.
101
+
102
+ Args:
103
+ training_params: Dict[str, Any]: The parameter config from the training run
104
+ interval: Union[str, int, Time]: The interval describing how often eval runs should be
105
+ launched. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`.
106
+ Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
107
+ :attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
108
+ eval_run_config: Optional[Dict[str, Any]]: A subset of mcli run config values to use
109
+ for the eval run. If not specified, any fields from run config will be created
110
+ dynamically from the training run config and the interval. The following fields
111
+ are supported:
112
+ - ``image``: Image of the eval run. Default: same as training run
113
+ - ``command``: Command to run for the eval run. Default: calls
114
+ `composer scripts/eval/eval.py $PARAMETERS`. If custom setup is needed,
115
+ the command should include calling the eval script with $PARAMETERS
116
+ - ``compute``: Compute to use for the eval run. Default: same cluster as
117
+ the training run and a single node (8 GPUs)
118
+ - ``scheduling``: Scheduling to use for the eval run. Default: same as training run
119
+
120
+ All fields are optional, but if specified, must be valid for a mcli run config. We
121
+ provide this optional config to give you the most flexibility in customizing the eval
122
+ run, but it is recommended to use the default values unless you have a specific use case
123
+ """
124
+
125
+ def __init__(self, training_params: Dict[str, Any], interval: Union[str, int, Time], eval_run_config: Optional[Dict[str, Any]]=None):
126
+ for required in ('save_interval', 'save_folder'):
127
+ if required not in training_params:
128
+ raise ValueError(f'{required} required for async eval')
129
+ if '/' in training_params.get('save_filename', ''):
130
+ raise ValueError('AsyncEval not supported for save_filename that includes a path')
131
+ self.checkpoint_save_folder = training_params['save_folder']
132
+ self.training_params = training_params
133
+ self.eval_run_config = validate_eval_run_config(eval_run_config)
134
+ self.current_run = self._get_current_run()
135
+ get_eval_parameters(parameters=training_params, checkpoint='test', training_run_name=self.current_run.name)
136
+ self.interval = validate_interval(interval, self.training_params['save_interval'])
137
+ check_interval_value = max(self.interval.value // CHECKS_PER_INTERVAL, 1)
138
+ self.check_interval = Time(check_interval_value, self.interval.unit)
139
+ self.checkpoints_evaled: Dict[Time, Tuple[str, str]] = {}
140
+ self.is_at_check_interval = create_interval_scheduler(self.check_interval, include_end_of_training=False)
141
+ log.info('Initialized AsyncEval callback. Will generate runs at ' + f'interval {interval}, checking at {self.check_interval}')
142
+
143
+ def state_dict(self) -> Dict[str, Any]:
144
+ checkpoints_evaled = []
145
+ for eval_ts, (checkpoint, run_name) in self.checkpoints_evaled.items():
146
+ eval_ts_dict = {'value': eval_ts.value, 'unit': eval_ts.unit.value}
147
+ checkpoints_evaled.append((eval_ts_dict, checkpoint, run_name))
148
+ return {'checkpoints_evaled': checkpoints_evaled}
149
+
150
+ def load_state_dict(self, state_dict: Dict[str, Any]):
151
+ previous_checkpoints_evaled = state_dict.get('checkpoints_evaled', [])
152
+ if previous_checkpoints_evaled:
153
+ for eval_ts, checkpoint, run_name in previous_checkpoints_evaled:
154
+ eval_ts = Time(eval_ts['value'], TimeUnit(eval_ts['unit']))
155
+ self.checkpoints_evaled[eval_ts] = (checkpoint, run_name)
156
+ log.info(f'Loaded previous checkpoints evaled: {self.checkpoints_evaled}')
157
+
158
+ @staticmethod
159
+ def _get_ready_sharded_checkpoints(checkpointer_checkpoints: Dict[str, Timestamp], remote_files: List[str]) -> Dict[str, Timestamp]:
160
+ """Identify checkpoints ready to be evaled based on remote files.
161
+
162
+ This has special logic for sharded checkpoints to consider checkpoints composed
163
+ of multiple shards (one per gpu) and metadata
164
+
165
+ Args:
166
+ checkpointer_checkpoints: All checkpoints from the checkpointer state
167
+ remote_files: List of remote files in the save folder
168
+
169
+ Returns:
170
+ Dict of checkpoints that are complete and ready to be evaled
171
+ """
172
+ remote_file_group_counts = Counter()
173
+ for f in remote_files:
174
+ checkpoint_ts_path = Path(f).parts[-2]
175
+ remote_file_group_counts[checkpoint_ts_path] += 1
176
+ checkpoints_to_eval = {}
177
+ for checkpoint, checkpoint_ts in checkpointer_checkpoints.items():
178
+ checkpoint_ts_path = Path(checkpoint).parts[-2]
179
+ expected_shard_count = dist.get_world_size() + 1
180
+ if remote_file_group_counts[checkpoint_ts_path] != expected_shard_count:
181
+ log.debug(f'Checkpoint {checkpoint} not fully uploaded (missing shards ' + f'{remote_file_group_counts[checkpoint_ts_path]}/{expected_shard_count}), skipping')
182
+ continue
183
+ checkpoints_to_eval[checkpoint_ts_path] = checkpoint_ts
184
+ return checkpoints_to_eval
185
+
186
+ @staticmethod
187
+ def _get_ready_single_checkpoints(checkpointer_checkpoints: Dict[str, Timestamp], remote_checkpoints: List[str]) -> Dict[str, Timestamp]:
188
+ """Identify checkpoints ready to be evaled based on remote checkpoints.
189
+
190
+ This is much simpler than the sharded case, because there is only one file
191
+
192
+ Args:
193
+ checkpointer_checkpoints: All checkpoints from the checkpointer state
194
+ remote_checkpoints: List of remote checkpoints in the save folder
195
+
196
+ Returns:
197
+ Dict of checkpoints that are complete and ready to be evaled
198
+ """
199
+ unique_remote_checkpoints = set(remote_checkpoints)
200
+ checkpoints_to_eval = {}
201
+ for checkpoint, checkpoint_ts in checkpointer_checkpoints.items():
202
+ checkpoint_ts_path = Path(checkpoint).parts[-1]
203
+ if checkpoint not in unique_remote_checkpoints:
204
+ log.debug(f'Checkpoint {checkpoint} not fully uploaded, skipping')
205
+ continue
206
+ checkpoints_to_eval[checkpoint_ts_path] = checkpoint_ts
207
+ return checkpoints_to_eval
208
+
209
+ def _get_checkpoints_and_launch_runs(self, state: State):
210
+ """Get the latest checkpoint from the training run.
211
+
212
+ Args:
213
+ state: The current state of the training run
214
+
215
+ Returns:
216
+ Returns checkpoints that have not been evaled
217
+ """
218
+ checkpointer = None
219
+ for callback in state.callbacks:
220
+ if isinstance(callback, CheckpointSaver):
221
+ if checkpointer is None:
222
+ checkpointer = callback
223
+ else:
224
+ log.warning('Multiple checkpoint savers found. Using the first one')
225
+ if not checkpointer:
226
+ warnings.warn('No checkpoint saver callback found. Skipping eval')
227
+ return
228
+ if not checkpointer.all_saved_checkpoints_to_timestamp:
229
+ log.debug('No saved checkpoints found on the checkpointer. Skipping eval')
230
+ return
231
+ log.debug(f'Found {len(checkpointer.all_saved_checkpoints_to_timestamp)} ' + f'checkpoints: {checkpointer.all_saved_checkpoints_to_timestamp}')
232
+ remote_checkpoints = list_remote_objects(self.checkpoint_save_folder)
233
+ if not remote_checkpoints:
234
+ log.debug('No saved checkpoints found yet on remote. Skipping eval')
235
+ return
236
+ if state.fsdp_sharded_state_dict_enabled:
237
+ checkpoints_to_eval = self._get_ready_sharded_checkpoints(checkpointer.all_saved_checkpoints_to_timestamp, remote_checkpoints)
238
+ else:
239
+ checkpoints_to_eval = self._get_ready_single_checkpoints(checkpointer.all_saved_checkpoints_to_timestamp, remote_checkpoints)
240
+ for checkpoint_interval_path, checkpoint_timestamp in checkpoints_to_eval.items():
241
+ checkpoint_ts = checkpoint_timestamp.get(self.interval.unit)
242
+ if checkpoint_ts.value % self.interval.value != 0:
243
+ log.debug(f'Checkpoint {checkpoint_interval_path} ({checkpoint_ts}) is ' + f'not at an eval interval ({self.interval}), skipping')
244
+ continue
245
+ if checkpoint_ts in self.checkpoints_evaled:
246
+ continue
247
+ full_checkpoint_path = f'{self.checkpoint_save_folder}/{checkpoint_interval_path}'
248
+ eval_run = self.launch_run(full_checkpoint_path, checkpoint_ts)
249
+ self.checkpoints_evaled[checkpoint_ts] = (full_checkpoint_path, eval_run.name)
250
+
251
+ def run_event(self, event: Event, state: State, logger: Logger) -> None:
252
+ del logger
253
+ should_launch_run = all([state.get_elapsed_duration() is not None, self.is_at_check_interval(state, event), dist.get_global_rank() == 0])
254
+ if should_launch_run:
255
+ self._get_checkpoints_and_launch_runs(state)
256
+
257
+ def close(self, state: State, logger: Logger) -> None:
258
+ del logger
259
+ if dist.get_global_rank() != 0:
260
+ return
261
+ self._get_checkpoints_and_launch_runs(state)
262
+ latest_timestamp = state.timestamp.get(self.interval.unit)
263
+ if latest_timestamp not in self.checkpoints_evaled:
264
+ save_latest_filename = self.training_params.get('save_latest_filename', None)
265
+ if not save_latest_filename:
266
+ rank = dist.get_global_rank()
267
+ save_latest_filename = f'latest-rank{rank}.pt'
268
+ checkpoint = f'{self.checkpoint_save_folder}/{save_latest_filename}'
269
+ eval_run = self.launch_run(checkpoint, latest_timestamp)
270
+ self.checkpoints_evaled[latest_timestamp] = (checkpoint, eval_run.name)
271
+ log.info(f'AsyncEval callback finished. Launched {len(self.checkpoints_evaled)} eval runs:')
272
+ for checkpoint_ts, (checkpoint, run_name) in self.checkpoints_evaled.items():
273
+ log.info(f' {checkpoint_ts}: {checkpoint}, {run_name}')
274
+
275
+ def _get_current_run(self) -> Run:
276
+ if os.environ.get(MOSAICML_PLATFORM_ENV_VAR, 'false').lower() == 'false':
277
+ raise RuntimeError('AsyncEval callback is only supported when running on the MosaicML platform')
278
+ run_name = os.environ.get(RUN_NAME_ENV_VAR, None)
279
+ if not run_name:
280
+ raise RuntimeError('RUN_NAME environment variable must be set to use the AsyncEval callback')
281
+ return get_run(run_name, include_details=True)
282
+
283
+ def launch_run(self, checkpoint: str, current_interval: Time) -> Run:
284
+ """Launch a new eval run.
285
+
286
+ Args:
287
+ checkpoint: The checkpoint to eval
288
+ current_interval: The interval of the checkpoint
289
+
290
+ Returns:
291
+ The launched run (mcli.Run type)
292
+ """
293
+ log.info(f'Launching eval run for {checkpoint} at {current_interval}')
294
+ cfg = self.current_run.submitted_config
295
+ default_compute = {'gpus': 8, 'cluster': self.current_run.cluster}
296
+ run_name = get_run_name(self.current_run.name, str(current_interval))
297
+ params = get_eval_parameters(parameters=self.training_params, checkpoint=checkpoint, training_run_name=self.current_run.name)
298
+ params['run_name'] = run_name
299
+ integrations = cfg.integrations
300
+ found_llm_foundry, installation_path = (False, 'llm-foundry')
301
+ for i in integrations:
302
+ if i['integration_type'] != 'git_repo':
303
+ continue
304
+ if not i['git_repo'].endswith('llm-foundry'):
305
+ continue
306
+ found_llm_foundry = True
307
+ if i.get('path'):
308
+ installation_path = i['path']
309
+ if not found_llm_foundry:
310
+ from .llmfoundry import __version__ as latest_foundry_version
311
+ version = f'v{latest_foundry_version}'
312
+ log.warning('No github integration found for llm-foundry. Adding installation ' + f'to eval run for latest foundry release ({version}). ' + 'To use a fork, custom branch, or custom version, configure ' + 'llm-foundry installation through a github integration')
313
+ integrations.append({'integration_type': 'git_repo', 'git_repo': 'mosaicml/llm-foundry', 'git_branch': version, 'pip_install': '-e .[gpu]', 'ssh_clone': False})
314
+ metadata = cfg.metadata
315
+ metadata['eval_timestamp'] = current_interval.value
316
+ metadata['eval_timestamp_unit'] = current_interval.unit.value
317
+ default_command = f'cd {installation_path}/scripts \n composer eval/eval.py $PARAMETERS'
318
+ run_config = RunConfig(name=run_name, image=self.eval_run_config.get('image', self.current_run.image), command=self.eval_run_config.get('command', default_command), compute=self.eval_run_config.get('compute', default_compute), scheduling=self.eval_run_config.get('scheduling', self.current_run.submitted_config.scheduling), integrations=integrations, env_variables=cfg.env_variables, metadata=cfg.metadata, parameters=params)
319
+ log.info(f'Creating new run with config: \n{run_config}')
320
+ new_run = create_run(run_config)
321
+ log.info(f'Launched new run {new_run.name} inside eval loop')
322
+ return new_run
attention.py CHANGED
@@ -31,9 +31,6 @@ def is_transformers_version_gte(hf_version: str) -> bool:
31
 
32
  def check_alibi_support(attention_impl: str) -> bool:
33
  return attention_impl != 'flash' or is_flash_v2_installed(v2_version='v2.4.2')
34
- if is_flash_v1_installed():
35
- import transformers
36
- transformers.utils.is_flash_attn_available = lambda : False
37
  from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
38
 
39
  def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool) -> bool:
@@ -53,7 +50,7 @@ def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
53
  """
54
  if n_rep == 1:
55
  return hidden
56
- (b, s, kv_n_heads, d) = hidden.shape
57
  hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)
58
  return hidden.reshape(b, s, kv_n_heads * n_rep, d)
59
 
@@ -66,7 +63,7 @@ def scaled_multihead_dot_product_attention(query: torch.Tensor, key: torch.Tenso
66
  k = torch.cat([past_key_value[0], k], dim=3)
67
  v = torch.cat([past_key_value[1], v], dim=2)
68
  past_key_value = (k, v)
69
- (b, _, s_q, d) = q.shape
70
  s_k = k.size(-1)
71
  if kv_n_heads > 1 and kv_n_heads < n_heads:
72
  k = repeat_kv_for_gqa(k.transpose(1, 2), n_heads // kv_n_heads).transpose(1, 2)
@@ -130,7 +127,7 @@ def flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n
130
  past_key_value = (key, value)
131
  if attn_bias is not None:
132
  raise NotImplementedError(f'attn_bias not implemented for flash attn.')
133
- (batch_size, seqlen) = query.shape[:2]
134
  indices_q = flash_attn_padding_info['indices_q']
135
  indices_k = flash_attn_padding_info['indices_k']
136
  indices_v = flash_attn_padding_info['indices_v']
@@ -169,65 +166,17 @@ def flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n
169
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
170
  return (output, None, past_key_value)
171
 
172
- def triton_flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: int, past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]:
173
- try:
174
- from .flash_attn_triton import flash_attn_func
175
- except:
176
- _installed = False
177
- if version.parse(torch.__version__) < version.parse('2.0.0'):
178
- _installed = True
179
- try:
180
- from flash_attn.flash_attn_triton import flash_attn_func
181
- except:
182
- _installed = False
183
- if not _installed:
184
- raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU ' + 'and `pip install .[gpu]` if installing from llm-foundry source or ' + '`pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` ' + 'if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). ' + 'Note: (1) requires you have CMake and PyTorch already installed.')
185
- check_valid_inputs(query, key, value)
186
- if past_key_value is not None:
187
- if len(past_key_value) != 0:
188
- key = torch.cat([past_key_value[0], key], dim=1)
189
- value = torch.cat([past_key_value[1], value], dim=1)
190
- past_key_value = (key, value)
191
- if attn_bias is not None:
192
- _s_q = max(0, attn_bias.size(2) - query.size(1))
193
- _s_k = max(0, attn_bias.size(3) - key.size(1))
194
- attn_bias = attn_bias[:, :, _s_q:, _s_k:]
195
- if dropout_p:
196
- raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
197
- dropout_p = dropout_p if training else 0.0
198
- if needs_weights:
199
- raise NotImplementedError(f'attn_impl: triton cannot return attn weights.')
200
- if key_padding_mask is not None:
201
- warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
202
- (b_size, s_k) = key_padding_mask.shape[:2]
203
- if attn_bias is None:
204
- attn_bias = query.new_zeros(b_size, 1, 1, s_k)
205
- attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
206
- query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
207
- key = rearrange(key, 'b s (h d) -> b s h d', h=kv_n_heads)
208
- value = rearrange(value, 'b s (h d) -> b s h d', h=kv_n_heads)
209
- if kv_n_heads == 1:
210
- key = key.repeat(1, 1, n_heads, 1)
211
- value = value.repeat(1, 1, n_heads, 1)
212
- elif kv_n_heads < n_heads:
213
- key = repeat_kv_for_gqa(key, n_heads // kv_n_heads)
214
- value = repeat_kv_for_gqa(value, n_heads // kv_n_heads)
215
- reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
216
- attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
217
- output = attn_output.view(*attn_output.shape[:2], -1)
218
- return (output, None, past_key_value)
219
-
220
  class GroupedQueryAttention(nn.Module):
221
  """Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
222
 
223
  and Multi-query attention (MQA).
224
 
225
  This allows the user to set a variable of number of kv_n_heads, rather than
226
- just n_heads or 1, as in MHA and MQA. Using torch or triton attention
227
  implementation enables user to also use additive bias.
228
  """
229
 
230
- def __init__(self, d_model: int, n_heads: int, kv_n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, qk_gn: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, bias: bool=True, sliding_window_size: int=-1):
231
  super().__init__()
232
  self.attn_impl = attn_impl
233
  self.clip_qkv = clip_qkv
@@ -251,8 +200,7 @@ class GroupedQueryAttention(nn.Module):
251
  self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
252
  self.attn_dropout_p = attn_pdrop
253
  fc_kwargs: dict[str, Any] = {'bias': bias}
254
- if fc_type != 'te':
255
- fc_kwargs['device'] = device
256
  self.Wqkv = FC_CLASS_REGISTRY[fc_type](self.d_model, self.d_model + 2 * self.kv_n_heads * self.head_dim, **fc_kwargs)
257
  fuse_splits = [i * self.head_dim for i in range(1, self.n_heads + 2 * self.kv_n_heads)]
258
  self.Wqkv._fused = (0, fuse_splits)
@@ -265,8 +213,6 @@ class GroupedQueryAttention(nn.Module):
265
  self.k_ln = norm_class(norm_size, device=device)
266
  if self.attn_impl == 'flash':
267
  self.attn_fn = flash_attn_fn
268
- elif self.attn_impl == 'triton':
269
- self.attn_fn = triton_flash_attn_fn
270
  elif self.attn_impl == 'torch':
271
  self.attn_fn = scaled_multihead_dot_product_attention
272
  else:
@@ -278,12 +224,12 @@ class GroupedQueryAttention(nn.Module):
278
  qkv = self.Wqkv(x)
279
  if self.clip_qkv:
280
  qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
281
- (query, key, value) = qkv.split([self.d_model, self.kv_n_heads * self.head_dim, self.kv_n_heads * self.head_dim], dim=2)
282
  key_padding_mask = attention_mask
283
  if self.qk_ln or self.qk_gn:
284
- (q_shape, k_shape) = (query.shape, key.shape)
285
  if self.qk_gn:
286
- (b, s) = query.shape[:2]
287
  query = query.view(b, s, self.n_heads, -1)
288
  key = key.view(b, s, self.kv_n_heads, -1)
289
  dtype = query.dtype
@@ -293,23 +239,28 @@ class GroupedQueryAttention(nn.Module):
293
  rotary_emb = rotary_emb_w_meta_info['rotary_emb']
294
  seq_len = rotary_emb_w_meta_info['seq_len']
295
  offset_info = rotary_emb_w_meta_info['offset_info']
296
- (bsz, seqlen) = query.shape[:2]
297
  query = query.view(bsz, seqlen, -1, self.head_dim)
298
  key = key.view(bsz, seqlen, -1, self.head_dim)
299
  if rotary_emb_w_meta_info['impl'] == 'dail':
300
  value = value.view(bsz, seqlen, -1, self.head_dim)
301
  kv = torch.stack([key, value], dim=2)
302
- (query, kv) = rotary_emb(query, kv, seqlen_offset=offset_info, max_seqlen=seq_len)
303
  [key, value] = torch.unbind(kv, dim=2)
304
  value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
305
  elif rotary_emb_w_meta_info['impl'] == 'hf':
306
- (cos, sin) = rotary_emb(value, seq_len)
307
- if is_transformers_version_gte('4.36'):
308
- (query, key) = apply_rotary_pos_emb(query, key, cos, sin, offset_info, unsqueeze_dim=2)
 
 
 
 
 
309
  else:
310
  query = query.transpose(1, 2)
311
  key = key.transpose(1, 2)
312
- (query, key) = apply_rotary_pos_emb(query, key, cos, sin, offset_info)
313
  query = query.transpose(1, 2)
314
  key = key.transpose(1, 2)
315
  query = query.view(bsz, seqlen, self.d_model)
@@ -318,38 +269,36 @@ class GroupedQueryAttention(nn.Module):
318
  if self.attn_impl == 'flash':
319
  key_padding_mask = None
320
  extra_attn_kwargs = {'should_repeat_kv_for_gqa': not is_flash_v2_installed(), 'sliding_window_size': self.sliding_window_size, 'alibi_slopes': alibi_slopes, 'flash_attn_padding_info': flash_attn_padding_info}
321
- (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, self.kv_n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, **extra_attn_kwargs)
322
  return (self.out_proj(context), attn_weights, past_key_value)
323
 
324
  class MultiheadAttention(GroupedQueryAttention):
325
  """Multi-head self attention.
326
 
327
- Using torch or triton attention implementation enables user to also use
328
- additive bias.
329
  """
330
 
331
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, qk_gn: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, bias: bool=True, sliding_window_size: int=-1):
332
  super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=n_heads, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device, bias=bias, sliding_window_size=sliding_window_size)
333
 
334
  class MultiQueryAttention(GroupedQueryAttention):
335
  """Multi-Query self attention.
336
 
337
- Using torch or triton attention implementation enables user to also use
338
- additive bias.
339
  """
340
 
341
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, qk_gn: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, bias: bool=True, sliding_window_size: int=-1):
342
  super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=1, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device, bias=bias, sliding_window_size=sliding_window_size)
343
 
344
- def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool, prefix_lm: bool, causal: bool, use_sequence_id: bool) -> Optional[tuple[int, int, int, int]]:
345
  if attn_impl == 'flash':
346
  return None
347
- elif attn_impl in ['torch', 'triton']:
348
  if alibi:
349
- if (prefix_lm or not causal) or use_sequence_id:
350
  return (1, n_heads, seq_len, seq_len)
351
  return (1, n_heads, 1, seq_len)
352
- elif prefix_lm or use_sequence_id:
353
  return (1, 1, seq_len, seq_len)
354
  return None
355
  else:
@@ -358,9 +307,9 @@ def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool, pre
358
  def build_attn_bias(attn_impl: str, attn_bias: torch.Tensor, n_heads: int, seq_len: int, causal: bool=False, alibi: bool=False, alibi_bias_max: int=8) -> Optional[torch.Tensor]:
359
  if attn_impl == 'flash':
360
  return None
361
- elif attn_impl in ['torch', 'triton']:
362
  if alibi:
363
- (device, dtype) = (attn_bias.device, attn_bias.dtype)
364
  attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
365
  return attn_bias
366
  else:
 
31
 
32
  def check_alibi_support(attention_impl: str) -> bool:
33
  return attention_impl != 'flash' or is_flash_v2_installed(v2_version='v2.4.2')
 
 
 
34
  from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
35
 
36
  def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool) -> bool:
 
50
  """
51
  if n_rep == 1:
52
  return hidden
53
+ b, s, kv_n_heads, d = hidden.shape
54
  hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)
55
  return hidden.reshape(b, s, kv_n_heads * n_rep, d)
56
 
 
63
  k = torch.cat([past_key_value[0], k], dim=3)
64
  v = torch.cat([past_key_value[1], v], dim=2)
65
  past_key_value = (k, v)
66
+ b, _, s_q, d = q.shape
67
  s_k = k.size(-1)
68
  if kv_n_heads > 1 and kv_n_heads < n_heads:
69
  k = repeat_kv_for_gqa(k.transpose(1, 2), n_heads // kv_n_heads).transpose(1, 2)
 
127
  past_key_value = (key, value)
128
  if attn_bias is not None:
129
  raise NotImplementedError(f'attn_bias not implemented for flash attn.')
130
+ batch_size, seqlen = query.shape[:2]
131
  indices_q = flash_attn_padding_info['indices_q']
132
  indices_k = flash_attn_padding_info['indices_k']
133
  indices_v = flash_attn_padding_info['indices_v']
 
166
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
167
  return (output, None, past_key_value)
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  class GroupedQueryAttention(nn.Module):
170
  """Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
171
 
172
  and Multi-query attention (MQA).
173
 
174
  This allows the user to set a variable of number of kv_n_heads, rather than
175
+ just n_heads or 1, as in MHA and MQA. Using torch attention
176
  implementation enables user to also use additive bias.
177
  """
178
 
179
+ def __init__(self, d_model: int, n_heads: int, kv_n_heads: int, attn_impl: str='flash', clip_qkv: Optional[float]=None, qk_ln: bool=False, qk_gn: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, bias: bool=True, sliding_window_size: int=-1):
180
  super().__init__()
181
  self.attn_impl = attn_impl
182
  self.clip_qkv = clip_qkv
 
200
  self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
201
  self.attn_dropout_p = attn_pdrop
202
  fc_kwargs: dict[str, Any] = {'bias': bias}
203
+ fc_kwargs['device'] = device
 
204
  self.Wqkv = FC_CLASS_REGISTRY[fc_type](self.d_model, self.d_model + 2 * self.kv_n_heads * self.head_dim, **fc_kwargs)
205
  fuse_splits = [i * self.head_dim for i in range(1, self.n_heads + 2 * self.kv_n_heads)]
206
  self.Wqkv._fused = (0, fuse_splits)
 
213
  self.k_ln = norm_class(norm_size, device=device)
214
  if self.attn_impl == 'flash':
215
  self.attn_fn = flash_attn_fn
 
 
216
  elif self.attn_impl == 'torch':
217
  self.attn_fn = scaled_multihead_dot_product_attention
218
  else:
 
224
  qkv = self.Wqkv(x)
225
  if self.clip_qkv:
226
  qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
227
+ query, key, value = qkv.split([self.d_model, self.kv_n_heads * self.head_dim, self.kv_n_heads * self.head_dim], dim=2)
228
  key_padding_mask = attention_mask
229
  if self.qk_ln or self.qk_gn:
230
+ q_shape, k_shape = (query.shape, key.shape)
231
  if self.qk_gn:
232
+ b, s = query.shape[:2]
233
  query = query.view(b, s, self.n_heads, -1)
234
  key = key.view(b, s, self.kv_n_heads, -1)
235
  dtype = query.dtype
 
239
  rotary_emb = rotary_emb_w_meta_info['rotary_emb']
240
  seq_len = rotary_emb_w_meta_info['seq_len']
241
  offset_info = rotary_emb_w_meta_info['offset_info']
242
+ bsz, seqlen = query.shape[:2]
243
  query = query.view(bsz, seqlen, -1, self.head_dim)
244
  key = key.view(bsz, seqlen, -1, self.head_dim)
245
  if rotary_emb_w_meta_info['impl'] == 'dail':
246
  value = value.view(bsz, seqlen, -1, self.head_dim)
247
  kv = torch.stack([key, value], dim=2)
248
+ query, kv = rotary_emb(query, kv, seqlen_offset=offset_info, max_seqlen=seq_len)
249
  [key, value] = torch.unbind(kv, dim=2)
250
  value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
251
  elif rotary_emb_w_meta_info['impl'] == 'hf':
252
+ if is_transformers_version_gte('4.38'):
253
+ cos, sin = rotary_emb(x=value, position_ids=offset_info, seq_len=None)
254
+ else:
255
+ cos, sin = rotary_emb(x=value, seq_len=seq_len)
256
+ if is_transformers_version_gte('4.38'):
257
+ query, key = apply_rotary_pos_emb(q=query, k=key, cos=cos, sin=sin, position_ids=None, unsqueeze_dim=2)
258
+ elif is_transformers_version_gte('4.36'):
259
+ query, key = apply_rotary_pos_emb(q=query, k=key, cos=cos, sin=sin, position_ids=offset_info, unsqueeze_dim=2)
260
  else:
261
  query = query.transpose(1, 2)
262
  key = key.transpose(1, 2)
263
+ query, key = apply_rotary_pos_emb(q=query, k=key, cos=cos, sin=sin, position_ids=offset_info)
264
  query = query.transpose(1, 2)
265
  key = key.transpose(1, 2)
266
  query = query.view(bsz, seqlen, self.d_model)
 
269
  if self.attn_impl == 'flash':
270
  key_padding_mask = None
271
  extra_attn_kwargs = {'should_repeat_kv_for_gqa': not is_flash_v2_installed(), 'sliding_window_size': self.sliding_window_size, 'alibi_slopes': alibi_slopes, 'flash_attn_padding_info': flash_attn_padding_info}
272
+ context, attn_weights, past_key_value = self.attn_fn(query, key, value, self.n_heads, self.kv_n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, **extra_attn_kwargs)
273
  return (self.out_proj(context), attn_weights, past_key_value)
274
 
275
  class MultiheadAttention(GroupedQueryAttention):
276
  """Multi-head self attention.
277
 
278
+ Using torch attention implementation enables user to also use additive bias.
 
279
  """
280
 
281
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='flash', clip_qkv: Optional[float]=None, qk_ln: bool=False, qk_gn: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, bias: bool=True, sliding_window_size: int=-1):
282
  super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=n_heads, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device, bias=bias, sliding_window_size=sliding_window_size)
283
 
284
  class MultiQueryAttention(GroupedQueryAttention):
285
  """Multi-Query self attention.
286
 
287
+ Using torch attention implementation enables user to also use additive bias.
 
288
  """
289
 
290
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='flash', clip_qkv: Optional[float]=None, qk_ln: bool=False, qk_gn: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, bias: bool=True, sliding_window_size: int=-1):
291
  super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=1, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device, bias=bias, sliding_window_size=sliding_window_size)
292
 
293
+ def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool, causal: bool, use_sequence_id: bool) -> Optional[tuple[int, int, int, int]]:
294
  if attn_impl == 'flash':
295
  return None
296
+ elif attn_impl == 'torch':
297
  if alibi:
298
+ if not causal or use_sequence_id:
299
  return (1, n_heads, seq_len, seq_len)
300
  return (1, n_heads, 1, seq_len)
301
+ elif use_sequence_id:
302
  return (1, 1, seq_len, seq_len)
303
  return None
304
  else:
 
307
  def build_attn_bias(attn_impl: str, attn_bias: torch.Tensor, n_heads: int, seq_len: int, causal: bool=False, alibi: bool=False, alibi_bias_max: int=8) -> Optional[torch.Tensor]:
308
  if attn_impl == 'flash':
309
  return None
310
+ elif attn_impl == 'torch':
311
  if alibi:
312
+ device, dtype = (attn_bias.device, attn_bias.dtype)
313
  attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
314
  return attn_bias
315
  else:
blocks.py CHANGED
@@ -8,8 +8,8 @@ from .norm import NORM_CLASS_REGISTRY
8
  try:
9
  from flash_attn.bert_padding import unpad_input, pad_input
10
  except:
11
- (unpad_input, pad_input) = (None, None)
12
- attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'qk_gn': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'sliding_window_size': -1, 'alibi': False, 'alibi_bias_max': 8, 'rope': False, 'rope_theta': 10000, 'rope_impl': 'dail', 'rope_dail_config': {'type': 'original', 'pos_idx_in_fp32': True, 'xpos_scale_base': 512}, 'rope_hf_config': {'type': 'no_scaling', 'factor': 1.0}}
13
 
14
  class MPTBlock(nn.Module):
15
 
@@ -23,8 +23,8 @@ class MPTBlock(nn.Module):
23
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
24
  assert isinstance(attn_config['attn_type'], str)
25
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
26
- args_to_exclude_in_attn_class = {'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max', 'rope', 'rope_theta', 'rope_impl', 'rope_dail_config', 'rope_hf_config'}
27
- attn_config_subset_for_attn_class = {k: v for (k, v) in attn_config.items() if k not in args_to_exclude_in_attn_class}
28
  self.norm_1 = norm_class(d_model, device=device)
29
  self.attn = attn_class(d_model=d_model, n_heads=n_heads, fc_type=fc_type, device=device, **attn_config_subset_for_attn_class, bias=not no_bias)
30
  self.norm_2 = None
@@ -37,16 +37,16 @@ class MPTBlock(nn.Module):
37
 
38
  def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, rotary_emb_w_meta_info: Optional[Dict]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True, output_attentions: bool=False, alibi_slopes: Optional[torch.Tensor]=None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]]=None) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
39
  a = self.norm_1(x)
40
- (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info)
41
  x = x + self.resid_attn_dropout(b)
42
  m = x
43
  if self.norm_2 is not None:
44
  m = self.norm_2(x)
45
- (batch_size, seq_len) = m.size()[:2]
46
  indices = None
47
  if not self.use_pad_tok_in_ffn:
48
  assert unpad_input is not None
49
- (m, indices, _, _) = unpad_input(m, attention_mask)
50
  n = self.ffn(m)
51
  if not self.use_pad_tok_in_ffn:
52
  assert pad_input is not None
 
8
  try:
9
  from flash_attn.bert_padding import unpad_input, pad_input
10
  except:
11
+ unpad_input, pad_input = (None, None)
12
+ attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'flash', 'qk_ln': False, 'qk_gn': False, 'clip_qkv': None, 'softmax_scale': None, 'attn_uses_sequence_id': False, 'sliding_window_size': -1, 'alibi': False, 'alibi_bias_max': 8, 'rope': False, 'rope_theta': 10000, 'rope_impl': 'dail', 'rope_dail_config': {'type': 'original', 'pos_idx_in_fp32': True, 'xpos_scale_base': 512}, 'rope_hf_config': {'type': 'no_scaling', 'factor': 1.0}}
13
 
14
  class MPTBlock(nn.Module):
15
 
 
23
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
24
  assert isinstance(attn_config['attn_type'], str)
25
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
26
+ args_to_exclude_in_attn_class = {'attn_type', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max', 'rope', 'rope_theta', 'rope_impl', 'rope_dail_config', 'rope_hf_config'}
27
+ attn_config_subset_for_attn_class = {k: v for k, v in attn_config.items() if k not in args_to_exclude_in_attn_class}
28
  self.norm_1 = norm_class(d_model, device=device)
29
  self.attn = attn_class(d_model=d_model, n_heads=n_heads, fc_type=fc_type, device=device, **attn_config_subset_for_attn_class, bias=not no_bias)
30
  self.norm_2 = None
 
37
 
38
  def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, rotary_emb_w_meta_info: Optional[Dict]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True, output_attentions: bool=False, alibi_slopes: Optional[torch.Tensor]=None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]]=None) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
39
  a = self.norm_1(x)
40
+ b, attn_weights, past_key_value = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info)
41
  x = x + self.resid_attn_dropout(b)
42
  m = x
43
  if self.norm_2 is not None:
44
  m = self.norm_2(x)
45
+ batch_size, seq_len = m.size()[:2]
46
  indices = None
47
  if not self.use_pad_tok_in_ffn:
48
  assert unpad_input is not None
49
+ m, indices, _, _ = unpad_input(m, attention_mask)
50
  n = self.ffn(m)
51
  if not self.use_pad_tok_in_ffn:
52
  assert pad_input is not None
builders.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import functools
3
+ import logging
4
+ import os
5
+ import re
6
+ from collections import OrderedDict
7
+ from typing import Any, ContextManager, Dict, Iterable, List, Optional, Tuple, Union
8
+ import torch
9
+ from torch.optim.optimizer import Optimizer
10
+ from torchmetrics import Metric
11
+ from transformers import AutoTokenizer, PreTrainedTokenizerBase
12
+ from .llmfoundry import registry
13
+ from .callbacks import EvalGauntlet
14
+ from .dataloader import build_dataloader
15
+ from .tiktoken import TiktokenTokenizerWrapper
16
+ from .registry_utils import construct_from_registry
17
+ log = logging.getLogger(__name__)
18
+
19
+ def build_evaluators(eval_loader_config: Optional[Union[DictConfig, ListConfig]], icl_tasks_config: Optional[Union[str, ListConfig]], eval_gauntlet_config: Optional[Union[str, DictConfig]], *, tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int, icl_seq_len: int, icl_subset_num_batches: Optional[int]) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]:
20
+ evaluators = []
21
+ if eval_loader_config is not None:
22
+ evaluators = build_eval_loaders(eval_loader_config, tokenizer, device_eval_batch_size)
23
+ logger_keys = []
24
+ eval_gauntlet_callback = None
25
+ if icl_tasks_config is not None:
26
+ icl_evaluators, logger_keys, eval_gauntlet_callback = build_icl_data_and_gauntlet(icl_tasks_config, eval_gauntlet_config, tokenizer, device_eval_batch_size, icl_seq_len, icl_subset_num_batches)
27
+ evaluators.extend(icl_evaluators)
28
+ return (evaluators, logger_keys, eval_gauntlet_callback)
29
+
30
+ def build_eval_loaders(eval_loader_config: Union[DictConfig, ListConfig], tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int) -> List[Evaluator]:
31
+ evaluators: List[Evaluator] = []
32
+ if isinstance(eval_loader_config, ListConfig):
33
+ eval_configs: ListConfig = eval_loader_config
34
+ is_multi_eval = True
35
+ else:
36
+ eval_configs = ListConfig([eval_loader_config])
37
+ is_multi_eval = False
38
+ for eval_config in eval_configs:
39
+ eval_dataloader = build_dataloader(eval_config, tokenizer, device_eval_batch_size)
40
+ eval_loader: Evaluator = Evaluator(label=f'eval/{eval_config.label}' if is_multi_eval else 'eval', dataloader=eval_dataloader, metric_names=[])
41
+ evaluators.append(eval_loader)
42
+ return evaluators
43
+
44
+ def add_metrics_to_eval_loaders(evaluators: List[Evaluator], metric_names: List[str]) -> List[Evaluator]:
45
+ eval_loaders, other_evaluators = ([], [])
46
+ for evaluator in evaluators:
47
+ if evaluator.metric_names == []:
48
+ evaluator.metric_names = metric_names
49
+ eval_loaders.append(evaluator)
50
+ else:
51
+ other_evaluators.append(evaluator)
52
+ return eval_loaders + other_evaluators
53
+
54
+ def build_icl_data_and_gauntlet(icl_tasks_config: Union[str, ListConfig], eval_gauntlet_config: Optional[Union[str, DictConfig]], tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int, icl_seq_len: int, icl_subset_num_batches: Optional[int]=None) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]:
55
+ icl_evaluators, logger_keys = build_icl_evaluators(icl_tasks_config, tokenizer, icl_seq_len, device_eval_batch_size, icl_subset_num_batches=icl_subset_num_batches)
56
+ eval_gauntlet_cb = None
57
+ if eval_gauntlet_config is not None:
58
+ if isinstance(eval_gauntlet_config, str):
59
+ with open(eval_gauntlet_config, 'r') as icl_f:
60
+ eval_gauntlet_cfg = om.load(icl_f)
61
+ eval_gauntlet = eval_gauntlet_cfg.eval_gauntlet
62
+ elif isinstance(eval_gauntlet_config, DictConfig):
63
+ eval_gauntlet = eval_gauntlet_config
64
+ else:
65
+ raise ValueError(f'Got invalid type for eval_gauntlet_config: {type(eval_gauntlet_config)}')
66
+ eval_gauntlet.logger_keys = logger_keys
67
+ eval_gauntlet.benchmark_sizes = {e.label: e.dataloader.num_samples for e in icl_evaluators}
68
+ eval_gauntlet_cb = EvalGauntlet(**eval_gauntlet)
69
+ return (icl_evaluators, logger_keys, eval_gauntlet_cb)
70
+
71
+ def build_composer_model(name: str, cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, init_context: Optional[ContextManager]=None, master_weights_dtype: Optional[str]=None) -> ComposerModel:
72
+ """Builds a ComposerModel from the registry.
73
+
74
+ Args:
75
+ name (str): Name of the model to build.
76
+ cfg (DictConfig): Configuration for the model.
77
+ tokenizer (PreTrainedTokenizerBase): Tokenizer to use.
78
+ init_context (Optional[ContextManager], optional): Context manager to use for initialization. Defaults to None.
79
+ master_weights_dtype (Optional[str], optional): Master weights dtype. Defaults to None.
80
+
81
+ Returns:
82
+ ComposerModel: _description_
83
+ """
84
+ if init_context is None:
85
+ init_context = contextlib.nullcontext()
86
+ with init_context:
87
+ model = construct_from_registry(name=name, registry=registry.models, pre_validation_function=ComposerModel, post_validation_function=None, kwargs={'om_model_config': cfg, 'tokenizer': tokenizer})
88
+ str_dtype_to_torch_dtype = {'f16': torch.float16, 'float16': torch.float16, 'bf16': torch.bfloat16, 'bfloat16': torch.bfloat16}
89
+ if master_weights_dtype is not None:
90
+ if master_weights_dtype not in str_dtype_to_torch_dtype:
91
+ raise ValueError(f'Invalid master_weights_dtype: {master_weights_dtype}. ' + f'Valid options are: {list(str_dtype_to_torch_dtype.keys())}.')
92
+ dtype = str_dtype_to_torch_dtype[master_weights_dtype]
93
+ model = model.to(dtype=dtype)
94
+ return model
95
+
96
+ def build_callback(name: str, kwargs: Optional[Dict[str, Any]]=None, config: Any=None) -> Callback:
97
+ """Builds a callback from the registry."""
98
+ registry_to_use = registry.callbacks
99
+ if name in registry.callbacks_with_config:
100
+ if kwargs is None:
101
+ kwargs = {}
102
+ if 'config' in kwargs:
103
+ raise ValueError(f'`config` is a reserved keyword for callbacks with config. Please remove it from the kwargs.')
104
+ kwargs['config'] = config
105
+ registry_to_use = registry.callbacks_with_config
106
+ return construct_from_registry(name=name, registry=registry_to_use, partial_function=True, pre_validation_function=Callback, post_validation_function=None, kwargs=kwargs)
107
+
108
+ def build_logger(name: str, kwargs: Optional[Dict[str, Any]]=None) -> LoggerDestination:
109
+ """Builds a logger from the registry."""
110
+ return construct_from_registry(name=name, registry=registry.loggers, partial_function=True, pre_validation_function=LoggerDestination, post_validation_function=None, kwargs=kwargs)
111
+
112
+ def build_algorithm(name: str, kwargs: Optional[Dict[str, Any]]=None) -> Algorithm:
113
+ """Builds an algorithm from the registry."""
114
+ return construct_from_registry(name=name, registry=registry.algorithms, partial_function=True, pre_validation_function=Algorithm, post_validation_function=None, kwargs=kwargs)
115
+
116
+ def build_metric(name: str, kwargs: Optional[Dict[str, Any]]=None) -> Metric:
117
+ """Builds a metric from the registry."""
118
+ return construct_from_registry(name=name, registry=registry.metrics, partial_function=True, pre_validation_function=Metric, post_validation_function=None, kwargs=kwargs)
119
+
120
+ def _extract_param_groups(model: torch.nn.Module, optimizer_config: Optional[Dict[str, Any]]=None) -> Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]:
121
+ """Extracts parameter groups defined in the optimizer config.
122
+
123
+ The optimizer_config defines the optimizer args. It can additionally have key
124
+ `disable_grad` which is a string or list of strings. If a string matches a
125
+ parameter name, then that parameter will have `requires_grad=False`. This is
126
+ useful for freezing parameters. It can additionally have a key
127
+ `param_groups` which is a list of dicts. In this dict, key `param_str_match`
128
+ defines a string; if a parameter name contains this string, then it will be
129
+ in this parameter group. This is useful for grouping parameters together.
130
+ The dict can also contain any other key that is a valid optimizer arg.
131
+ Note: to handle name overlap conflicts, params are assigned to parameter
132
+ groups and added to `param_groups` in the order that `param_str_match` appear
133
+ in `param_groups`.
134
+
135
+ Usage
136
+ To disable gradient for all parameters that contain the string "norm" or "bias":
137
+ ```
138
+ optimizer_config: {
139
+ "name": "decoupled_lionw",
140
+ "lr": 1e-3,
141
+ "weight_decay": 1e-2,
142
+ "betas": [0.9, 0.999],
143
+ "eps": 1e-8,
144
+ "disable_grad": ["norm", "bias"]
145
+ }
146
+ ```
147
+
148
+ To create and modify the optimizer parameters for all parameters that contain
149
+ the string "norm" and "bias" separately:
150
+ ```
151
+ optimizer_config: {
152
+ "name": "decoupled_lionw",
153
+ "lr": 1e-3,
154
+ "weight_decay": 1e-2,
155
+ "betas": [0.9, 0.999],
156
+ "eps": 1e-8,
157
+ "param_groups": [
158
+ {
159
+ "param_str_match": "norm",
160
+ "lr": 1e-4,
161
+ "weight_decay": 0.0,
162
+ },
163
+ {
164
+ "param_str_match": "bias",
165
+ "lr": 5e-4,
166
+ "weight_decay": 0.0,
167
+ },
168
+ ],
169
+ }
170
+ ```
171
+
172
+ Args:
173
+ model (torch.nn.Module): model to extract parameters from
174
+ optimizer_config (Dict[str, Any]): optimizer config
175
+
176
+ Returns:
177
+ Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: an iterable of
178
+ torch.Tensor's or dict's. Specifies what Tensors should be optimized
179
+ and their param groupings.
180
+ """
181
+ if optimizer_config is None:
182
+ return model.parameters()
183
+ if 'disable_grad' in optimizer_config.keys():
184
+ str_matches = optimizer_config.pop('disable_grad')
185
+ if isinstance(str_matches, str):
186
+ str_matches = [str_matches]
187
+ for str_match in str_matches:
188
+ for n, p in model.named_parameters():
189
+ if re.search(str_match, n):
190
+ p.requires_grad = False
191
+ log.debug(f'Setting `{n}.requires_grad = False`.')
192
+ param_groups_config = optimizer_config.pop('param_groups', None)
193
+ if param_groups_config is not None:
194
+ params = []
195
+ param_dict = OrderedDict(((n, p) for n, p in model.named_parameters()))
196
+ log.debug(f'Default optimizer settings: {optimizer_config}.')
197
+ for param_group_config in param_groups_config:
198
+ str_match = param_group_config.pop('param_str_match')
199
+ filter_fn = functools.partial(re.search, str_match)
200
+ param_names = [n for n in param_dict.keys() if filter_fn(n)]
201
+ group_params = {'params': [param_dict.pop(n) for n in param_names]}
202
+ group_params.update(param_group_config)
203
+ log.debug(f'Creating optimizer param_group with parameters: {param_names} ' + f'(extracted using str_match={str_match!r}). The param_group optimizer ' + f'setting overrides are: {param_group_config}.')
204
+ params.append(group_params)
205
+ params.insert(0, {'params': param_dict.values()})
206
+ return params
207
+ return model.parameters()
208
+
209
+ def build_optimizer(model: torch.nn.Module, name: str, optimizer_config: Optional[Dict[str, Any]]=None) -> Optimizer:
210
+ params = _extract_param_groups(model, optimizer_config)
211
+ kwargs = optimizer_config
212
+ if kwargs is None:
213
+ kwargs = {}
214
+ if 'params' in kwargs:
215
+ raise ValueError('The `params` will be automatically extracted from the model and ' + 'optimizer config. Please remove it from the optimizer config kwargs.')
216
+ kwargs['params'] = params
217
+ return construct_from_registry(name=name, registry=registry.optimizers, partial_function=True, pre_validation_function=Optimizer, post_validation_function=None, kwargs=kwargs)
218
+
219
+ def build_scheduler(name: str, scheduler_config: Optional[Dict[str, Any]]=None) -> ComposerScheduler:
220
+ return construct_from_registry(name=name, registry=registry.schedulers, partial_function=True, pre_validation_function=ComposerScheduler, post_validation_function=None, kwargs=scheduler_config)
221
+
222
+ def build_tokenizer(tokenizer_name: str, tokenizer_kwargs: Dict[str, Any]) -> PreTrainedTokenizerBase:
223
+ os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
224
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
225
+ signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup'
226
+ if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
227
+ with dist.local_rank_zero_download_and_wait(signal_file_path):
228
+ pass
229
+ if tokenizer_name.startswith('tiktoken'):
230
+ tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs)
231
+ else:
232
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_kwargs)
233
+ tokenizer.model_max_length = tokenizer_kwargs.get('model_max_length', int(1e+30))
234
+ if not hasattr(tokenizer, 'eos_token') or tokenizer.eos_token is None:
235
+ raise ValueError(f'The tokenizer {tokenizer_name} must have an eos_token.')
236
+ if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
237
+ if dist.get_local_rank() == 0:
238
+ with open(signal_file_path, 'wb') as f:
239
+ f.write(b'local_rank0_completed_tokenizer_setup')
240
+ dist.barrier()
241
+ if dist.get_local_rank() == 0:
242
+ os.remove(signal_file_path)
243
+ return tokenizer
244
+
245
+ def build_icl_evaluators(icl_tasks: Union[str, ListConfig], tokenizer: PreTrainedTokenizerBase, default_max_seq_len: int, default_batch_size: int, destination_dir: Optional[str]=None, icl_subset_num_batches: Optional[int]=None) -> Tuple[List[Evaluator], List[str]]:
246
+ if destination_dir is None:
247
+ destination_dir = os.getcwd()
248
+ evaluators = []
249
+ logger_keys = []
250
+ icl_tasks_list = None
251
+ if isinstance(icl_tasks, str):
252
+ log.info(f'Extracting ICL task config from path: {icl_tasks}')
253
+ with open(icl_tasks, 'r') as icl_f:
254
+ icl_task_cfg = om.load(icl_f)
255
+ icl_tasks_list = icl_task_cfg.icl_tasks
256
+ else:
257
+ icl_tasks_list = icl_tasks
258
+
259
+ def _validate_cfg(icl_cfg: DictConfig):
260
+ assert 'label' in icl_cfg
261
+ assert 'dataset_uri' in icl_cfg and icl_cfg.dataset_uri is not None
262
+ assert 'icl_task_type' in icl_cfg
263
+ assert 'num_fewshot' in icl_cfg
264
+ if 'metric_names' not in icl_cfg:
265
+ if icl_cfg.icl_task_type == 'language_modeling':
266
+ icl_cfg.metric_names = ['InContextLearningLMAccuracy']
267
+ elif icl_cfg.icl_task_type == 'multiple_choice':
268
+ icl_cfg.metric_names = ['InContextLearningMultipleChoiceAccuracy']
269
+ elif icl_cfg.icl_task_type == 'schema':
270
+ icl_cfg.metric_names = ['InContextLearningMultipleChoiceAccuracy']
271
+ elif icl_cfg.icl_task_type == 'question_answering':
272
+ icl_cfg.metric_names = ['InContextLearningQAAccuracy']
273
+ elif icl_cfg.icl_task_type == 'code_evaluation':
274
+ icl_cfg.metric_names = ['InContextLearningCodeEvalAccuracy']
275
+ else:
276
+ raise ValueError(f'No metric_names defined, unable to build default metrics for icl_task_type={icl_cfg.icl_task_type}.')
277
+ if 'prompt_string' not in icl_cfg:
278
+ icl_cfg.prompt_string = ''
279
+ if 'example_delimiter' not in icl_cfg:
280
+ icl_cfg.example_delimiter = '\n'
281
+ if 'continuation_delimiter' not in icl_cfg:
282
+ icl_cfg.continuation_delimiter = ' '
283
+ if 'max_seq_len' not in icl_cfg:
284
+ icl_cfg.max_seq_len = default_max_seq_len
285
+ if 'batch_size' not in icl_cfg:
286
+ icl_cfg.batch_size = default_batch_size
287
+ if 'pass_at_k' not in icl_cfg:
288
+ icl_cfg.pass_at_k = 1
289
+ if 'fewshot_random_seed' not in icl_cfg:
290
+ icl_cfg.fewshot_random_seed = 1234
291
+ if 'generations_per_sample' not in icl_cfg:
292
+ icl_cfg.generations_per_sample = 1
293
+ if 'num_beams' in icl_cfg:
294
+ raise ValueError('num_beams is no longer supported as a top level icl_task parameter.' + 'Please use generation_kwargs.num_beams instead.')
295
+ for icl_cfg in icl_tasks_list:
296
+ assert isinstance(icl_cfg, DictConfig)
297
+ _validate_cfg(icl_cfg)
298
+ for num_fewshot in list(icl_cfg.num_fewshot):
299
+ if tokenizer.pad_token_id is None:
300
+ pad_tok_id = tokenizer.eos_token_id
301
+ else:
302
+ pad_tok_id = tokenizer.pad_token_id
303
+ label = f'{icl_cfg.label}/{num_fewshot}-shot'
304
+ metric_names = list(icl_cfg.metric_names)
305
+ destination_path = f'{destination_dir}/{icl_cfg.label}-{num_fewshot}.jsonl'
306
+ if dist.get_local_rank() == 0 and os.path.exists(destination_path):
307
+ os.remove(destination_path)
308
+ dist.barrier()
309
+ hf_parsing_map = icl_cfg.get('hf_parsing_map', {})
310
+ hf_loading_vars = icl_cfg.get('hf_loading_vars', {})
311
+ early_stopping_criteria = icl_cfg.get('early_stopping_criteria', None)
312
+ if isinstance(early_stopping_criteria, ListConfig):
313
+ early_stopping_criteria = om.to_container(early_stopping_criteria)
314
+ assert early_stopping_criteria is None or isinstance(early_stopping_criteria, list)
315
+ dataloaders = get_icl_task_dataloader(icl_cfg.icl_task_type, icl_cfg.dataset_uri, tokenizer, batch_size=icl_cfg.batch_size, max_seq_len=icl_cfg.max_seq_len, pad_tok_id=pad_tok_id, num_fewshot=num_fewshot, prompt_string=icl_cfg.prompt_string, example_delimiter=icl_cfg.example_delimiter, hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, continuation_delimiter=icl_cfg.continuation_delimiter, question_prelimiter=icl_cfg.get('question_prelimiter', ''), destination_path=destination_path, fewshot_random_seed=icl_cfg.fewshot_random_seed, pass_at_k=icl_cfg.pass_at_k, generations_per_sample=icl_cfg.generations_per_sample, has_categories=icl_cfg.get('has_categories', False), cot_delimiter=icl_cfg.get('cot_delimiter', ''), generation_kwargs=icl_cfg.get('generation_kwargs', {}), early_stopping_criteria=early_stopping_criteria, do_normalization=icl_cfg.get('do_normalization', True))
316
+ if hasattr(icl_cfg, 'has_categories') and icl_cfg.has_categories and isinstance(dataloaders, dict):
317
+ for category in dataloaders.keys():
318
+ logger_keys.extend([f'metrics/{label}/{category}/{m}' for m in metric_names])
319
+ evaluators.append(Evaluator(label=f'{label}/{category}', dataloader=dataloaders[category], metric_names=metric_names))
320
+ else:
321
+ logger_keys.extend([f'metrics/{label}/{m}' for m in metric_names])
322
+ evaluators.append(Evaluator(label=label, dataloader=dataloaders, metric_names=metric_names, subset_num_batches=icl_subset_num_batches))
323
+ return (evaluators, logger_keys)
callback_with_config.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Any
3
+
4
+ class CallbackWithConfig(Callback, abc.ABC):
5
+ """A callback that takes a config dictionary as an argument, in addition to.
6
+
7
+ its other kwargs.
8
+ """
9
+
10
+ def __init__(self, config: dict[str, Any], *args: Any, **kwargs: Any) -> None:
11
+ del config, args, kwargs
12
+ pass
callbacks.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .async_eval_callback import AsyncEval
2
+ from .curriculum_learning_callback import CurriculumLearning
3
+ from .eval_gauntlet_callback import EvalGauntlet
4
+ from .fdiff_callback import FDiffMetrics
5
+ from .hf_checkpointer import HuggingFaceCheckpointer
6
+ from .monolithic_ckpt_callback import MonolithicCheckpointSaver
7
+ from .resumption_callbacks import GlobalLRScaling, LayerFreezing
8
+ from .scheduled_gc_callback import ScheduledGarbageCollector
9
+ from .registry import callbacks, callbacks_with_config
10
+ callbacks.register('lr_monitor', func=LRMonitor)
11
+ callbacks.register('memory_monitor', func=MemoryMonitor)
12
+ callbacks.register('memory_snapshot', func=MemorySnapshot)
13
+ callbacks.register('speed_monitor', func=SpeedMonitor)
14
+ callbacks.register('runtime_estimator', func=RuntimeEstimator)
15
+ callbacks.register('optimizer_monitor', func=OptimizerMonitor)
16
+ callbacks.register('generate_callback', func=Generate)
17
+ callbacks.register('early_stopper', func=EarlyStopper)
18
+ callbacks.register('fdiff_metrics', func=FDiffMetrics)
19
+ callbacks.register('hf_checkpointer', func=HuggingFaceCheckpointer)
20
+ callbacks.register('global_lr_scaling', func=GlobalLRScaling)
21
+ callbacks.register('layer_freezing', func=LayerFreezing)
22
+ callbacks.register('mono_checkpoint_saver', func=MonolithicCheckpointSaver)
23
+ callbacks.register('scheduled_gc', func=ScheduledGarbageCollector)
24
+ callbacks.register('oom_observer', func=OOMObserver)
25
+ callbacks_with_config.register('async_eval', func=AsyncEval)
26
+ callbacks_with_config.register('curriculum_learning', func=CurriculumLearning)
checkpoint_conversion_helpers.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helper methods for the checkpoint conversion scripts.
2
+
3
+ The checkpoint conversion scripts are located in the
4
+ llmfoundry/scripts/inference/benchmarking/ folder. Users should run those
5
+ scripts directly to convert between checkpoints; this file contains only common
6
+ utility functions that are present in multiple scripts.
7
+ """
8
+ import json
9
+ import logging
10
+ import os
11
+ import random
12
+ import string
13
+ from pathlib import Path
14
+ from typing import Any, Dict, Optional, Tuple, Union
15
+ import numpy as np
16
+ import sentencepiece as spm
17
+ from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
18
+ log = logging.getLogger(__name__)
19
+
20
+ def _get_weight_data_type(data_type: str):
21
+ if data_type == 'fp32':
22
+ return np.float32
23
+ elif data_type == 'fp16':
24
+ return np.float16
25
+ else:
26
+ raise RuntimeError('Unsupported data type: {data_type} for conversion.')
27
+
28
+ def get_hf_tokenizer_from_composer_state_dict(state_dict: Dict[str, Any], trust_remote_code: bool, tokenizer_save_dir: Optional[str]=None) -> Optional[PreTrainedTokenizer]:
29
+ if 'state' not in state_dict:
30
+ raise RuntimeError('Unexpected composer state dictionary. Did you pass in a full composer checkpoint?')
31
+ if 'integrations' not in state_dict['state'] or 'huggingface' not in state_dict['state']['integrations']:
32
+ raise RuntimeError('Did not find HuggingFace related state (e.g., tokenizer) in the provided composer checkpoint!')
33
+ hf_tokenizer_state = state_dict['state']['integrations']['huggingface']['tokenizer']
34
+ hf_tokenizer = None
35
+ if hf_tokenizer_state != {}:
36
+ if tokenizer_save_dir is None:
37
+ unique_suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=6))
38
+ tokenizer_save_dir = os.path.join(os.getcwd(), f'tokenizer-save-dir-{unique_suffix}')
39
+ os.makedirs(tokenizer_save_dir, exist_ok=True)
40
+ for filename, saved_content in hf_tokenizer_state.items():
41
+ if filename.endswith(saved_content['file_extension']):
42
+ tokenizer_file_name = filename
43
+ else:
44
+ tokenizer_file_name = filename + saved_content['file_extension']
45
+ tokenizer_file_path = Path(tokenizer_save_dir) / tokenizer_file_name
46
+ if saved_content['file_extension'] == '.json':
47
+ with open(tokenizer_file_path, 'w') as _tmp_file:
48
+ json.dump(saved_content['content'], _tmp_file)
49
+ elif saved_content['file_extension'] == '.txt':
50
+ with open(tokenizer_file_path, 'w') as _tmp_file:
51
+ for line in saved_content['content']:
52
+ _tmp_file.write(line)
53
+ _tmp_file.write('\n')
54
+ elif saved_content['file_extension'] == '.py':
55
+ with open(tokenizer_file_path, 'w') as _tmp_file:
56
+ _tmp_file.write(saved_content['content'])
57
+ elif saved_content['file_extension'] == '.model':
58
+ s = spm.SentencePieceProcessor()
59
+ s.load_from_serialized_proto(saved_content['content'])
60
+ with open(tokenizer_file_path, 'wb') as _tmp_file:
61
+ _tmp_file.write(s.serialized_model_proto())
62
+ hf_tokenizer = load_tokenizer(tokenizer_save_dir, trust_remote_code=trust_remote_code)
63
+ hf_tokenizer.name_or_path = ''
64
+ hf_tokenizer.init_kwargs['name_or_path'] = ''
65
+ return hf_tokenizer
66
+
67
+ def load_tokenizer(tokenizer_save_dir: str, trust_remote_code: bool) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
68
+ try:
69
+ return AutoTokenizer.from_pretrained(tokenizer_save_dir, trust_remote_code=trust_remote_code)
70
+ except ValueError as e:
71
+ raise ValueError(f'Got error while loading tokenizer with trust_remote_code={trust_remote_code}: {e}. ' + 'If accessing a tokenizer defined outside of the transformers module,' + ' please use --trust_remote_code.')
72
+
73
+ def _write_zero_bias(weight_name: str, weight_file_path: str, bias_shape: Union[Tuple[int, ...], int], np_data_type: np.dtype) -> None:
74
+ """Write zeros for bias when converting MPT to FasterTransformer weights.
75
+
76
+ MPT model might not have bias while FT expects bias.
77
+
78
+ Args:
79
+ weight_name (str): Name of the weight tensor.
80
+ weight_file_path (str): Output path for storing the weight (NOT zero bias).
81
+ bias_shape (Union[Tuple[int, ...], int]): Shape of the bias array.
82
+ np_data_type (np.dtype): The data type for bias.
83
+ """
84
+ if 'weight' not in weight_file_path:
85
+ raise RuntimeError(f'Cannot write zero bias for {weight_name}. Input is not a weight tensor')
86
+ log.debug(f'zero bias for weight: {weight_name}')
87
+ bias_file_path = weight_file_path.replace('.weight', '.bias')
88
+ bias = np.zeros(bias_shape, dtype=np_data_type)
89
+ bias.tofile(bias_file_path)
90
+
91
+ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int, tensor_name: str, config: Dict[str, Any], data: np.ndarray, np_weight_data_type: np.dtype) -> None:
92
+ """Convert each MPT weight to a FasterTransformer compatible format.
93
+
94
+ Args:
95
+ save_dir (str): Path of the directory to save the weight in FT format. The directory must already exist.
96
+ infer_gpu_num (int): The number of gpus you are planning to use for inference.
97
+ tensor_name (str): Name of the weight tensor. Used in naming the output file.
98
+ config (Dict[str, Any]): Configuration for the model. This is used in getting model specific parameters.
99
+ data (np.ndarray): Tensor data in np.ndarray format.
100
+
101
+ Returns:
102
+ None: Writes to a file in `save_dir`. File name is based on the `tensor_name`
103
+ """
104
+ if tensor_name.find('input_layernorm.weight') != -1 or tensor_name.find('input_layernorm.bias') != -1 or tensor_name.find('attention.dense.bias') != -1 or (tensor_name.find('post_attention_layernorm.weight') != -1) or (tensor_name.find('post_attention_layernorm.bias') != -1) or (tensor_name.find('mlp.dense_4h_to_h.bias') != -1) or (tensor_name.find('final_layernorm.weight') != -1) or (tensor_name.find('final_layernorm.bias') != -1):
105
+ save_path = os.path.join(save_dir, f'model.{tensor_name}.bin')
106
+ data.tofile(save_path)
107
+ if 'weight' in tensor_name and config['no_bias']:
108
+ _write_zero_bias(tensor_name, save_path, data.shape[-1], np_weight_data_type)
109
+ elif tensor_name.find('attention.dense.weight') != -1:
110
+ assert data.shape == (config['d_model'], config['d_model']), f'unexpected dim for {tensor_name}'
111
+ data = data.T
112
+ split_vals = np.split(data, infer_gpu_num, axis=0)
113
+ for j in range(infer_gpu_num):
114
+ save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin')
115
+ split_vals[j].tofile(save_path)
116
+ if config['no_bias']:
117
+ fake_weight_path = os.path.join(save_dir, f'model.{tensor_name}.bin')
118
+ _write_zero_bias(tensor_name, fake_weight_path, data.shape[-1], np_weight_data_type)
119
+ elif tensor_name.find('mlp.dense_4h_to_h.weight') != -1:
120
+ assert data.shape == (config['d_model'], config['expansion_ratio'] * config['d_model']), f'unexpected dim for {tensor_name}'
121
+ data = data.T
122
+ split_vals = np.split(data, infer_gpu_num, axis=0)
123
+ for j in range(infer_gpu_num):
124
+ save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin')
125
+ split_vals[j].tofile(save_path)
126
+ if config['no_bias']:
127
+ fake_weight_path = os.path.join(save_dir, f'model.{tensor_name}.bin')
128
+ _write_zero_bias(tensor_name, fake_weight_path, data.shape[-1], np_weight_data_type)
129
+ elif tensor_name.find('mlp.dense_h_to_4h.weight') != -1:
130
+ assert data.shape == (config['expansion_ratio'] * config['d_model'], config['d_model']), f'unexpected dim for {tensor_name}'
131
+ data = data.T
132
+ split_vals = np.split(data, infer_gpu_num, axis=-1)
133
+ for j in range(infer_gpu_num):
134
+ save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin')
135
+ split_vals[j].tofile(save_path)
136
+ if config['no_bias']:
137
+ _write_zero_bias(tensor_name, save_path, split_vals[j].shape[-1], np_weight_data_type)
138
+ elif tensor_name.find('mlp.dense_h_to_4h.bias') != -1:
139
+ assert data.shape == (config['expansion_ratio'] * config['d_model'],), f'unexpected dim for {tensor_name}'
140
+ split_vals = np.split(data, infer_gpu_num, axis=-1)
141
+ for j in range(infer_gpu_num):
142
+ save_path = os.path.join(save_dir + f'model.{tensor_name}.{j}.bin')
143
+ split_vals[j].tofile(save_path)
144
+ elif tensor_name.find('attention.query_key_value.bias') != -1:
145
+ assert data.shape == (3 * config['d_model'],), f'unexpected dim for {tensor_name}'
146
+ data = data.reshape(3, config['d_model'])
147
+ split_vals = np.split(data, infer_gpu_num, axis=-1)
148
+ for j in range(infer_gpu_num):
149
+ save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin')
150
+ split_vals[j].tofile(save_path)
151
+ elif tensor_name.find('attention.query_key_value.weight') != -1:
152
+ assert data.shape == (3 * config['d_model'], config['d_model']), f'unexpected dim for {tensor_name}'
153
+ data = data.T
154
+ data = data.reshape(config['d_model'], 3, config['d_model'])
155
+ split_vals = np.split(data, infer_gpu_num, axis=-1)
156
+ for j in range(infer_gpu_num):
157
+ save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin')
158
+ split_vals[j].tofile(save_path)
159
+ if config['no_bias']:
160
+ _write_zero_bias(tensor_name, save_path, (3, split_vals[j].shape[-1]), np_weight_data_type)
161
+ else:
162
+ raise RuntimeError(f'Tensor with name {tensor_name} is not handled')
163
+
164
+ def convert_and_save_ft_weights(named_params: dict, config: dict, infer_gpu_num: int=1, weight_data_type: str='fp32', save_dir: str='') -> None:
165
+ """Convert a Composer MPT checkpoint to a FasterTransformer format.
166
+
167
+ Args:
168
+ named_params (Dict[str, Parameter]): A dictionary containing the Composer MPT model's parameter names and data.
169
+ config (Dict[str, Any]): Configuration for the model. This is used in getting model specific parameters.
170
+ infer_gpu_num (int): The number of gpus you are planning to use for inference.
171
+ weight_data_type (str): The dtype of the converted FasterTransformer model.
172
+ save_dir (str): Path of the directory to save the weight in FT format. The directory must already exist.
173
+
174
+ Returns:
175
+ None: Writes to the `save_dir` folder. File names within this folder are based on the model parameter names.
176
+ """
177
+ np_weight_data_type = _get_weight_data_type(weight_data_type)
178
+ param_remapping = {'norm_1.bias': 'input_layernorm.bias', 'norm_1.weight': 'input_layernorm.weight', 'attn.Wqkv.bias': 'attention.query_key_value.bias', 'attn.Wqkv.weight': 'attention.query_key_value.weight', 'attn.out_proj.bias': 'attention.dense.bias', 'attn.out_proj.weight': 'attention.dense.weight', 'norm_2.bias': 'post_attention_layernorm.bias', 'norm_2.weight': 'post_attention_layernorm.weight', 'ffn.up_proj.bias': 'mlp.dense_h_to_4h.bias', 'ffn.up_proj.weight': 'mlp.dense_h_to_4h.weight', 'ffn.down_proj.bias': 'mlp.dense_4h_to_h.bias', 'ffn.down_proj.weight': 'mlp.dense_4h_to_h.weight'}
179
+ for name, param in named_params.items():
180
+ log.debug(f'Working on parameter {name} ...')
181
+ data = param.detach().cpu().numpy().astype(np_weight_data_type)
182
+ if name.find('weight') == -1 and name.find('bias') == -1:
183
+ log.debug(f'found a parameter name that is not handled: {name}')
184
+ continue
185
+ if name == 'transformer.wpe.weight':
186
+ assert data.shape == (config['max_seq_len'], config['d_model']), f'unexpected dim for {name}'
187
+ data.tofile(os.path.join(save_dir, 'model.wpe.bin'))
188
+ elif name == 'transformer.wte.weight':
189
+ assert data.shape == (config['vocab_size'], config['d_model']), f'unexpected dim for {name}'
190
+ data.tofile(os.path.join(save_dir, 'model.wte.bin'))
191
+ elif name == 'transformer.norm_f.bias':
192
+ assert data.shape == (config['d_model'],), f'unexpected dim for {name}'
193
+ data.tofile(os.path.join(save_dir, 'model.final_layernorm.bias.bin'))
194
+ elif name == 'transformer.norm_f.weight':
195
+ assert data.shape == (config['d_model'],), f'unexpected dim for {name}'
196
+ save_path = os.path.join(save_dir, 'model.final_layernorm.weight.bin')
197
+ data.tofile(save_path)
198
+ if config['no_bias']:
199
+ _write_zero_bias(name, save_path, data.shape[-1], np_weight_data_type)
200
+ elif name == 'transformer.lm_head.weight':
201
+ data.tofile(os.path.join(save_dir, 'model.lm_head.weight.bin'))
202
+ else:
203
+ for mpt_pattern, ft_pattern in param_remapping.items():
204
+ if name.find(mpt_pattern) != -1:
205
+ new_name = name.replace('transformer.blocks.', 'layers.').replace(mpt_pattern, ft_pattern)
206
+ _convert_weight_to_ft_each(save_dir, infer_gpu_num, new_name, config, data, np_weight_data_type)
collator.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import warnings
3
+ from typing import Any, Dict, List, Optional, Union
4
+ import torch
5
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
6
+ log = logging.getLogger(__name__)
7
+ _HF_IGNORE_INDEX = -100
8
+ TokenizedExample = Dict[str, List[Dict[str, List[int]]]]
9
+
10
+ def ensure_list(x: Union[List, torch.Tensor]) -> List:
11
+ if isinstance(x, torch.Tensor):
12
+ x = list(x.flatten())
13
+ assert isinstance(x, list)
14
+ return x
15
+
16
+ def validate_target_settings(target_prompts: str, target_responses: str, decoder_only_format: bool):
17
+ """Raises an error if target settings are invalid."""
18
+ if not decoder_only_format and (target_prompts != 'none' or target_responses != 'last'):
19
+ raise ValueError(f'When using encoder_decoder format, you must use target_prompts="none" and target_responses="last".')
20
+ if target_responses not in {'all', 'last'}:
21
+ raise ValueError(f'target_responses must be either "last" or "all" but target_responses={target_responses!r}')
22
+ if target_prompts.startswith('length>='):
23
+ cutoff = target_prompts[8:]
24
+ if not cutoff.isdigit():
25
+ raise ValueError(f'target_prompts starts with "length>=" but the rest of the string is not digits (target_prompts={target_prompts!r}). ' + 'To use this configuration option, set target_prompts "length>=XX" where "XX" is a positive integer indicating ' + 'the length cutoff. Prompts of at least XX tokens in length will be treated as targets.')
26
+ cutoff = int(cutoff)
27
+ if cutoff <= 0:
28
+ raise ValueError(f'You are trying to set the target_prompts length cutoff to a negative number cutoff={cutoff!r}. This is not allowed.')
29
+ elif target_prompts not in {'all', 'none'}:
30
+ raise ValueError(f'target_prompts must either be "all", "none" or "length>=XX" where "XX" is a positive integer, but target_prompts={target_prompts!r}')
31
+
32
+ def _sequence_to_labels_all(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]:
33
+ del is_last_turn, cutoff
34
+ return sequence
35
+
36
+ def _sequence_to_labels_none(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]:
37
+ del is_last_turn, cutoff
38
+ return [_HF_IGNORE_INDEX] * len(sequence)
39
+
40
+ def _sequence_to_labels_last(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]:
41
+ del cutoff
42
+ if is_last_turn:
43
+ return sequence
44
+ else:
45
+ return [_HF_IGNORE_INDEX] * len(sequence)
46
+
47
+ def _sequence_to_labels_cutoff(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]:
48
+ del is_last_turn
49
+ if cutoff is None:
50
+ raise ValueError('input ``cutoff`` must be provided')
51
+ if len(sequence) >= cutoff:
52
+ return sequence
53
+ else:
54
+ return [_HF_IGNORE_INDEX] * len(sequence)
55
+ _TARGET_POLICY_LOOKUP = {'all': _sequence_to_labels_all, 'none': _sequence_to_labels_none, 'last': _sequence_to_labels_last, 'length': _sequence_to_labels_cutoff}
56
+
57
+ def stitch_turns_decoder_only(example_turns: list[dict[str, list[int]]], target_prompts: str, target_responses: str, eos_token_id: Optional[int]=None, validate: bool=False) -> tuple[list[int], list[int]]:
58
+ target_prompts = target_prompts.lower()
59
+ target_responses = target_responses.lower()
60
+ if validate:
61
+ validate_target_settings(target_prompts, target_responses, decoder_only_format=True)
62
+ if target_prompts.startswith('length'):
63
+ prompt_cutoff = int(target_prompts.split('>=')[-1])
64
+ prompt_to_target = _TARGET_POLICY_LOOKUP['length']
65
+ else:
66
+ prompt_cutoff = None
67
+ prompt_to_target = _TARGET_POLICY_LOOKUP[target_prompts]
68
+ response_to_target = _TARGET_POLICY_LOOKUP[target_responses]
69
+ input_ids = []
70
+ labels = []
71
+ for idx, turn in enumerate(example_turns):
72
+ is_last_turn = idx + 1 == len(example_turns)
73
+ context = ensure_list(turn['input_ids'])
74
+ target = ensure_list(turn['labels'])
75
+ if is_last_turn and eos_token_id is not None:
76
+ if target[-1] != eos_token_id:
77
+ target = target + [eos_token_id]
78
+ input_ids += context
79
+ input_ids += target
80
+ labels += prompt_to_target(context, is_last_turn, prompt_cutoff)
81
+ labels += response_to_target(target, is_last_turn)
82
+ if len(input_ids) != len(labels):
83
+ raise ValueError(f'input_ids and labels should be the same length, len(input_ids)={len(input_ids)!r}, len(labels)={len(labels)!r}')
84
+ return (input_ids, labels)
85
+
86
+ def stitch_turns_encoder_decoder(example_turns: list[dict[str, list[int]]], eos_token_id: Optional[int]=None) -> tuple[list[int], list[int]]:
87
+ context = []
88
+ target = None
89
+ for idx, turn in enumerate(example_turns):
90
+ is_last_turn = idx + 1 == len(example_turns)
91
+ turn_context = ensure_list(turn['input_ids'])
92
+ turn_target = ensure_list(turn['labels'])
93
+ context += turn_context
94
+ if is_last_turn:
95
+ if eos_token_id is not None and turn_target[-1] != eos_token_id:
96
+ turn_target = turn_target + [eos_token_id]
97
+ target = turn_target
98
+ else:
99
+ context += turn_target
100
+ if target is None:
101
+ raise ValueError('target is still None but should be list[int]')
102
+ return (context, target)
103
+
104
+ class Seq2SeqFinetuningCollator:
105
+ """A general-purpose collator for sequence-to-sequence training/evaluation.
106
+
107
+ Args:
108
+ tokenizer: A HuggingFace tokenizer. Must have a pad_token set.
109
+ max_seq_len (int): The maximum sequence length of the combined
110
+ context/target sequence (decoder-only format) or of each the
111
+ context sequence and target sequence (encoder-decoder format).
112
+ decoder_only_format (bool): Whether to format the batches for a
113
+ decoder-only model (if True) or an encoder-decoder model (if False).
114
+ target_responses (str): For multi-turn examples, this controls which
115
+ responses are treated as training targets (i.e. generate loss).
116
+ Options are:
117
+ "last": (Default) Only the final response is used as the training
118
+ target; non-terminal responses are only part of the context.
119
+ "all": All of the responses are used as training targets.
120
+ target_prompts (str): This controls which prompts are treated as
121
+ training targets (i.e. generate loss).
122
+ Options are:
123
+ "none": (Default) Prompts are never used as training targets.
124
+ "all": Prompts are always used as training targets.
125
+ "length>=XX": Prompt sequences are used as training targets when
126
+ they have length of at least XX tokens. For instance,
127
+ setting "length>=512" instructs the collator to use a prompt
128
+ sequence as a training target when it is at least 512 tokens long.
129
+ allow_pad_trimming (bool, optional): Whether to allow the collator
130
+ to trim padding, which may result in smaller but inconsistent batch
131
+ sizes. Default: ``False`` ensures that all sequences are max_seq_len.
132
+ batch_metadata (dict, optional): A dictionary of metadata which will be added
133
+ to the batch.
134
+ """
135
+
136
+ def __init__(self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], max_seq_len: int, decoder_only_format: bool, target_responses: str='last', target_prompts: str='none', allow_pad_trimming: bool=False, batch_metadata: Optional[Dict[str, Any]]=None):
137
+ self.tokenizer = tokenizer
138
+ self.max_seq_len = max_seq_len
139
+ self.decoder_only_format = decoder_only_format
140
+ self.target_responses = target_responses.lower()
141
+ self.target_prompts = target_prompts.lower()
142
+ self.batch_metadata = batch_metadata or {}
143
+ self._allow_pad_trimming = allow_pad_trimming
144
+ self._seen_first_batch = False
145
+ illegal_keys = ['input_ids', 'labels', 'attention_mask', 'decoder_input_ids', 'decoder_attention_mask']
146
+ found_keys = []
147
+ for illegal_key in illegal_keys:
148
+ if illegal_key in self.batch_metadata:
149
+ found_keys.append(illegal_key)
150
+ if found_keys:
151
+ raise ValueError(f"The following keys are in batch_metadata but are not allowed: {', '.join(found_keys)}.\n" + f'You cannot use keys that are used directly by the models. The prohibited keys are:\n' + f"{', '.join(illegal_keys)}")
152
+ if max_seq_len % 8 != 0:
153
+ log.warning('For performance, a max_seq_len as a multiple of 8 is recommended.')
154
+ if self.tokenizer.pad_token_id is None:
155
+ raise ValueError(f'{self.__class__.__name__} requires that the tokenizer has the pad token set, but it is None')
156
+ validate_target_settings(self.target_prompts, self.target_responses, self.decoder_only_format)
157
+ if self.target_prompts.startswith('length'):
158
+ self.prompt_cutoff = int(self.target_prompts.split('>=')[-1])
159
+ self.prompt_to_target = _TARGET_POLICY_LOOKUP['length']
160
+ else:
161
+ self.prompt_cutoff = None
162
+ self.prompt_to_target = _TARGET_POLICY_LOOKUP[self.target_prompts]
163
+ self.response_to_target = _TARGET_POLICY_LOOKUP[self.target_responses]
164
+ self._warned_truncated = False
165
+ self._warned_context = False
166
+ self._warned_target = False
167
+
168
+ def __call__(self, examples: List[TokenizedExample]) -> Dict[str, torch.Tensor]:
169
+ for check_key in ['input_ids', 'labels']:
170
+ if check_key not in examples[0]['turns'][0]:
171
+ raise KeyError(f'Examples returned by dataset do not include required key: {check_key}')
172
+ if self.decoder_only_format:
173
+ batch = self._process_and_batch_decoder_only(examples)
174
+ else:
175
+ batch = self._process_and_batch_encoder_decoder(examples)
176
+ batch_size = batch['input_ids'].shape[0]
177
+ batch.update({k: torch.tensor([v] * batch_size) for k, v in self.batch_metadata.items()})
178
+ return batch
179
+
180
+ def _process_and_batch_decoder_only(self, examples: List[TokenizedExample]) -> Dict[str, torch.Tensor]:
181
+ processed_examples = []
182
+ for example in examples:
183
+ input_ids, labels = stitch_turns_decoder_only(example_turns=example['turns'], target_prompts=self.target_prompts, target_responses=self.target_responses, eos_token_id=self.tokenizer.eos_token_id)
184
+ orig_size = len(input_ids)
185
+ if orig_size > self.max_seq_len:
186
+ input_ids = input_ids[:self.max_seq_len]
187
+ labels = labels[:self.max_seq_len]
188
+ if len([l for l in labels if l != _HF_IGNORE_INDEX]) == 0:
189
+ raise ValueError(f'Truncating to max_seq_len={self.max_seq_len} has removed all loss-generating tokens. ' + f'Pre-truncation sequence length was {orig_size}. ' + 'This sample should have been filtered out before reaching the collator. If using ' + 'pre-tokenized streaming data, this may have resulted from using different ' + '``target_prompts``, ``target_responses``, or ``max_seq_len`` ' + 'settings when preparing the streaming dataset than what are currently being used.')
190
+ if not self._warned_truncated:
191
+ warnings.warn(f'Truncating sequence of length={orig_size} to fit max_seq_len={self.max_seq_len}. ' + f'If truncation is a problem, consider increasing max_seq_len.')
192
+ self._warned_truncated = True
193
+ attention_mask = [1] * len(input_ids)
194
+ n_total = len(input_ids)
195
+ i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - n_total)
196
+ if self.tokenizer.padding_side == 'left':
197
+ labels = i_pad + labels
198
+ else:
199
+ labels = labels + i_pad
200
+ processed_example = {'input_ids': input_ids, 'labels': labels, 'attention_mask': attention_mask}
201
+ processed_examples.append(processed_example)
202
+ batch = self.tokenizer.pad(processed_examples, padding='max_length', max_length=self.max_seq_len, return_tensors='pt')
203
+ batch['sequence_id'] = batch['attention_mask'] - 1
204
+ if not (self._allow_pad_trimming and self._seen_first_batch):
205
+ self._seen_first_batch = True
206
+ return batch
207
+ self._seen_first_batch = True
208
+ multiple_of = 8
209
+ n_non_padding = batch['attention_mask'].sum(dim=1).max()
210
+ keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
211
+ for k, v in batch.items():
212
+ if len(v.shape) < 2:
213
+ continue
214
+ if self.tokenizer.padding_side == 'left':
215
+ batch[k] = v[:, -keep_tokens:].contiguous()
216
+ else:
217
+ batch[k] = v[:, :keep_tokens].contiguous()
218
+ return batch
219
+
220
+ def _process_and_batch_encoder_decoder(self, examples: List[TokenizedExample]) -> Dict[str, torch.Tensor]:
221
+ processed_examples = []
222
+ for example in examples:
223
+ context, target = stitch_turns_encoder_decoder(example_turns=example['turns'], eos_token_id=self.tokenizer.eos_token_id)
224
+ if len(target) < self.max_seq_len:
225
+ i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - len(target))
226
+ target = target + i_pad
227
+ else:
228
+ if not self._warned_target:
229
+ warnings.warn(f'Truncating TARGET sequence of length={len(target)} ' + f'to max_seq_len={self.max_seq_len}. If truncation is ' + f'a problem, consider increasing max_seq_len.')
230
+ self._warned_target = True
231
+ target = target[:self.max_seq_len - 1] + [self.tokenizer.eos_token_id]
232
+ if len(context) > self.max_seq_len:
233
+ if not self._warned_context:
234
+ warnings.warn(f'Truncating CONTEXT sequence of length={len(context)} ' + f'to max_seq_len={self.max_seq_len}. If truncation is ' + f'a problem, consider increasing max_seq_len.')
235
+ self._warned_context = True
236
+ context = context[:self.max_seq_len - 1] + [self.tokenizer.eos_token_id]
237
+ processed_example = {'input_ids': context, 'labels': target, 'attention_mask': [1] * len(context)}
238
+ processed_examples.append(processed_example)
239
+ batch = self.tokenizer.pad(processed_examples, padding='max_length', max_length=self.max_seq_len, return_tensors='pt')
240
+ batch['decoder_input_ids'] = torch.cat([torch.full((len(processed_examples), 1), self.tokenizer.pad_token_id), batch['labels'][:, :-1]], dim=1)
241
+ batch['decoder_input_ids'].masked_fill_(batch['decoder_input_ids'] == _HF_IGNORE_INDEX, self.tokenizer.pad_token_id)
242
+ batch['decoder_attention_mask'] = torch.not_equal(batch['labels'], _HF_IGNORE_INDEX)
243
+ if not (self._allow_pad_trimming and self._seen_first_batch):
244
+ self._seen_first_batch = True
245
+ return batch
246
+ self._seen_first_batch = True
247
+ multiple_of = 8
248
+ n_non_padding = batch['attention_mask'].sum(dim=1).max()
249
+ keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
250
+ for k in ['input_ids', 'attention_mask']:
251
+ batch[k] = batch[k][:, :keep_tokens].contiguous()
252
+ n_non_padding = batch['decoder_attention_mask'].sum(dim=1).max()
253
+ keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
254
+ for k in ['decoder_input_ids', 'decoder_attention_mask', 'labels']:
255
+ batch[k] = batch[k][:, :keep_tokens].contiguous()
256
+ return batch
config_utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import logging
3
+ import math
4
+ import warnings
5
+ from typing import Any, Dict, Literal, Mapping, Optional, Tuple, Union
6
+ from .utils import init_empty_weights
7
+ log = logging.getLogger(__name__)
8
+
9
+ def pop_config(cfg: DictConfig, key: str, must_exist: bool=True, default_value: Any=None, convert: bool=False) -> Any:
10
+ """Pop a value from the main config file and return it.
11
+
12
+ If the key does not exist, return the default_value or raise a RuntimeError
13
+ depending on the must_exist flag. If the convert flag is set to True, then
14
+ we will convert the value to a python object using OmegaConf.to_container.
15
+ """
16
+ value = cfg.pop(key, None)
17
+ if value is not None and convert:
18
+ if not isinstance(value, DictConfig) and (not isinstance(value, ListConfig)):
19
+ raise ValueError(f'The key {key} has a value of type {type(value)} that cannot be converted to a dict or list. Please check your yaml.')
20
+ return om.to_container(value)
21
+ elif value is not None:
22
+ return value
23
+ elif must_exist:
24
+ raise NameError(f'The {key} parameter is missing and must exist for execution. Please check your yaml.')
25
+ else:
26
+ return default_value
27
+
28
+ def calculate_batch_size_info(global_batch_size: int, device_microbatch_size: Union[int, Literal['auto']]) -> Tuple[int, Union[int, Literal['auto']], Union[int, Literal['auto']]]:
29
+ if global_batch_size % dist.get_world_size() != 0:
30
+ raise ValueError(f'Global batch size {global_batch_size} is not divisible by {dist.get_world_size()} ' + 'as a result, the batch size would be truncated, please adjust `global_batch_size` ' + f'to be divisible by world size, {dist.get_world_size()}.')
31
+ device_batch_size = global_batch_size // dist.get_world_size()
32
+ if device_microbatch_size == 'auto':
33
+ device_grad_accum = 'auto'
34
+ elif isinstance(device_microbatch_size, int):
35
+ if device_microbatch_size > device_batch_size:
36
+ log.warn(f'device_microbatch_size > device_batch_size, ' + f'will be reduced from {device_microbatch_size} -> {device_batch_size}.')
37
+ device_microbatch_size = device_batch_size
38
+ device_grad_accum = math.ceil(device_batch_size / device_microbatch_size)
39
+ else:
40
+ raise ValueError(f'Not sure how to parse device_microbatch_size={device_microbatch_size!r}')
41
+ return (device_batch_size, device_microbatch_size, device_grad_accum)
42
+
43
+ def update_batch_size_info(cfg: DictConfig) -> DictConfig:
44
+ device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info(cfg.global_train_batch_size, cfg.device_train_microbatch_size)
45
+ cfg.n_gpus = dist.get_world_size()
46
+ cfg.device_train_batch_size = device_train_batch_size
47
+ cfg.device_train_microbatch_size = device_train_microbatch_size
48
+ cfg.device_train_grad_accum = device_train_grad_accum
49
+ if 'device_eval_batch_size' not in cfg:
50
+ if cfg.device_train_microbatch_size == 'auto':
51
+ cfg.device_eval_batch_size = 1
52
+ else:
53
+ cfg.device_eval_batch_size = cfg.device_train_microbatch_size
54
+ return cfg
55
+
56
+ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]):
57
+ init_context = contextlib.nullcontext()
58
+ if 'init_device' in model_cfg:
59
+ assert model_cfg.init_device in ['meta', 'cpu', 'mixed']
60
+ if fsdp_config is None and model_cfg.init_device == 'meta':
61
+ warnings.warn("Using `cfg.model.init_device='meta'` is only valid when using FSDP! " + "Reverting to `cfg.model.init_device='cpu'`.")
62
+ model_cfg.init_device = 'cpu'
63
+ if model_cfg.init_device == 'meta':
64
+ init_context = init_empty_weights()
65
+ if model_cfg.init_device == 'mixed':
66
+ if fsdp_config is None:
67
+ raise NotImplementedError('Using init_device `mixed` is only supported with FSDP. ' + 'Please add a FSDP config.')
68
+ if not fsdp_config.get('sync_module_states', False):
69
+ warnings.warn('Setting `sync_module_states = True` for FSDP. This is required when using mixed initialization.')
70
+ fsdp_config['sync_module_states'] = True
71
+ fsdp_config.setdefault('use_orig_params', False)
72
+ fsdp_config.setdefault('load_monolith_rank0_only', True)
73
+ master_dtype = model_cfg.get('master_weights_dtype')
74
+ small_dtypes = ('bf16', 'fp16', 'float16', 'bfloat16', 'amp_fp16', 'amp_bf16')
75
+ if fsdp_config and master_dtype in small_dtypes:
76
+ reduce_dtype = None
77
+ buffer_dtype = None
78
+ mixed_precision = fsdp_config.get('mixed_precision')
79
+ if isinstance(mixed_precision, Mapping):
80
+ reduce_dtype = mixed_precision.get('reduce_dtype')
81
+ buffer_dtype = mixed_precision.get('buffer_dtype')
82
+ fsdp_config['mixed_precision'] = {'param_dtype': None, 'reduce_dtype': reduce_dtype, 'buffer_dtype': buffer_dtype, 'keep_low_precision_grads': True}
83
+ return init_context
84
+
85
+ def log_config(cfg: DictConfig) -> None:
86
+ """Logs the current config and updates the wandb and mlflow configs.
87
+
88
+ This function can be called multiple times to update the wandb and MLflow
89
+ config with different variables.
90
+ """
91
+ print(om.to_yaml(cfg))
92
+ if 'wandb' in cfg.get('loggers', {}):
93
+ try:
94
+ import wandb
95
+ except ImportError as e:
96
+ raise e
97
+ if wandb.run:
98
+ wandb.config.update(om.to_container(cfg, resolve=True))
99
+ if 'mlflow' in cfg.get('loggers', {}):
100
+ try:
101
+ import mlflow
102
+ except ImportError as e:
103
+ raise e
104
+ if mlflow.active_run():
105
+ mlflow.log_params(params=om.to_container(cfg, resolve=True))
configuration_mpt.py CHANGED
@@ -2,12 +2,11 @@
2
  import warnings
3
  from typing import Any, Dict, Optional, Union
4
  from transformers import PretrainedConfig
5
- from .attention import check_alibi_support, is_flash_v1_installed, is_flash_v2_installed
6
  from .blocks import attn_config_defaults
7
  from .fc import FC_CLASS_REGISTRY
8
  from .norm import LPLayerNorm
9
  from .ffn import FFN_CLASS_REGISTRY
10
- from .warnings import VersionedDeprecationWarning
11
  ffn_config_defaults: Dict = {'ffn_type': 'mptmlp'}
12
  init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0}
13
 
@@ -30,16 +29,13 @@ class MPTConfig(PretrainedConfig):
30
  attn_config (Dict): A dictionary used to configure the model's attention module:
31
  attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention, grouped_query_attention
32
  attn_pdrop (float): The dropout probability for the attention layers.
33
- attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
34
  qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
35
  qk_gn (bool): Whether to apply group normalization to the queries and keys in the attention layer.
36
  clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
37
  this value.
38
  softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
39
  use the default scale of ``1/sqrt(d_keys)``.
40
- prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
41
- extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
42
- can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
43
  attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
44
  When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
45
  which sub-sequence each token belongs to.
@@ -116,7 +112,7 @@ class MPTConfig(PretrainedConfig):
116
  self._validate_config()
117
 
118
  def _set_config_defaults(self, config: Dict[str, Any], config_defaults: Dict[str, Any]) -> Dict[str, Any]:
119
- for (k, v) in config_defaults.items():
120
  if k not in config:
121
  config[k] = v
122
  elif isinstance(v, dict):
@@ -131,18 +127,12 @@ class MPTConfig(PretrainedConfig):
131
  raise ValueError('d_model must be divisible by n_heads')
132
  if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):
133
  raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
134
- if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
135
  raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
136
- if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
137
- raise NotImplementedError('prefix_lm only implemented with torch and triton attention.')
138
- if self.attn_config['attn_impl'] == 'flash' and is_flash_v1_installed():
139
- warnings.warn(VersionedDeprecationWarning('Support for Flash Attention v1 is deprecated. Please upgrade to Flash Attention v2.4.2. To install Flash Attention v2.4.2, please run `pip install -e ".[gpu-flash2]"` from the root directory of the llm-foundry repository.', remove_version='0.6.0'))
140
- if self.attn_config['attn_impl'] == 'triton' and (not self.attn_config['prefix_lm']):
141
- warnings.warn(UserWarning('If not using a Prefix Language Model, we recommend setting "attn_impl" to "flash" instead of "triton".'))
142
  if self.attn_config['alibi'] and (not check_alibi_support(self.attn_config['attn_impl'])):
143
- raise NotImplementedError('alibi only implemented with torch, triton, and flash (v2.4.2 or higher) attention.')
144
- if self.attn_config['attn_uses_sequence_id'] and (not (self.attn_config['attn_impl'] in ['torch', 'triton'] or (self.attn_config['attn_impl'] == 'flash' and is_flash_v2_installed(v2_version='v2.1.2')))):
145
- raise NotImplementedError('attn_uses_sequence_id only implemented with torch, triton, and flash (v2.1.2 or higher) attention.')
146
  if self.attn_config['rope'] and self.attn_config['rope_impl'] not in ['dail', 'hf']:
147
  raise ValueError('If rope is being used then rope_impl should be either "dail", or "hf".')
148
  if self.attn_config['rope'] and self.attn_config['rope_impl'] == 'hf' and (self.attn_config['rope_hf_config']['type'] not in ['no_scaling', 'linear', 'dynamic']):
 
2
  import warnings
3
  from typing import Any, Dict, Optional, Union
4
  from transformers import PretrainedConfig
5
+ from .attention import check_alibi_support, is_flash_v2_installed
6
  from .blocks import attn_config_defaults
7
  from .fc import FC_CLASS_REGISTRY
8
  from .norm import LPLayerNorm
9
  from .ffn import FFN_CLASS_REGISTRY
 
10
  ffn_config_defaults: Dict = {'ffn_type': 'mptmlp'}
11
  init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0}
12
 
 
29
  attn_config (Dict): A dictionary used to configure the model's attention module:
30
  attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention, grouped_query_attention
31
  attn_pdrop (float): The dropout probability for the attention layers.
32
+ attn_impl (str): The attention implementation to use. One of 'torch' or 'flash'.
33
  qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
34
  qk_gn (bool): Whether to apply group normalization to the queries and keys in the attention layer.
35
  clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
36
  this value.
37
  softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
38
  use the default scale of ``1/sqrt(d_keys)``.
 
 
 
39
  attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
40
  When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
41
  which sub-sequence each token belongs to.
 
112
  self._validate_config()
113
 
114
  def _set_config_defaults(self, config: Dict[str, Any], config_defaults: Dict[str, Any]) -> Dict[str, Any]:
115
+ for k, v in config_defaults.items():
116
  if k not in config:
117
  config[k] = v
118
  elif isinstance(v, dict):
 
127
  raise ValueError('d_model must be divisible by n_heads')
128
  if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):
129
  raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
130
+ if self.attn_config['attn_impl'] not in ['torch', 'flash']:
131
  raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
 
 
 
 
 
 
132
  if self.attn_config['alibi'] and (not check_alibi_support(self.attn_config['attn_impl'])):
133
+ raise NotImplementedError('alibi only implemented with torch and flash (v2.4.2 or higher) attention.')
134
+ if self.attn_config['attn_uses_sequence_id'] and (not (self.attn_config['attn_impl'] == 'torch' or (self.attn_config['attn_impl'] == 'flash' and is_flash_v2_installed(v2_version='v2.1.2')))):
135
+ raise NotImplementedError('attn_uses_sequence_id only implemented with torch and flash (v2.1.2 or higher) attention.')
136
  if self.attn_config['rope'] and self.attn_config['rope_impl'] not in ['dail', 'hf']:
137
  raise ValueError('If rope is being used then rope_impl should be either "dail", or "hf".')
138
  if self.attn_config['rope'] and self.attn_config['rope_impl'] == 'hf' and (self.attn_config['rope_hf_config']['type'] not in ['no_scaling', 'linear', 'dynamic']):
curriculum_learning_callback.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Enable curriculum learning by resuming with a different dataset.
2
+
3
+ This callback is currently experimental. The API may change without warning in
4
+ the future.
5
+ """
6
+ import logging
7
+ from typing import Any, Dict
8
+ from streaming import StreamingDataset
9
+ from torch.utils.data import DataLoader
10
+ from .interfaces import CallbackWithConfig
11
+ from .warnings import experimental_class
12
+ log = logging.getLogger(__name__)
13
+
14
+ @experimental_class('CurriculumLearning callback')
15
+ class CurriculumLearning(CallbackWithConfig):
16
+ """Starts an epoch with a different dataset when resuming from a checkpoint.
17
+
18
+ Args:
19
+ dataset_index (int): The index of the dataset currently being used.
20
+ current_dataset_config (Dict): The configuration of the dataset currently
21
+ being used.
22
+ """
23
+
24
+ def __init__(self, dataset_index: int, train_config: Dict):
25
+ self.dataset_index = dataset_index
26
+ self.saved_dataset_index = 0
27
+ self.all_dataset_configs = []
28
+ self.current_dataset_state = {}
29
+ self.current_dataset_config = train_config['dataloader']
30
+
31
+ def before_load(self, state: State, logger: Logger):
32
+ del logger
33
+ train_loader = state.train_dataloader
34
+ if not isinstance(train_loader, DataLoader):
35
+ raise ValueError(f'CurriculumLearning callback can only be used with a train ', f'dataloader of type DataLoader, but got {type(train_loader)}.')
36
+ dataset = train_loader.dataset
37
+ if not isinstance(dataset, StreamingDataset):
38
+ raise ValueError(f'CurriculumLearning callback only supports StreamingDataset ', f'because it requires loading and saving dataset state. ', f'Instead, got a dataset of type {type(dataset)}')
39
+ assert isinstance(dataset, StreamingDataset)
40
+ self.current_dataset_state = dataset.state_dict(num_samples=0, from_beginning=False)
41
+
42
+ def after_load(self, state: State, logger: Logger):
43
+ del logger
44
+ train_loader = state._train_dataloader
45
+ assert isinstance(train_loader, DataLoader), 'CurriculumLearning callback requires a DataLoader.'
46
+ dataset = train_loader.dataset
47
+ assert isinstance(dataset, StreamingDataset), 'CurriculumLearning callback requires a StreamingDataset.'
48
+ if self.saved_dataset_index < self.dataset_index:
49
+ if self.current_dataset_state['epoch'] < 0:
50
+ self.current_dataset_state['epoch'] = 0
51
+ dataset.load_state_dict(self.current_dataset_state)
52
+ state.timestamp = state.timestamp.to_next_epoch()
53
+ self.all_dataset_configs.append(self.current_dataset_config)
54
+ elif self.dataset_index == 0 and len(self.all_dataset_configs) == 0:
55
+ self.all_dataset_configs.append(self.current_dataset_config)
56
+
57
+ def state_dict(self):
58
+ return {'dataset_index': self.dataset_index, 'all_dataset_configs': self.all_dataset_configs}
59
+
60
+ def load_state_dict(self, state: Dict[str, Any]):
61
+ self.saved_dataset_index = state.get('dataset_index', 0)
62
+ self.all_dataset_configs = state.get('all_dataset_configs', [])
data.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Datasets for converting to MDS Shards."""
2
+ import os
3
+ import warnings
4
+ from typing import Dict, Iterable, Union
5
+ import datasets as hf_datasets
6
+ import numpy as np
7
+ from torch.utils.data import IterableDataset
8
+ from transformers import PreTrainedTokenizerBase
9
+
10
+ class NoConcatDataset(IterableDataset):
11
+ """An IterableDataset that returns text samples for MDSWriter.
12
+
13
+ Returns dicts of {'text': bytes}
14
+ """
15
+
16
+ def __init__(self, hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset]):
17
+ self.hf_dataset = hf_dataset
18
+
19
+ def __iter__(self) -> Iterable[Dict[str, bytes]]:
20
+ for sample in self.hf_dataset:
21
+ yield {'text': sample['text'].encode('utf-8')}
22
+
23
+ class ConcatTokensDataset(IterableDataset):
24
+ """An IterableDataset that returns token samples for MDSWriter.
25
+
26
+ Returns dicts of {'tokens': bytes}
27
+
28
+ To use data created by this class and written to MDS format:
29
+
30
+ ```python
31
+ import torch
32
+ from streaming.base import StreamingDataset
33
+ from transformers import AutoTokenizer
34
+
35
+ tokenizer = AutoTokenizer.from_pretrained('your/tokenizer')
36
+ ds = StreamingDataset(local='mds-data-folder', split='val')
37
+
38
+ # note, you need to copy the numpy array because the original is non-writeable
39
+ # and torch does not support non-writeable tensors, so you get a scary warning and
40
+ # if you do try to write to the tensor you get undefined behavior
41
+ tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int64).copy())
42
+ print(tokenizer.decode(tokens))
43
+ ```
44
+ """
45
+
46
+ def __init__(self, hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset], tokenizer: PreTrainedTokenizerBase, max_length: int, bos_text: str, eos_text: str, no_wrap: bool):
47
+ self.hf_dataset = hf_dataset
48
+ self.tokenizer = tokenizer
49
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
50
+ self.max_length = max_length
51
+ self.bos_text = bos_text
52
+ self.eos_text = eos_text
53
+ self.should_wrap = not no_wrap
54
+ self.bos_tokens = self.tokenizer(self.bos_text, truncation=False, padding=False, add_special_tokens=False)['input_ids']
55
+ if len(self.bos_tokens) > 1:
56
+ warnings.warn(f'You specified --concat_tokens with --bos_text, but your BOS text is not tokenizing to one token , instead we got {self.bos_tokens}. Quit if this was in error.')
57
+ self.eos_tokens = self.tokenizer(self.eos_text, truncation=False, padding=False, add_special_tokens=False)['input_ids']
58
+ if len(self.eos_tokens) > 1:
59
+ warnings.warn(f'You specified --concat_tokens with --eos_text, but your EOS text is not tokenizing to one token , instead we got {self.eos_tokens}. Quit if this was in error.')
60
+ eos_text_provided = self.eos_text != ''
61
+ bos_text_provided = self.bos_text != ''
62
+ test_text = self.tokenizer('')
63
+ if len(test_text['input_ids']) > 0 and (eos_text_provided or bos_text_provided):
64
+ message = 'both eos and bos' if eos_text_provided and bos_text_provided else 'eos_text' if eos_text_provided else 'bos_text'
65
+ warnings.warn(f'The provided tokenizer adds special tokens, but you also specified {message}. This may result ' + 'in duplicated special tokens. Please be sure this is what you intend.')
66
+
67
+ def __iter__(self) -> Iterable[Dict[str, bytes]]:
68
+ buffer = []
69
+ for sample in self.hf_dataset:
70
+ encoded = self.tokenizer(sample['text'], truncation=False, padding=False)
71
+ iids = encoded['input_ids']
72
+ buffer = buffer + self.bos_tokens + iids + self.eos_tokens
73
+ while len(buffer) >= self.max_length:
74
+ concat_sample = buffer[:self.max_length]
75
+ buffer = buffer[self.max_length:] if self.should_wrap else []
76
+ yield {'tokens': np.asarray(concat_sample).tobytes()}
data_prep_utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from glob import glob
4
+ from typing import List, Optional
5
+
6
+ def with_id(basename: str, shard_id: int) -> str:
7
+ """Get a new basename with the given shard_id.
8
+
9
+ From https://github.com/mosaicml/streaming/blob/main/examples/multiprocess_dataset_conversion.ipynb.
10
+
11
+ Args:
12
+ basename (str): Old basename of file.
13
+ shard_id (int): New shard ID.
14
+
15
+ Returns:
16
+ str: New basename of file.
17
+ """
18
+ parts = basename.split('.')
19
+ parts[1] = f'{shard_id:05}'
20
+ return '.'.join(parts)
21
+
22
+ def merge_shard_groups(root: str) -> None:
23
+ """Merge ephemeral sub-datasets created in parallel into one dataset.
24
+
25
+ From https://github.com/mosaicml/streaming/blob/main/examples/multiprocess_dataset
26
+ _conversion.ipynb.
27
+
28
+ Args:
29
+ root (str): Root directory.
30
+ """
31
+ pattern = os.path.join(root, '*')
32
+ subdirs = sorted(glob(pattern))
33
+ shard_id = 0
34
+ infos = []
35
+ for subdir in subdirs:
36
+ index_filename = os.path.join(subdir, 'index.json')
37
+ with open(index_filename) as index_file:
38
+ obj = json.load(index_file)
39
+ for info in obj['shards']:
40
+ old_basename = info['raw_data']['basename']
41
+ new_basename = with_id(old_basename, shard_id)
42
+ info['raw_data']['basename'] = new_basename
43
+ if info['zip_data'] is not None:
44
+ old_basename = info['zip_data']['basename']
45
+ new_basename = with_id(old_basename, shard_id)
46
+ info['zip_data']['basename'] = new_basename
47
+ old_filename = os.path.join(subdir, old_basename)
48
+ new_filename = os.path.join(root, new_basename)
49
+ os.rename(old_filename, new_filename)
50
+ shard_id += 1
51
+ infos.append(info)
52
+ os.remove(index_filename)
53
+ os.rmdir(subdir)
54
+ index_filename = os.path.join(root, 'index.json')
55
+ obj = {'version': 2, 'shards': infos}
56
+ text = json.dumps(obj, sort_keys=True)
57
+ with open(index_filename, 'w') as out:
58
+ out.write(text)
59
+
60
+ class DownloadingIterable:
61
+
62
+ def __init__(self, object_names: List[str], output_folder: str, object_store: Optional[ObjectStore]):
63
+ """Iterable that downloads files from an object store before yielding.
64
+
65
+ If object_store is None, input_folder_prefix is treated as a local path.
66
+
67
+ Args:
68
+ object_names (List[str]): Names of objects to download
69
+ output_folder (str): Local folder to write downloaded files to
70
+ object_store (Optiona[ObjectStore]): Object store to download from
71
+ """
72
+ self.object_names = object_names
73
+ self.object_store = object_store
74
+ self.output_folder = output_folder
75
+
76
+ def __iter__(self):
77
+ for object_name in self.object_names:
78
+ output_filename = object_name
79
+ if self.object_store is not None:
80
+ output_filename = os.path.join(self.output_folder, object_name.strip('/'))
81
+ self.object_store.download_object(object_name=object_name, filename=output_filename, overwrite=True)
82
+ with open(output_filename) as _txt_file:
83
+ txt = _txt_file.read()
84
+ yield {'text': txt}
dataloader.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Tuple, Union
4
+ import torch
5
+ from torch.utils.data import DataLoader
6
+ from transformers import PreTrainedTokenizerBase
7
+ from .collator import Seq2SeqFinetuningCollator, validate_target_settings
8
+ from .tasks import DOWNLOADED_FT_DATASETS_DIRPATH, SUPPORTED_EXTENSIONS, dataset_constructor
9
+ from .packing import BinPackCollator, auto_packing_ratio
10
+ from .text_data import build_streams, get_tokens_per_batch_func
11
+ from .exceptions import MissingHuggingFaceURLSplitError, NotEnoughDatasetSamplesError
12
+ log = logging.getLogger(__name__)
13
+ _HF_IGNORE_INDEX = -100
14
+ _DEFAULT_TARGET_RESPONSES = 'last'
15
+ _DEFAULT_TARGET_PROMPTS = 'none'
16
+
17
+ def build_finetuning_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int) -> DataSpec:
18
+ """Builds a finetuning dataloader for training or evaluating.
19
+
20
+ The underlying dataset can be built through one of two code paths:
21
+ 1. As a HuggingFace dataset, via `datasets.load_dataset(...)`
22
+ 2. As a streaming dataset
23
+ You will need to set slightly different dataset config fields depending
24
+ on which you intend to use, as explained below.
25
+
26
+ Args:
27
+ cfg (DictConfig): An omegaconf dictionary used to configure the loader:
28
+ cfg.name (str): The type of dataloader to build. Must = "finetuning".
29
+ ---
30
+ *** HuggingFace dataset config fields ***
31
+ cfg.dataset.hf_name (str, optional): The name of the HuggingFace dataset
32
+ to use. Can also be a remote http(s) directory or object store bucket
33
+ containing the file {split}.jsonl in the format (prompt, response),
34
+ in which case the builder will create a HuggingFace dataset.
35
+ cfg.dataset.hf_kwargs (DictConfig, optional): Additional kwargs to
36
+ pass to `datasets.load_dataset`, which can be used to load
37
+ a dataset from local files.
38
+ cfg.dataset.preprocessing_fn (str, optional): The name/import path of
39
+ the preprocessing function to use for formatting the data examples.
40
+ If ``None`` (default), the builder will use the preprocessing function
41
+ registered under `hf_name` (see `tasks.py`), if one exists,
42
+ otherwise it will skip preprocessing.
43
+ If `preprocessing_fn` corresponds to a registered preprocessing
44
+ function in `tasks.py`, the builder will use that.
45
+ Otherwise, it will interpret `preprocessing_fn` as a
46
+ "import.path:function_name" import path; e.g., it will call
47
+ `from import.path import function_name` and use the imported
48
+ function as the preprocessing function.
49
+ *** Streaming dataset config fields ***
50
+ cfg.dataset.remote (str, optional): Location of a MDS-formatted
51
+ streaming dataset to use. Setting this will tell the builder
52
+ to create a streaming dataset rather than a HuggingFace dataset.
53
+ cfg.dataset.local (str, optional): Local path where remote data
54
+ will be streamed to. Only valid if `cfg.dataset.remote` has
55
+ also been set.
56
+ *** Shared dataset configs fields ***
57
+ cfg.dataset.max_seq_len (int): The maximum length of sequences
58
+ in the batch. See :class:`Seq2SeqFinetuningCollator` docstring
59
+ for details.
60
+ cfg.dataset.decoder_only_format (bool): Whether to format the
61
+ examples for a decoder-only model. See :class:`Seq2SeqFinetuningCollator`
62
+ docstring for details.
63
+ cfg.dataset.target_responses (str): Which responses are used as training targets.
64
+ Defaults to "last", meaning only the final response in multi-turn examples
65
+ will serve as training targets. See :class:`Seq2SeqFinetuningCollator` docstring for
66
+ details.
67
+ cfg.dataset.target_prompts (str): Which prompts are used as training targets.
68
+ Defaults to "none", meaning prompts are never used as training targets.
69
+ See :class:`Seq2SeqFinetuningCollator` docstring for details.
70
+ cfg.dataset.allow_pad_trimming (bool, optional): Whether to allow
71
+ the collator to trim padding. See :class:`Seq2SeqFinetuningCollator`
72
+ docstring for details. Default: ``False``.
73
+ cfg.dataset.packing_ratio (Optional[float, Literal['auto']]): If provided, this invokes
74
+ a collator wrapper that packs device_batch_size*packing_ratio
75
+ raw examples into device_batch_size packed examples. This helps
76
+ minimize padding while preserving sequence integrity.
77
+ This adds `sequence_id` to the batch, which indicates which unique
78
+ sequence each token belongs to.
79
+
80
+ If set to 'auto', packing_ratio is profiled and the highest observed packing ratio with
81
+ zero waste is selected.
82
+ In practice, this may result in > 0 waste because profiling is done on only a portion
83
+ of the dataset.
84
+
85
+ Note: Using this feature will not change device_batch_size but it
86
+ will determine the number of raw examples consumed by the dataloader
87
+ per batch. Some examples may be discarded if they do not fit when
88
+ packing.
89
+ Select packing_ratio **carefully** based on the dataset
90
+ statistics, max_seq_len, and tolerance for discarding samples!
91
+ The script `scripts/misc/profile_packing.py` can help
92
+ you choose the best packing_ratio.
93
+ cfg.dataset.shuffle (bool): Whether to shuffle the dataset.
94
+ ___
95
+ See :class:`StreamingFinetuningDataset` for info on other standard config
96
+ options within `cfg.dataset` that will be passed as kwargs if
97
+ using the streaming codepath.
98
+ ---
99
+ See :class:`DataLoader` for standard argument options to the pytorch
100
+ dataloader, such as `cfg.drop_last`, `cfg.num_workers`, etc.
101
+ tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to
102
+ prepare the data from raw text. Any missing sentinel tokens will
103
+ be added by the collator.
104
+ device_batch_size (int): The size of the batches (number of examples)
105
+ that the dataloader will produce.
106
+
107
+ Returns:
108
+ A pytorch dataloader
109
+
110
+ Note:
111
+ You can run the script inside `scripts/misc/profile_packing.py` to quickly test the
112
+ padding/waste rates for different `cfg.dataset.packing_ratio` choices,
113
+ given a starting workload YAML.
114
+ """
115
+ _validate_config(cfg.dataset)
116
+ if tokenizer.pad_token is None:
117
+ tokenizer.pad_token = tokenizer.eos_token
118
+ collate_fn, dataloader_batch_size = _build_collate_fn(cfg, tokenizer, device_batch_size)
119
+ dataset = None
120
+ sampler = None
121
+ if cfg.dataset.get('remote') is not None or cfg.dataset.get('streams') is not None:
122
+ streams = build_streams(cfg.dataset)
123
+ dataset = dataset_constructor.build_from_streaming(tokenizer=tokenizer, streams=streams, local=cfg.dataset.get('local', None), remote=cfg.dataset.get('remote', None), split=cfg.dataset.get('split', None), download_retry=cfg.dataset.get('download_retry', 2), download_timeout=cfg.dataset.get('download_timeout', 60), validate_hash=cfg.dataset.get('validate_hash', None), keep_zip=cfg.dataset.get('keep_zip', False), epoch_size=cfg.dataset.get('epoch_size', None), predownload=cfg.dataset.get('predownload', None), cache_limit=cfg.dataset.get('cache_limit', None), partition_algo=cfg.dataset.get('partition_algo', 'relaxed'), num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None), batch_size=device_batch_size, shuffle=cfg.dataset.get('shuffle', False), shuffle_algo=cfg.dataset.get('shuffle_algo', 'py1e'), shuffle_seed=cfg.dataset.get('shuffle_seed', 9176), shuffle_block_size=cfg.dataset.get('shuffle_block_size', None), sampling_method=cfg.dataset.get('sampling_method', 'balanced'), sampling_granularity=cfg.dataset.get('sampling_granularity', 1), batching_method=cfg.dataset.get('batching_method', 'random'), max_seq_len=cfg.dataset.max_seq_len)
124
+ else:
125
+ dataset_name_or_path = cfg.dataset.hf_name
126
+ split = cfg.dataset.get('split')
127
+ if split is None:
128
+ raise MissingHuggingFaceURLSplitError()
129
+ backend, _, _ = parse_uri(dataset_name_or_path)
130
+ if backend not in ['', None]:
131
+ dataset_name_or_path = _download_remote_hf_dataset(remote_path=dataset_name_or_path, split=split)
132
+ split = split.replace('-', '_')
133
+ proto_preprocessing_fn = cfg.dataset.get('preprocessing_fn')
134
+ if isinstance(proto_preprocessing_fn, (dict, DictConfig)):
135
+ preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_dict(dict(proto_preprocessing_fn))
136
+ else:
137
+ preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_str(proto_preprocessing_fn, dataset_name_or_path)
138
+ dataset = dataset_constructor.build_from_hf(dataset_name=dataset_name_or_path, split=split, safe_load=cfg.dataset.get('safe_load', False), max_seq_len=cfg.dataset.max_seq_len, preprocessing_fn=preprocessing_fn, tokenizer=tokenizer, target_prompts=cfg.dataset.get('target_prompts', _DEFAULT_TARGET_PROMPTS), target_responses=cfg.dataset.get('target_responses', _DEFAULT_TARGET_RESPONSES), decoder_only_format=cfg.dataset.decoder_only_format, hf_kwargs=cfg.dataset.get('hf_kwargs', {}))
139
+ if cfg.drop_last:
140
+ world_size = dist.get_world_size()
141
+ minimum_dataset_size = world_size * dataloader_batch_size
142
+ if hasattr(dataset, '__len__'):
143
+ full_dataset_size = len(dataset)
144
+ if full_dataset_size < minimum_dataset_size:
145
+ raise NotEnoughDatasetSamplesError(dataset_name=cfg.dataset.hf_name, split=split, dataloader_batch_size=dataloader_batch_size, world_size=world_size, full_dataset_size=full_dataset_size, minimum_dataset_size=minimum_dataset_size)
146
+ sampler = dist.get_sampler(dataset, drop_last=cfg.drop_last, shuffle=cfg.dataset.shuffle)
147
+ assert dataset is not None
148
+ dl = DataLoader(dataset, collate_fn=collate_fn, batch_size=dataloader_batch_size, drop_last=cfg.drop_last, sampler=sampler, num_workers=cfg.num_workers, pin_memory=cfg.get('pin_memory', True), prefetch_factor=cfg.get('prefetch_factor', 2), persistent_workers=cfg.get('persistent_workers', True), timeout=cfg.get('timeout', 0))
149
+ token_counting_func = get_tokens_per_batch_func()
150
+ return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)
151
+
152
+ def _validate_config(dataset_cfg: DictConfig) -> None:
153
+ """Validates the dataset configuration.
154
+
155
+ Makes sure that the dataset is properly configured for either
156
+ a HuggingFace dataset or a streaming dataset. Must be valid for one or
157
+ the other.
158
+
159
+ Args:
160
+ dataset_cfg (DictConfig): The dataset configuration to be validated.
161
+
162
+ Raises:
163
+ ValueError: If the dataset configuration does not meet the requirements.
164
+ """
165
+ if dataset_cfg.get('hf_name') is not None:
166
+ illegal_keys = ['local', 'remote']
167
+ discovered_illegal_keys = []
168
+ for key in illegal_keys:
169
+ if dataset_cfg.get(key) is not None:
170
+ discovered_illegal_keys.append('`' + key + '`')
171
+ if discovered_illegal_keys:
172
+ raise ValueError('The dataset config sets a value for `hf_name` as well as the ' + f"following keys: {', '.join(discovered_illegal_keys)}.\n" + 'Those keys are used when building from a streaming dataset, but ' + 'setting `hf_name` instructs the dataset to build from a HuggingFace dataset.')
173
+ elif dataset_cfg.get('remote') is not None:
174
+ illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn', 'safe_load']
175
+ discovered_illegal_keys = []
176
+ for key in illegal_keys:
177
+ if dataset_cfg.get(key) is not None:
178
+ discovered_illegal_keys.append('`' + key + '`')
179
+ if discovered_illegal_keys:
180
+ raise ValueError('The dataset config sets a value for `remote` as well as the ' + f"following keys: {', '.join(discovered_illegal_keys)}.\n" + 'Those keys are used when building from a HuggingFace dataset, but ' + 'setting `remote` instructs the dataset to build from a streaming dataset.')
181
+ if dataset_cfg.get('local') is None:
182
+ raise ValueError('Using a streaming dataset requires setting both `remote` and `local`, ' + 'but dataset.local is None.')
183
+ elif dataset_cfg.get('streams') is not None:
184
+ illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn', 'safe_load']
185
+ discovered_illegal_keys = []
186
+ for key in illegal_keys:
187
+ if dataset_cfg.get(key) is not None:
188
+ discovered_illegal_keys.append('`' + key + '`')
189
+ if discovered_illegal_keys:
190
+ raise ValueError('The dataset config sets a value for `streams` as well as the ' + f"following keys: {', '.join(discovered_illegal_keys)}.\n" + 'Those keys are used when building from a HuggingFace dataset, but ' + 'setting `streams` instructs the dataset to build from a streaming dataset.')
191
+ illegal_keys = ['remote', 'local']
192
+ discovered_illegal_keys = []
193
+ for key in illegal_keys:
194
+ if dataset_cfg.get(key) is not None:
195
+ discovered_illegal_keys.append('`' + key + '`')
196
+ if discovered_illegal_keys:
197
+ raise ValueError('The dataset config sets a value for `streams` as well as the ' + f"following keys: {', '.join(discovered_illegal_keys)}.\n" + 'Please either use single stream (set remote/local only) ' + 'or put remote/local under streams')
198
+ else:
199
+ raise ValueError('In the dataset config, you must set `hf_name` to use a HuggingFace ' + 'dataset, or set `remote` to use a streaming dataset, or set ' + '`streams` to use multiple streaming datasets, but all were None.')
200
+ if dataset_cfg.get('max_seq_len') is None:
201
+ raise ValueError('In the dataset config, you must set the `max_seq_len`')
202
+ target_responses = str(dataset_cfg.get('target_responses', _DEFAULT_TARGET_RESPONSES)).lower()
203
+ target_prompts = str(dataset_cfg.get('target_prompts', _DEFAULT_TARGET_PROMPTS)).lower()
204
+ decoder_only_format = dataset_cfg.decoder_only_format
205
+ validate_target_settings(target_prompts, target_responses, decoder_only_format)
206
+
207
+ def _download_remote_hf_dataset(remote_path: str, split: str) -> str:
208
+ """Downloads a dataset from a remote object store.
209
+
210
+ This function supports 'jsonl', 'csv', and 'parquet' file formats for the dataset. It will attempt to download
211
+ the dataset, then once it is downloaded, convert it into HuggingFace ``datasets`` format, and then return this
212
+ dataset.
213
+
214
+ The function also ensures synchronicity across multiple processes during the file download. It creates a signal
215
+ file that is used to synchronize the start of the download across different processes. Once the download is
216
+ completed, the function removes the signal file.
217
+
218
+ Args:
219
+ hf_name (str): The path of the HuggingFace dataset to download.
220
+ split (str): The dataset split to download (e.g., 'train', 'validation', 'test').
221
+
222
+ Returns:
223
+ A local directory path where the dataset files are stored.
224
+
225
+ Raises:
226
+ FileNotFoundError: Raised if the dataset file cannot be found with any of the supported extensions.
227
+ """
228
+ hf_formatted_split = split.replace('-', '_')
229
+ finetune_dir = os.path.join(DOWNLOADED_FT_DATASETS_DIRPATH, hf_formatted_split if hf_formatted_split != 'data' else 'data_not')
230
+ os.makedirs(finetune_dir, exist_ok=True)
231
+ for extension in SUPPORTED_EXTENSIONS:
232
+ name = f"{remote_path.strip('/')}/{split}{extension}"
233
+ destination = str(os.path.abspath(os.path.join(finetune_dir, 'data', f'{hf_formatted_split}-00000-of-00001{extension}')))
234
+ signal_file_path = os.path.join(finetune_dir, f'.node_{dist.get_node_rank()}_local_rank0_completed')
235
+ if dist.get_local_rank() == 0:
236
+ try:
237
+ get_file(path=name, destination=destination, overwrite=True)
238
+ except FileNotFoundError as e:
239
+ if extension == SUPPORTED_EXTENSIONS[-1]:
240
+ files_searched = [f'{cfg.dataset.hf_name}/{cfg.dataset.split}{ext}' for ext in SUPPORTED_EXTENSIONS]
241
+ raise FileNotFoundError(f'Could not find a file with any of ' + f'the supported extensions: {SUPPORTED_EXTENSIONS}\n' + f'at {files_searched}') from e
242
+ else:
243
+ log.debug(f'Could not find {name}, looking for another extension')
244
+ continue
245
+ os.makedirs(os.path.dirname(signal_file_path), exist_ok=True)
246
+ with open(signal_file_path, 'wb') as f:
247
+ f.write(b'local_rank0_completed_download')
248
+ with dist.local_rank_zero_download_and_wait(signal_file_path):
249
+ dist.barrier()
250
+ if dist.get_local_rank() == 0:
251
+ os.remove(signal_file_path)
252
+ dist.barrier()
253
+ break
254
+ return finetune_dir
255
+
256
+ def _build_collate_fn(dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator], int]:
257
+ dataset_cfg = dataloader_cfg.dataset
258
+ max_seq_len = dataset_cfg.max_seq_len
259
+ collate_fn = Seq2SeqFinetuningCollator(tokenizer=tokenizer, max_seq_len=max_seq_len, decoder_only_format=dataset_cfg.decoder_only_format, target_responses=dataset_cfg.get('target_responses', _DEFAULT_TARGET_RESPONSES), target_prompts=dataset_cfg.get('target_prompts', _DEFAULT_TARGET_PROMPTS), allow_pad_trimming=dataset_cfg.get('allow_pad_trimming', False))
260
+ packing_ratio = dataset_cfg.get('packing_ratio')
261
+ max_leftover_bins_to_keep = dataset_cfg.get('max_leftover_bins_to_keep')
262
+ if packing_ratio is None:
263
+ if max_leftover_bins_to_keep is not None:
264
+ raise ValueError('dataset.max_leftover_bins_to_keep has been defined, ' + 'but dataset.packing_ratio has not been set. Please set ' + 'the latter to turn on packing or remove the former from the config.')
265
+ return (collate_fn, device_batch_size)
266
+ if packing_ratio == 'auto':
267
+ packing_ratio = auto_packing_ratio(dataloader_cfg, tokenizer, device_batch_size)
268
+ if isinstance(packing_ratio, str):
269
+ raise ValueError('dataset.packing_ratio must be a float or "auto", but it was set to ' + f'{packing_ratio}.')
270
+ log.info(f'Using packing ratio {packing_ratio}')
271
+ if packing_ratio == 1.0:
272
+ return (collate_fn, device_batch_size)
273
+ elif packing_ratio < 1.0:
274
+ raise ValueError('packing_ratio must be >= 1, if supplied')
275
+ if not dataset_cfg.decoder_only_format:
276
+ raise NotImplementedError('On-the-fly packing is currently only supported for decoder-only formats.')
277
+ collate_fn = BinPackCollator(collator=collate_fn, target_batch_size=device_batch_size, max_seq_len=max_seq_len, pad_token_id=tokenizer.pad_token_id, padding_side=tokenizer.padding_side, max_leftover_bins_to_keep=max_leftover_bins_to_keep)
278
+ n_examples_to_pack = int(device_batch_size * packing_ratio)
279
+ return (collate_fn, n_examples_to_pack)
280
+ if __name__ == '__main__':
281
+ import torch
282
+ from .utils import build_tokenizer
283
+ cfg = om.create({'dataset': {'hf_name': 'tatsu-lab/alpaca', 'preprocessing_fn': 'llmfoundry.data.finetuning.tasks:alpaca_preprocessing_function', 'split': 'train', 'packing_ratio': 18.0, 'max_seq_len': 2048, 'decoder_only_format': True, 'allow_pad_trimming': False, 'num_canonical_nodes': 472, 'shuffle': True, 'target_responses': 'last', 'target_prompts': 'none'}, 'drop_last': False, 'num_workers': 0, 'pin_memory': False, 'prefetch_factor': None, 'persistent_workers': False, 'timeout': 0})
284
+ tokenizer_name = 'EleutherAI/gpt-neox-20b'
285
+ tokenizer_kwargs = {'model_max_length': cfg.dataset.max_seq_len}
286
+ tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
287
+ device_batch_size = 1
288
+ dataloader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size).dataloader
289
+ packing = cfg.dataset.get('packing_ratio') is not None
290
+ for i, batch in enumerate(dataloader):
291
+ if i >= 5:
292
+ break
293
+ print(f'-----Batch {i}-----')
294
+ for k, v in batch.items():
295
+ if isinstance(v, torch.Tensor):
296
+ print(k, v.shape)
297
+ else:
298
+ print(k, v)
299
+ for j in range(device_batch_size):
300
+ print(f'--- Sample {j} ---')
301
+ if cfg.dataset.decoder_only_format:
302
+ if packing:
303
+ for subseq in range(int(batch['sequence_id'][j].max()) + 1):
304
+ is_subseq = batch['sequence_id'][j] == subseq
305
+ print('\x1b[93m{}\x1b[00m\n'.format('INPUT IDS:'), tokenizer.decode(batch['input_ids'][j, torch.logical_and(is_subseq, batch['attention_mask'][j] == 1)], skip_special_tokens=False, clean_up_tokenization_spaces=True))
306
+ print('\x1b[91m{}\x1b[00m\n'.format('TARGET: '), tokenizer.decode(batch['input_ids'][j, torch.logical_and(is_subseq, batch['labels'][j] != _HF_IGNORE_INDEX)], skip_special_tokens=False, clean_up_tokenization_spaces=True))
307
+ else:
308
+ print('\x1b[93m{}\x1b[00m\n'.format('INPUT IDS:'), tokenizer.decode(batch['input_ids'][j, batch['attention_mask'][j] == 1], skip_special_tokens=False, clean_up_tokenization_spaces=True))
309
+ print('\x1b[91m{}\x1b[00m\n'.format('TARGET: '), tokenizer.decode(batch['input_ids'][j, batch['labels'][j] != _HF_IGNORE_INDEX], skip_special_tokens=False, clean_up_tokenization_spaces=True))
310
+ else:
311
+ print('\x1b[92m{}\x1b[00m\n'.format('CONTEXT: '), tokenizer.decode(batch['input_ids'][j, batch['attention_mask'][j] == 1], skip_special_tokens=False, clean_up_tokenization_spaces=True))
312
+ print('\x1b[91m{}\x1b[00m\n'.format('TARGET: '), tokenizer.decode(batch['labels'][j, batch['decoder_attention_mask'][j] == 1], skip_special_tokens=False, clean_up_tokenization_spaces=True))
313
+ print(' ')
eval_gauntlet_callback.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Aggregate ICL evals into composite scores."""
2
+ import logging
3
+ import math
4
+ from enum import Enum
5
+ from typing import Dict, Optional
6
+ log = logging.getLogger(__name__)
7
+
8
+ class Weighting(Enum):
9
+ EQUAL = 1
10
+ SAMPLE_SZ = 2
11
+ LOG_SAMPLE_SZ = 3
12
+
13
+ def calculate_named_averages(average_names: Dict[str, list], category_scores: Dict[str, float]):
14
+ """Calculates the named averages based off the raw category scores.
15
+
16
+ For each named average, take a simple average of all the category scores associated with that named average.
17
+
18
+ Args:
19
+ average_names (dict[str, list]): Contains a mapping of named averages to which category scores that average should consist of.
20
+ category_scores (dict[str, float]): Contains the raw scores corresponding to each category.
21
+ """
22
+ average_scores = {}
23
+ for avg_name, category_list in average_names.items():
24
+ composite_subset = {category: score for category, score in category_scores.items() if category in category_list}
25
+ if len(composite_subset.values()) > 0:
26
+ average_scores[avg_name] = sum(composite_subset.values()) / len(composite_subset.values())
27
+ else:
28
+ average_scores[avg_name] = 0
29
+ return average_scores
30
+
31
+ class EvalGauntlet(Callback):
32
+ """The EvalGauntlet aggregates ICL eval results.
33
+
34
+ After `eval_end`, this callback inspects the logger for different ICL metrics and aggregates the scores according to the aggregation
35
+ specification provided in the constructor.
36
+
37
+ Args:
38
+ logger_keys (list): These are the exact keys that the individual benchmark metrics will be
39
+ logged under in the logger after eval
40
+ categories (dict): This contains the list of categories, as well as the subtasks within them, the
41
+ random baseline accuracy of each subtask, and the number of fewshot examples
42
+ used for the task. See `llmfoundry/scripts/eval/yamls/eval_gauntlet_v0.2.yaml` to see the structure.
43
+ weighting (Weighting): The weighting scheme used to balance different tasks within each category.
44
+ Either assign them all equal weight, assign them weight proportional
45
+ to the dataset size, or assign them weight proportional to the log2 of the dataset size.
46
+ Options are 'EQUAL', 'SAMPLE_SZ', and 'LOG_SAMPLE_SZ'.
47
+ subtract_random_baseline (bool): Flag determining whether to subtract random baseline accuracy
48
+ from the performance on each individual benchmark before aggregating.
49
+ rescale_accuracy (bool): Flag determining whether to rescale the accuracy on each benchmark
50
+ by (1-random_baseline_accuracy) before aggregating. Using this ensures that all benchmarks max out at 1.0.
51
+ benchmark_sizes (Optional[dict]): Optional data on benchmark sizes, used when not relying on equal weighting.
52
+ averages (Optional[dict]): Optional dictionary specifying a mapping from a average names to lists of categories used produce each named average.
53
+ """
54
+
55
+ def __init__(self, logger_keys: list, categories: dict, weighting: str='EQUAL', subtract_random_baseline: bool=True, rescale_accuracy: bool=True, benchmark_sizes: Optional[dict]=None, averages: Optional[dict]=None):
56
+ if isinstance(logger_keys, dict):
57
+ raise ValueError('logger_keys now requires a list type as input, not a dict')
58
+ if weighting != Weighting.EQUAL and benchmark_sizes is None:
59
+ raise Exception('When not using equal weighting, you must provide the benchmark sizes.')
60
+ if rescale_accuracy and (not subtract_random_baseline):
61
+ raise Exception('Only use accuracy rescaling in conjunction with subtracting random baseline accuracy.')
62
+ self.categories = categories
63
+ self.category_names = [conf.get('name') for conf in self.categories]
64
+ self.weighting = Weighting[weighting]
65
+ self.subtract_random_baseline = subtract_random_baseline
66
+ self.rescale_accuracy = rescale_accuracy
67
+ self.logger_keys = logger_keys
68
+ for category in self.categories:
69
+ for benchmark in category['benchmarks']:
70
+ bench_name = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot"
71
+ if self.weighting != Weighting.EQUAL:
72
+ assert benchmark_sizes is not None
73
+ cumulative_samples = max(sum((count for name, count in benchmark_sizes.items() if name.startswith(bench_name))), 1)
74
+ else:
75
+ cumulative_samples = -1
76
+ weight = None
77
+ if self.weighting == Weighting.EQUAL:
78
+ weight = 1
79
+ elif self.weighting == Weighting.SAMPLE_SZ:
80
+ weight = cumulative_samples
81
+ elif self.weighting == Weighting.LOG_SAMPLE_SZ:
82
+ weight = max(math.log2(cumulative_samples), 1)
83
+ assert weight is not None
84
+ benchmark['weighting'] = weight
85
+ self.averages = {}
86
+ if averages is not None:
87
+ self.averages = averages
88
+ else:
89
+ self.averages['default_average'] = self.category_names
90
+ for avg_name in self.averages:
91
+ if avg_name in self.category_names:
92
+ raise ValueError(f'Found average name `{avg_name}` used as category name. Average names and category names must be non-overlapping.')
93
+
94
+ def extract_metrics_from_state(self, state: State) -> Dict[str, float]:
95
+ results = {}
96
+ for key in self.logger_keys:
97
+ dl_name, metric_name = (key.split('/')[1:-1], key.split('/')[-1])
98
+ if 'Accuracy' not in metric_name:
99
+ continue
100
+ metric = state.eval_metrics.get('/'.join(dl_name), {}).get(metric_name, None)
101
+ if metric is None:
102
+ continue
103
+ val = metric.compute().item()
104
+ key = '/'.join(dl_name[0:2])
105
+ if key not in results:
106
+ results[key] = []
107
+ results[key].append(val)
108
+ return {k: sum(v) / len(v) for k, v in results.items()}
109
+
110
+ def eval_after_all(self, state: State, logger: Logger) -> Dict[str, float]:
111
+ computed_metrics = self.extract_metrics_from_state(state)
112
+ if len(computed_metrics) == 0:
113
+ return {}
114
+ category_scores = {}
115
+ for category in self.categories:
116
+ missing_metrics = []
117
+ category_scores[category['name']] = []
118
+ for benchmark in category['benchmarks']:
119
+ key = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot"
120
+ if key not in computed_metrics:
121
+ log.warning(f'Could not find results for benchmark: {benchmark}.')
122
+ missing_metrics.append(key)
123
+ else:
124
+ score = computed_metrics[key]
125
+ if self.subtract_random_baseline:
126
+ score -= benchmark['random_baseline']
127
+ if self.rescale_accuracy and self.subtract_random_baseline:
128
+ score /= 1.0 - benchmark['random_baseline']
129
+ category_scores[category['name']].append({'name': benchmark['name'], 'score': score, 'weighting': benchmark['weighting']})
130
+ if len(missing_metrics) > 0:
131
+ log.warning(f"Removing category `{category['name']}` from scores because benchmarks were missing: {missing_metrics}")
132
+ del category_scores[category['name']]
133
+ continue
134
+ total_weight = sum((k['weighting'] for k in category_scores[category['name']]))
135
+ category_scores[category['name']] = sum((k['score'] * (k['weighting'] / total_weight) for k in category_scores[category['name']]))
136
+ named_averages = calculate_named_averages(self.averages, category_scores)
137
+ category_scores.update(named_averages)
138
+ category_scores = {f'icl/metrics/eval_gauntlet/{k}': v for k, v in category_scores.items()}
139
+ if logger is not None:
140
+ logger.log_metrics(category_scores)
141
+ return category_scores
exceptions.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom exceptions for the LLMFoundry."""
2
+ from collections.abc import Mapping
3
+ from typing import Any, Dict, List
4
+
5
+ class MissingHuggingFaceURLSplitError(ValueError):
6
+ """Error thrown when there's no split used in HF dataset config."""
7
+
8
+ def __init__(self) -> None:
9
+ message = 'When using a HuggingFace dataset from a URL, you must set the ' + '`split` key in the dataset config.'
10
+ super().__init__(message)
11
+
12
+ class NotEnoughDatasetSamplesError(ValueError):
13
+ """Error thrown when there is not enough data to train a model."""
14
+
15
+ def __init__(self, dataset_name: str, split: str, dataloader_batch_size: int, world_size: int, full_dataset_size: int, minimum_dataset_size: int) -> None:
16
+ self.dataset_name = dataset_name
17
+ self.split = split
18
+ self.dataloader_batch_size = dataloader_batch_size
19
+ self.world_size = world_size
20
+ self.full_dataset_size = full_dataset_size
21
+ self.minimum_dataset_size = minimum_dataset_size
22
+ message = f'Your dataset (name={dataset_name}, split={split}) ' + f'has {full_dataset_size} samples, but your minimum batch size ' + f'is {minimum_dataset_size} because you are running on {world_size} gpus and ' + f'your per device batch size is {dataloader_batch_size}. Please increase the number ' + f'of samples in your dataset to at least {minimum_dataset_size}.'
23
+ super().__init__(message)
24
+
25
+ class UnknownExampleTypeError(KeyError):
26
+ """Error thrown when an unknown example type is used in a task."""
27
+
28
+ def __init__(self, example: Mapping) -> None:
29
+ message = f'Unknown example type example={example!r}'
30
+ super().__init__(message)
31
+
32
+ class TooManyKeysInExampleError(ValueError):
33
+ """Error thrown when a data sample has too many keys."""
34
+
35
+ def __init__(self, desired_keys: set[str], keys: set[str]) -> None:
36
+ message = f'Data sample has {len(keys)} keys in `allowed_keys`: {desired_keys} Please specify exactly one. Provided keys: {keys}'
37
+ super().__init__(message)
38
+
39
+ class NotEnoughChatDataError(ValueError):
40
+ """Error thrown when there is not enough chat data to train a model."""
41
+
42
+ def __init__(self) -> None:
43
+ message = 'Chat example must have at least two messages'
44
+ super().__init__(message)
45
+
46
+ class ConsecutiveRepeatedChatRolesError(ValueError):
47
+ """Error thrown when there are consecutive repeated chat roles."""
48
+
49
+ def __init__(self, repeated_role: str) -> None:
50
+ self.repeated_role = repeated_role
51
+ message = f'Conversation roles must alternate but found {repeated_role} repeated consecutively.'
52
+ super().__init__(message)
53
+
54
+ class InvalidLastChatMessageRoleError(ValueError):
55
+ """Error thrown when the last message role in a chat example is invalid."""
56
+
57
+ def __init__(self, last_role: str, expected_roles: set[str]) -> None:
58
+ message = f'Invalid last message role: {last_role}. Expected one of: {expected_roles}'
59
+ super().__init__(message)
60
+
61
+ class IncorrectMessageKeyQuantityError(ValueError):
62
+ """Error thrown when a message has an incorrect number of keys."""
63
+
64
+ def __init__(self, keys: List[str]) -> None:
65
+ self.keys = keys
66
+ message = f'Expected 2 keys in message, but found {len(keys)}'
67
+ super().__init__(message)
68
+
69
+ class InvalidRoleError(ValueError):
70
+ """Error thrown when a role is invalid."""
71
+
72
+ def __init__(self, role: str, valid_roles: set[str]) -> None:
73
+ self.role = role
74
+ self.valid_roles = valid_roles
75
+ message = f'Expected role to be one of {valid_roles} but found: {role}'
76
+ super().__init__(message)
77
+
78
+ class InvalidContentTypeError(TypeError):
79
+ """Error thrown when the content type is invalid."""
80
+
81
+ def __init__(self, content_type: type) -> None:
82
+ self.content_type = content_type
83
+ message = f'Expected content to be a string, but found {content_type}'
84
+ super().__init__(message)
85
+
86
+ class InvalidPromptTypeError(TypeError):
87
+ """Error thrown when the prompt type is invalid."""
88
+
89
+ def __init__(self, prompt_type: type) -> None:
90
+ self.prompt_type = prompt_type
91
+ message = f'Expected prompt to be a string, but found {prompt_type}'
92
+ super().__init__(message)
93
+
94
+ class InvalidResponseTypeError(TypeError):
95
+ """Error thrown when the response type is invalid."""
96
+
97
+ def __init__(self, response_type: type) -> None:
98
+ self.response_type = response_type
99
+ message = f'Expected response to be a string, but found {response_type}'
100
+ super().__init__(message)
101
+
102
+ class InvalidPromptResponseKeysError(ValueError):
103
+ """Error thrown when missing expected prompt and response keys."""
104
+
105
+ def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]):
106
+ self.example = example
107
+ message = f'Expected mapping={mapping!r} to have keys "prompt" and "response".'
108
+ super().__init__(message)
109
+
110
+ class InvalidFileExtensionError(FileNotFoundError):
111
+ """Error thrown when a file extension is not a safe extension."""
112
+
113
+ def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None:
114
+ self.dataset_name = dataset_name
115
+ self.valid_extensions = valid_extensions
116
+ message = f'safe_load is set to True. No data files with safe extensions {valid_extensions} ' + f'found for dataset at local path {dataset_name}.'
117
+ super().__init__(message)
118
+
119
+ class UnableToProcessPromptResponseError(ValueError):
120
+ """Error thrown when a prompt and response cannot be processed."""
121
+
122
+ def __init__(self, input: Dict) -> None:
123
+ message = f'Unable to extract prompt/response from {input}'
124
+ super().__init__(message)
125
+
126
+ class ClusterDoesNotExistError(ValueError):
127
+ """Error thrown when the cluster does not exist."""
128
+
129
+ def __init__(self, cluster_id: str) -> None:
130
+ self.cluster_id = cluster_id
131
+ message = f'Cluster with id {cluster_id} does not exist. Check cluster id and try again!'
132
+ super().__init__(message)
133
+
134
+ class FailedToCreateSQLConnectionError(RuntimeError):
135
+ """Error thrown when client can't sql connect to Databricks."""
136
+
137
+ def __init__(self) -> None:
138
+ message = 'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!'
139
+ super().__init__(message)
140
+
141
+ class FailedToConnectToDatabricksError(RuntimeError):
142
+ """Error thrown when the client fails to connect to Databricks."""
143
+
144
+ def __init__(self) -> None:
145
+ message = 'Failed to create databricks connection. Check hostname and access token!'
146
+ super().__init__(message)
147
+
148
+ class InputFolderMissingDataError(ValueError):
149
+ """Error thrown when the input folder is missing data."""
150
+
151
+ def __init__(self, input_folder: str) -> None:
152
+ self.input_folder = input_folder
153
+ message = f'No text files were found at {input_folder}.'
154
+ super().__init__(message)
155
+
156
+ class OutputFolderNotEmptyError(FileExistsError):
157
+ """Error thrown when the output folder is not empty."""
158
+
159
+ def __init__(self, output_folder: str) -> None:
160
+ self.output_folder = output_folder
161
+ message = f'{output_folder} is not empty. Please remove or empty it and retry.'
162
+ super().__init__(message)
fdiff_callback.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Monitor rate of change of loss."""
2
+ from __future__ import annotations
3
+ import torch
4
+
5
+ class FDiffMetrics(Callback):
6
+ """Rate of change of metrics.
7
+
8
+ tracks and plots the rate of change of metrics effectively taking the
9
+ numerical derivative of the metrics
10
+ """
11
+
12
+ def __init__(self, diff_train_metrics: bool=False, diff_eval_metrics: bool=True):
13
+ self.diff_train_metrics = diff_train_metrics
14
+ self.diff_eval_metrics = diff_eval_metrics
15
+ self.train_prev_loss = None
16
+ self.train_prev_metric = {}
17
+ self.eval_prev_metric = {}
18
+
19
+ def batch_end(self, state: State, logger: Logger) -> None:
20
+ if self.diff_train_metrics:
21
+ if not isinstance(state.loss, torch.Tensor):
22
+ raise NotImplementedError('Multiple losses not supported yet')
23
+ loss = state.loss.item()
24
+ if self.train_prev_loss:
25
+ logger.log_metrics({'loss/train/total_fdiff': loss - self.train_prev_loss})
26
+ self.train_prev_loss = loss
27
+ for k in self.train_prev_metric.keys():
28
+ logger.log_metrics({f'metrics/train/{k}_fdiff': state.train_metric_values[k] - self.train_prev_metric[k]})
29
+ for k in state.train_metric_values.keys():
30
+ value = state.train_metric_values[k]
31
+ self.train_prev_metric[k] = value
32
+
33
+ def eval_end(self, state: State, logger: Logger) -> None:
34
+ if self.diff_eval_metrics:
35
+ evaluator = state.dataloader_label
36
+ assert evaluator is not None, 'dataloader should have been set'
37
+ metrics = list(state.eval_metrics[evaluator].keys())
38
+ for k in metrics:
39
+ mkey = '/'.join(['metrics', evaluator, k])
40
+ if mkey in self.eval_prev_metric.keys():
41
+ logger.log_metrics({f'{mkey}_fdiff': state.eval_metric_values[k] - self.eval_prev_metric[mkey]})
42
+ for k in metrics:
43
+ mkey = '/'.join(['metrics', evaluator, k])
44
+ self.eval_prev_metric[mkey] = state.eval_metric_values[k]
ffn.py CHANGED
@@ -59,8 +59,7 @@ class MPTMLP(nn.Module):
59
  super().__init__()
60
  ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, ffn_hidden_size)
61
  self.fc_kwargs: dict[str, Any] = {'bias': bias}
62
- if fc_type != 'te':
63
- self.fc_kwargs['device'] = device
64
  self.up_proj = FC_CLASS_REGISTRY[fc_type](d_model, ffn_hidden_size, **self.fc_kwargs)
65
  self.act = act_fn
66
  self.down_proj = FC_CLASS_REGISTRY[fc_type](ffn_hidden_size, d_model, **self.fc_kwargs)
@@ -75,6 +74,7 @@ class MPTGLU(MPTMLP):
75
  super().__init__(d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, ffn_hidden_size=ffn_hidden_size, act_fn=act_fn, device=device, bias=bias)
76
  self.gate_proj = FC_CLASS_REGISTRY[fc_type](d_model, self.up_proj.out_features, **self.fc_kwargs)
77
 
 
78
  def forward(self, x: torch.Tensor) -> torch.Tensor:
79
  return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
80
  FFN_CLASS_REGISTRY = {'mptmlp': MPTMLP, 'mptglu': MPTGLU}
 
59
  super().__init__()
60
  ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, ffn_hidden_size)
61
  self.fc_kwargs: dict[str, Any] = {'bias': bias}
62
+ self.fc_kwargs['device'] = device
 
63
  self.up_proj = FC_CLASS_REGISTRY[fc_type](d_model, ffn_hidden_size, **self.fc_kwargs)
64
  self.act = act_fn
65
  self.down_proj = FC_CLASS_REGISTRY[fc_type](ffn_hidden_size, d_model, **self.fc_kwargs)
 
74
  super().__init__(d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, ffn_hidden_size=ffn_hidden_size, act_fn=act_fn, device=device, bias=bias)
75
  self.gate_proj = FC_CLASS_REGISTRY[fc_type](d_model, self.up_proj.out_features, **self.fc_kwargs)
76
 
77
+ @torch.compile
78
  def forward(self, x: torch.Tensor) -> torch.Tensor:
79
  return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
80
  FFN_CLASS_REGISTRY = {'mptmlp': MPTMLP, 'mptglu': MPTGLU}
finetuning.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .collator import Seq2SeqFinetuningCollator
2
+ from .dataloader import build_finetuning_dataloader
hf.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .hf_causal_lm import ComposerHFCausalLM
2
+ from .hf_fsdp import prepare_hf_causal_lm_model_for_fsdp, prepare_hf_enc_dec_model_for_fsdp, prepare_hf_model_for_fsdp
3
+ from .hf_t5 import ComposerHFT5
hf_causal_lm.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`."""
2
+ import logging
3
+ import os
4
+ import warnings
5
+ from typing import TYPE_CHECKING, Any, Dict, Mapping
6
+ from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase
7
+ from .hf_fsdp import hf_get_init_device
8
+ from .model_wrapper import HuggingFaceModelWithFSDP
9
+ from .attention import is_flash_v2_installed
10
+ from .utils import init_empty_weights
11
+ from .config_utils import pop_config
12
+ if TYPE_CHECKING:
13
+ from peft import PeftConfig
14
+ log = logging.getLogger(__name__)
hf_checkpointer.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import copy
3
+ import logging
4
+ import math
5
+ import os
6
+ import re
7
+ import tempfile
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Optional, Sequence, Union
10
+ import torch
11
+ from mlflow.transformers import _fetch_model_card, _write_license_information
12
+ from transformers import PreTrainedModel, PreTrainedTokenizerBase
13
+ from .mpt import MPTConfig, MPTForCausalLM
14
+ from .utils import init_empty_weights
15
+ from .huggingface_hub_utils import edit_files_for_hf_compatibility
16
+ log = logging.getLogger(__name__)
17
+ _LICENSE_FILE_PATTERN = re.compile('license(\\.[a-z]+|$)', re.IGNORECASE)
18
+
19
+ def _maybe_get_license_filename(local_dir: str, pretrained_model_name: Optional[str]=None) -> Optional[str]:
20
+ """Returns the name of the license file if it exists in the local_dir.
21
+
22
+ Note: This is intended to be consistent with the code in MLflow.
23
+ https://github.com/mlflow/mlflow/blob/5d13d6ec620a02de9a5e31201bf1becdb9722ea5/mlflow/transformers/__init__.py#L1152
24
+
25
+ Since LLM Foundry supports local model files being used rather than fetching the files from the Hugging Face Hub,
26
+ MLflow's logic to fetch and write the license information on model save is not applicable; it will try to search for
27
+ a Hugging Face repo named after the local path. However, the user can provide the original pretrained model name,
28
+ in which case this function will use that to fetch the correct license information.
29
+
30
+ If the license file does not exist, returns None.
31
+ """
32
+ try:
33
+ license_filename = next((file for file in os.listdir(local_dir) if _LICENSE_FILE_PATTERN.search(file)))
34
+ if pretrained_model_name is not None:
35
+ log.info(f'Overwriting license file {license_filename} with license info for model {pretrained_model_name} from Hugging Face Hub')
36
+ os.remove(os.path.join(local_dir, license_filename))
37
+ model_card = _fetch_model_card(pretrained_model_name)
38
+ local_dir_path = Path(local_dir).absolute()
39
+ _write_license_information(pretrained_model_name, model_card, local_dir_path)
40
+ license_filename = next((file for file in os.listdir(local_dir) if _LICENSE_FILE_PATTERN.search(file)))
41
+ return license_filename
42
+ except StopIteration:
43
+ return None
44
+
45
+ class HuggingFaceCheckpointer(Callback):
46
+ """Save a huggingface formatted checkpoint during training.
47
+
48
+ Args:
49
+ save_folder (str): Top level folder to save checkpoints to (can be a
50
+ URI). It is likely that this would be the same as your save_folder.
51
+ save_interval: Union[str, int, Time]: The interval describing how often
52
+ checkpoints should be saved. If an integer, it will be assumed to be
53
+ in :attr:`.TimeUnit.EPOCH`. Otherwise, the unit must be either
54
+ :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
55
+ :attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
56
+ huggingface_folder_name (str): Folder to save each checkpoint under (can
57
+ be a format string). Default is ``ba{batch}``.
58
+ precision: The precision to save the model in. Default is ``float32``.
59
+ Options are ``bfloat16``, ``float16``, or ``float32``.
60
+ overwrite (bool): Whether to overwrite previous checkpoints.
61
+ mlflow_registered_model_name (Optional[str]): The name to register the
62
+ model under in the MLflow model registry. If ``None``, the model
63
+ will not be registered. Default is ``None``.
64
+ mlflow_logging_config (Optional[dict]): A dictionary of config arguments
65
+ that will get passed along to the MLflow ``save_model`` call.
66
+ Expected to contain ``metadata`` and ``task`` keys. If either is
67
+ unspecified, the defaults are ``'text-generation'`` and
68
+ ``{'task': 'llm/v1/completions'}`` respectively. A default input example
69
+ and signature intended for text generation is also included under the
70
+ keys ``input_example`` and ``signature``.
71
+ flatten_imports (Sequence[str]): A sequence of import prefixes that will
72
+ be flattened when editing MPT files.
73
+ """
74
+
75
+ def __init__(self, save_folder: str, save_interval: Union[str, int, Time], huggingface_folder_name: str='ba{batch}', precision: str='float32', overwrite: bool=True, mlflow_registered_model_name: Optional[str]=None, mlflow_logging_config: Optional[dict]=None, flatten_imports: Sequence[str]=('llmfoundry',)):
76
+ _, _, self.save_dir_format_str = parse_uri(save_folder)
77
+ self.overwrite = overwrite
78
+ self.precision = precision
79
+ self.dtype = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16}[precision]
80
+ self.flatten_imports = flatten_imports
81
+ self.mlflow_registered_model_name = mlflow_registered_model_name
82
+ if mlflow_logging_config is None:
83
+ mlflow_logging_config = {}
84
+ if self.mlflow_registered_model_name is not None:
85
+ import numpy as np
86
+ passed_metadata = mlflow_logging_config.get('metadata', {})
87
+ mlflow_logging_config['metadata'] = passed_metadata
88
+ mlflow_logging_config.setdefault('task', 'llm/v1/completions')
89
+ default_input_example = {'prompt': np.array(['What is Machine Learning?'])}
90
+ is_chat = mlflow_logging_config['task'].endswith('chat') or mlflow_logging_config['metadata'].get('task', '').endswith('chat')
91
+ if is_chat:
92
+ default_input_example = {'messages': np.array([{'role': 'user', 'content': 'What is Machine Learning?'}])}
93
+ mlflow_logging_config.setdefault('example_no_conversion', True)
94
+ mlflow_logging_config.setdefault('input_example', default_input_example)
95
+ self.mlflow_logging_config = mlflow_logging_config
96
+ self.huggingface_folder_name_fstr = os.path.join('huggingface', huggingface_folder_name)
97
+ self.save_interval: Time = Time.from_input(save_interval, TimeUnit.EPOCH)
98
+ self.check_interval = create_interval_scheduler(self.save_interval, include_end_of_training=True)
99
+ self.remote_ud = maybe_create_remote_uploader_downloader_from_uri(save_folder, loggers=[])
100
+ if self.remote_ud is not None:
101
+ self.remote_ud._num_concurrent_uploads = 4
102
+ self.last_checkpoint_batch: Optional[Time] = None
103
+ self.mlflow_loggers = []
104
+
105
+ def run_event(self, event: Event, state: State, logger: Logger) -> None:
106
+ if state.get_elapsed_duration() is not None and self.check_interval(state, event) and (self.last_checkpoint_batch != state.timestamp.batch):
107
+ self._save_checkpoint(state, logger)
108
+ elif event == Event.INIT:
109
+ if not isinstance(state.model, HuggingFaceModel):
110
+ raise ValueError(f'`HuggingFaceCheckpointer` is only compatible with `HuggingFaceModel`s. ' + f'Got {type(state.model)} instead.')
111
+ if self.remote_ud is not None:
112
+ self.remote_ud.init(state, logger)
113
+ state.callbacks.append(self.remote_ud)
114
+ if self.mlflow_registered_model_name is not None:
115
+ self.mlflow_loggers = [logger_destination for logger_destination in logger.destinations if isinstance(logger_destination, MLFlowLogger)]
116
+ if len(self.mlflow_loggers) == 0:
117
+ raise ValueError(f'`mlflow_registered_model_name` was set, but no `MLFlowLogger` was found in the `logger.destinations` list. ' + 'Please add an `MLFlowLogger` or set `mlflow_registered_model_name` to `None`.')
118
+ import mlflow
119
+ mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set('5GB')
120
+
121
+ def _is_last_batch(self, state: State):
122
+ elapsed_duration = state.get_elapsed_duration()
123
+ if elapsed_duration is not None and elapsed_duration >= 1.0:
124
+ return True
125
+ assert state.max_duration is not None
126
+ if self.save_interval.unit == TimeUnit.DURATION and self.save_interval.value == 1 and (state.max_duration.unit == TimeUnit.EPOCH):
127
+ assert state.dataloader_len is not None
128
+ return int(state.timestamp.batch) % math.ceil(state.max_duration.value * state.dataloader_len) == 0
129
+ return False
130
+
131
+ def _save_checkpoint(self, state: State, logger: Logger):
132
+ del logger
133
+ self.last_checkpoint_batch = state.timestamp.batch
134
+ log.info('Saving HuggingFace formatted checkpoint')
135
+ from transformers.models.auto.configuration_auto import CONFIG_MAPPING
136
+ CONFIG_MAPPING._extra_content['mpt'] = MPTConfig
137
+ MPTConfig.register_for_auto_class()
138
+ MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')
139
+ save_dir = format_name_with_dist_and_time(str(Path(self.save_dir_format_str) / self.huggingface_folder_name_fstr), state.run_name, state.timestamp)
140
+ dir_context_mgr = tempfile.TemporaryDirectory() if self.remote_ud is not None else contextlib.nullcontext(enter_result=save_dir)
141
+ with dir_context_mgr as temp_save_dir:
142
+ assert isinstance(temp_save_dir, str)
143
+ log.debug('Gathering state dict')
144
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
145
+ if state.is_model_ddp:
146
+ composer_model = state.model.module
147
+ original_model: PreTrainedModel = state.model.module.model
148
+ state_dict_model = state.model.module.model
149
+ original_tokenizer = state.model.module.tokenizer
150
+ elif isinstance(state.model.model, FSDP):
151
+ composer_model = state.model
152
+ original_model: PreTrainedModel = state.model.model.module
153
+ state_dict_model = state.model.model
154
+ original_tokenizer = state.model.tokenizer
155
+ else:
156
+ composer_model = state.model
157
+ original_model: PreTrainedModel = state.model.model
158
+ state_dict_model = state.model.model
159
+ original_tokenizer = state.model.tokenizer
160
+ state_dict_context = fsdp_state_dict_type_context(original_model, state_dict_type='full') if not state.is_model_ddp and isinstance(state_dict_model, FSDP) else contextlib.nullcontext()
161
+ with state_dict_context:
162
+ state_dict = state_dict_model.state_dict()
163
+ for k, v in state_dict.items():
164
+ if isinstance(v, torch.Tensor):
165
+ state_dict[k] = v.to(dtype=self.dtype)
166
+ if dist.get_global_rank() == 0:
167
+ log.debug('Saving Hugging Face checkpoint in global rank 0')
168
+ copied_config = copy.deepcopy(original_model.config)
169
+ if copied_config.model_type == 'mpt':
170
+ copied_config.attn_config['attn_impl'] = 'torch'
171
+ copied_config.init_device = 'cpu'
172
+ log.debug(f'Creating new model instance')
173
+ if composer_model.using_peft:
174
+ active_adapter = original_model.active_adapter
175
+ base_model = original_model.get_base_model()
176
+ new_base_model_instance = type(base_model)(copied_config)
177
+ new_model_instance = type(original_model)(new_base_model_instance, original_model.peft_config[active_adapter])
178
+ new_model_instance.to(dtype=self.dtype)
179
+ else:
180
+ with init_empty_weights():
181
+ new_model_instance = type(original_model)(copied_config)
182
+ new_model_instance.load_state_dict(state_dict, assign=True)
183
+ del state_dict
184
+ log.debug('Saving Hugging Face checkpoint to disk')
185
+ new_model_instance.save_pretrained(temp_save_dir)
186
+ if original_tokenizer is not None:
187
+ assert isinstance(original_tokenizer, PreTrainedTokenizerBase)
188
+ original_tokenizer.save_pretrained(temp_save_dir)
189
+ if original_model.config.model_type == 'mpt':
190
+ log.debug('Editing MPT files for HuggingFace compatibility')
191
+ edit_files_for_hf_compatibility(temp_save_dir, self.flatten_imports)
192
+ if self.remote_ud is not None:
193
+ for filename in os.listdir(temp_save_dir):
194
+ remote_file_name = os.path.join(save_dir, filename)
195
+ remote_file_uri = self.remote_ud.remote_backend.get_uri(remote_file_name)
196
+ log.info(f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}')
197
+ self.remote_ud.upload_file(state=state, remote_file_name=remote_file_name, file_path=Path(os.path.join(temp_save_dir, filename)), overwrite=self.overwrite)
198
+ if self.mlflow_registered_model_name and self._is_last_batch(state):
199
+ components = {'model': new_model_instance}
200
+ if original_tokenizer is not None:
201
+ components['tokenizer'] = original_tokenizer
202
+ log.debug('Logging Hugging Face model to MLFlow')
203
+ for i, mlflow_logger in enumerate(self.mlflow_loggers):
204
+ log.debug(f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}')
205
+ local_save_path = str(Path(temp_save_dir) / f'mlflow_save_{i}')
206
+ import mlflow
207
+ mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
208
+ model_saving_kwargs: Dict[str, Any] = {'path': local_save_path}
209
+ if composer_model.using_peft:
210
+ model_saving_kwargs['flavor'] = 'peft'
211
+ model_saving_kwargs['save_pretrained_dir'] = temp_save_dir
212
+ model_saving_kwargs['metadata'] = self.mlflow_logging_config['metadata']
213
+ else:
214
+ model_saving_kwargs['flavor'] = 'transformers'
215
+ model_saving_kwargs['transformers_model'] = components
216
+ model_saving_kwargs.update(self.mlflow_logging_config)
217
+ mlflow_logger.save_model(**model_saving_kwargs)
218
+ license_filename = _maybe_get_license_filename(local_save_path, self.mlflow_logging_config['metadata'].get('pretrained_model_name', None))
219
+ if license_filename is not None:
220
+ mlflow_logger._mlflow_client.log_artifact(mlflow_logger._run_id, os.path.join(local_save_path, license_filename))
221
+ mlflow_logger.register_model_with_run_id(model_uri=local_save_path, name=self.mlflow_registered_model_name, await_creation_for=3600)
hf_fsdp.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union
3
+ from transformers import PreTrainedModel
4
+ from transformers.models.opt.modeling_opt import OPTDecoder
5
+ if TYPE_CHECKING:
6
+ from peft import PeftModel
7
+
8
+ def rhasattr(obj: Any, attr: str) -> bool:
9
+ """A chain-able attribute version of hasattr.
10
+
11
+ For example, to check if
12
+ `obj` has the attribute `foo.bar.baz`, you can use:
13
+ `rhasattr(obj, "foo.bar.baz")`
14
+ Reference: https://stackoverflow.com/a/67303315
15
+ """
16
+ _nested_attrs = attr.split('.')
17
+ _curr_obj = obj
18
+ for _a in _nested_attrs[:-1]:
19
+ if hasattr(_curr_obj, _a):
20
+ _curr_obj = getattr(_curr_obj, _a)
21
+ else:
22
+ return False
23
+ return hasattr(_curr_obj, _nested_attrs[-1])
24
+
25
+ def rgetattr(obj: Any, attr: str, *args: List[Any]) -> Any:
26
+ """A chain-able attribute version of getattr.
27
+
28
+ For example, to get the attribute `foo.bar.baz` from `obj`, you can use:
29
+ `rgetattr(obj, "foo.bar.baz")`
30
+ Reference: https://stackoverflow.com/a/31174427
31
+ """
32
+
33
+ def _getattr(obj: Any, attr: str):
34
+ return getattr(obj, attr, *args)
35
+ return functools.reduce(_getattr, [obj] + attr.split('.'))
36
+
37
+ def findattr(obj: Any, attrs: Iterable[str]) -> Optional[Any]:
38
+ for attr in attrs:
39
+ if rhasattr(obj, attr):
40
+ return rgetattr(obj, attr)
41
+ return None
42
+
43
+ def hf_get_causal_base_model(model: PreTrainedModel) -> Any:
44
+ """Returns the causal decoder backbone of the specified HuggingFace model.
45
+
46
+ Newer HF models have a `self.get_decoder()` method. Older models do not.
47
+
48
+ NOTE: Different model configurations have different causal decoder attribute
49
+ names.
50
+ - transformer: (GPT2LMHeadModel, GPTJConfig)
51
+ - model.decoder: (OPTConfig, BloomConfig)
52
+ - gpt_neox: (GPTNeoXConfig)
53
+ """
54
+ if hasattr(model, 'get_decoder'):
55
+ return model.get_decoder()
56
+ decoder_attrs = ('transformer', 'model.decoder', 'gpt_neox', 'model.transformer')
57
+ causal_base_model = findattr(model, decoder_attrs)
58
+ if causal_base_model is None:
59
+ raise ValueError(f'Unable to FSDP-wrap model {model}. Please open a github issue to add support.')
60
+ return causal_base_model
61
+
62
+ def hf_get_hidden_layers(model: PreTrainedModel) -> Any:
63
+ """Returns the hidden layers of the specified model.
64
+
65
+ Expects to receive the causal decoder backbone, not he XXForCausalLM wrapper.
66
+
67
+ NOTE: Different model configurations have different hidden layer attribute names.
68
+ - h: (BloomForCausalLM, GPT2LMHeadModel, GPTJForCausalLM)
69
+ - decoder.layers: (OPTForCausalLM)
70
+ - layers: (GPTNeoXForCausalLM, LlaMaForCausalLM)
71
+ - blocks: (MPTForCausalLM)
72
+ """
73
+ hidden_layers_attrs = ('h', 'decoder.layers', 'layers', 'block', 'blocks')
74
+ layers = findattr(model, hidden_layers_attrs)
75
+ if layers is None:
76
+ raise ValueError(f'Unable to find hidden layer for {model}. Model must have one of the following attributes: {hidden_layers_attrs}')
77
+ return layers
78
+
79
+ def hf_get_init_device(init_device: Optional[str]) -> Optional[str]:
80
+ """Returns the appropriate device to initialize models."""
81
+ if init_device == 'mixed':
82
+ if dist.get_local_rank() == 0:
83
+ return 'cpu'
84
+ return 'meta'
85
+ return init_device
86
+
87
+ def prepare_hf_model_for_fsdp(model: PreTrainedModel, init_device: Optional[str]) -> None:
88
+ """FSDP wrap a HuggingFace model.
89
+
90
+ Call specific functions
91
+ """
92
+ if model.config.is_encoder_decoder:
93
+ prepare_hf_enc_dec_model_for_fsdp(model, init_device)
94
+ else:
95
+ prepare_hf_causal_lm_model_for_fsdp(model, init_device)
96
+
97
+ def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel, 'PeftModel'], init_device: Optional[str]) -> None:
98
+ """FSDP wrap a HuggingFace decoder.
99
+
100
+ Wrap any model for FSDP which follows one of the 3 existing conventions from
101
+ HuggingFace for decoder-only LLMs.
102
+ """
103
+ causal_base_model = hf_get_causal_base_model(model)
104
+ if isinstance(causal_base_model, OPTDecoder) or model.config.model_type == 'olmo':
105
+ underlying_model = maybe_get_underlying_model(model)
106
+ underlying_model.model._fsdp_wrap = False
107
+ model_block = hf_get_hidden_layers(causal_base_model)
108
+ lm_head = model.get_output_embeddings()
109
+ try:
110
+ tied_embeddings = causal_base_model.get_input_embeddings()
111
+ except:
112
+ tied_embeddings = model.get_input_embeddings()
113
+ modules = {'base_model': causal_base_model, 'model_block': model_block, 'lm_head': lm_head, 'tied_embeddings': tied_embeddings}
114
+ for mod_name, module in modules.items():
115
+ if module is None:
116
+ raise ValueError(f'Unable to FSDP-wrap this model! `{mod_name}` does not ' + 'follow common layer/weight naming conventions.')
117
+ block_type = type(model_block[0])
118
+ if model.config.tie_word_embeddings:
119
+ causal_base_model._fsdp_wrap = False
120
+ tied_embeddings._fsdp_wrap = False
121
+ lm_head._fsdp_wrap = False
122
+ if hasattr(model, 'peft_type') and model.peft_type is not None:
123
+ peft_type = model.peft_type.lower()
124
+ active_adapters = [adapter.lower() for adapter in model.active_adapters]
125
+ for name, module in model.named_modules():
126
+ if peft_type in name.lower() and any((adapter in name.lower() for adapter in active_adapters)):
127
+ has_parameters = next(module.parameters(), None) is not None
128
+ has_buffers = next(module.buffers(), None) is not None
129
+ if has_parameters or has_buffers:
130
+ module._fsdp_wrap = True
131
+ model.fsdp_wrap_fn = lambda module: isinstance(module, block_type)
132
+ model.activation_checkpointing_fn = lambda module: isinstance(module, block_type)
133
+
134
+ def prepare_hf_enc_dec_model_for_fsdp(model: PreTrainedModel, init_device: Optional[str]) -> None:
135
+ """Wrap an encoder/decoder HF model.
136
+
137
+ This works for T5, BART, Pegasus, PegasusX, but not all enc/dec (ProphetNet)
138
+ You have model.shared, model.encoder, model.decoder and model.lm_head, where
139
+ model.shared are the embeddings which are tied to model.lm_head, and
140
+ model.shared == model.encoder.embed_tokens and model.shared ==
141
+ model.decoder.embed_tokens
142
+ """
143
+ tied_embeddings = model.get_input_embeddings()
144
+ encoder = model.get_encoder()
145
+ decoder = model.get_decoder()
146
+ lm_head = model.get_output_embeddings()
147
+ encoder_block = hf_get_hidden_layers(encoder)
148
+ decoder_block = hf_get_hidden_layers(decoder)
149
+ modules = {'encoder': encoder, 'decoder': decoder, 'encoder_block': encoder_block, 'decoder_block': decoder_block, 'lm_head': lm_head, 'tied_embeddings': tied_embeddings}
150
+ for mod_name, module in modules.items():
151
+ if module is None:
152
+ raise ValueError(f'Unable to FSDP-wrap this model! `{mod_name}` does not ' + 'follow common layer/weight naming conventions.')
153
+ decoder_block_type = type(decoder_block[0])
154
+ encoder_block_type = type(encoder_block[0])
155
+ if model.config.tie_word_embeddings:
156
+ tied_embeddings._fsdp_wrap = False
157
+ encoder._fsdp_wrap = False
158
+ decoder._fsdp_wrap = False
159
+ lm_head._fsdp_wrap = False
160
+ model.fsdp_wrap_fn = lambda module: isinstance(module, decoder_block_type)
161
+ model.activation_checkpointing_fn = lambda module: isinstance(module, decoder_block_type)
162
+ if encoder_block_type == decoder_block_type:
163
+ return
164
+ model.fsdp_wrap_fn = lambda module: isinstance(module, encoder_block_type)
165
+ model.activation_checkpointing_fn = lambda module: isinstance(module, encoder_block_type)
hf_t5.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """Implements a Hugging Face T5 wrapped inside a :class:`.ComposerModel`."""
2
+ from __future__ import annotations
3
+ from typing import Mapping
4
+ from transformers import AutoConfig, PreTrainedTokenizerBase, T5ForConditionalGeneration
5
+ from .hf_fsdp import hf_get_init_device
6
+ from .model_wrapper import HuggingFaceModelWithFSDP
7
+ from .utils import init_empty_weights
8
+ from .warnings import experimental_class
huggingface_hub_utils.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import importlib
3
+ import os
4
+ from typing import Optional, Sequence
5
+
6
+ class DeleteSpecificNodes(ast.NodeTransformer):
7
+
8
+ def __init__(self, nodes_to_remove: list[ast.AST]):
9
+ self.nodes_to_remove = nodes_to_remove
10
+
11
+ def visit(self, node: ast.AST) -> Optional[ast.AST]:
12
+ if node in self.nodes_to_remove:
13
+ return None
14
+ return super().visit(node)
15
+
16
+ def convert_to_relative_import(module_name: str, original_parent_module_name: Optional[str]) -> str:
17
+ parts = module_name.split('.')
18
+ if parts[-1] == original_parent_module_name:
19
+ return '.'
20
+ return '.' + parts[-1]
21
+
22
+ def find_module_file(module_name: str) -> str:
23
+ if not module_name:
24
+ raise ValueError(f'Invalid input: module_name={module_name!r}')
25
+ module = importlib.import_module(module_name)
26
+ module_file = module.__file__
27
+ if module_file is None:
28
+ raise ValueError(f'Could not find file for module: {module_name}')
29
+ return module_file
30
+
31
+ def _flatten_import(node: ast.ImportFrom, flatten_imports_prefix: Sequence[str]) -> bool:
32
+ """Returns True if import should be flattened.
33
+
34
+ Checks whether the node starts the same as any of the imports in
35
+ flatten_imports_prefix.
36
+ """
37
+ for import_prefix in flatten_imports_prefix:
38
+ if node.module is not None and node.module.startswith(import_prefix):
39
+ return True
40
+ return False
41
+
42
+ def _remove_import(node: ast.ImportFrom, remove_imports_prefix: Sequence[str]) -> bool:
43
+ """Returns True if import should be removed.
44
+
45
+ Checks whether the node starts the same as any of the imports in
46
+ remove_imports_prefix.
47
+ """
48
+ for import_prefix in remove_imports_prefix:
49
+ if node.module is not None and node.module.startswith(import_prefix):
50
+ return True
51
+ return False
52
+
53
+ def process_file(file_path: str, folder_path: str, flatten_imports_prefix: Sequence[str], remove_imports_prefix: Sequence[str]) -> list[str]:
54
+ with open(file_path, 'r', encoding='utf-8') as f:
55
+ source = f.read()
56
+ parent_module_name = None
57
+ if os.path.basename(file_path) == '__init__.py':
58
+ parent_module_name = os.path.basename(os.path.dirname(file_path))
59
+ tree = ast.parse(source)
60
+ new_files_to_process = []
61
+ nodes_to_remove = []
62
+ for node in ast.walk(tree):
63
+ if isinstance(node, ast.ImportFrom) and node.module is not None and _remove_import(node, remove_imports_prefix):
64
+ nodes_to_remove.append(node)
65
+ elif isinstance(node, ast.ImportFrom) and node.module is not None and _flatten_import(node, flatten_imports_prefix):
66
+ module_path = find_module_file(node.module)
67
+ node.module = convert_to_relative_import(node.module, parent_module_name)
68
+ new_files_to_process.append(module_path)
69
+ elif isinstance(node, ast.ClassDef) and node.name.startswith('Composer'):
70
+ nodes_to_remove.append(node)
71
+ elif isinstance(node, ast.Assign) and len(node.targets) == 1 and isinstance(node.targets[0], ast.Name) and (node.targets[0].id == '__all__'):
72
+ nodes_to_remove.append(node)
73
+ transformer = DeleteSpecificNodes(nodes_to_remove)
74
+ new_tree = transformer.visit(tree)
75
+ new_filename = os.path.basename(file_path)
76
+ if new_filename == '__init__.py':
77
+ new_filename = file_path.split('/')[-2] + '.py'
78
+ new_file_path = os.path.join(folder_path, new_filename)
79
+ with open(new_file_path, 'w', encoding='utf-8') as f:
80
+ assert new_tree is not None
81
+ f.write(ast.unparse(new_tree))
82
+ return new_files_to_process
83
+
84
+ def edit_files_for_hf_compatibility(folder: str, flatten_imports_prefix: Sequence[str]=('llmfoundry',), remove_imports_prefix: Sequence[str]=('composer', 'omegaconf', 'llmfoundry.metrics')) -> None:
85
+ """Edit files to be compatible with Hugging Face Hub.
86
+
87
+ Args:
88
+ folder (str): The folder to process.
89
+ flatten_imports_prefix (Sequence[str], optional): Sequence of prefixes to flatten. Defaults to ('llmfoundry',).
90
+ remove_imports_prefix (Sequence[str], optional): Sequence of prefixes to remove. Takes precedence over flattening.
91
+ Defaults to ('composer', 'omegaconf', 'llmfoundry.metrics').
92
+ """
93
+ files_to_process = [os.path.join(folder, filename) for filename in os.listdir(folder) if filename.endswith('.py')]
94
+ files_processed_and_queued = set(files_to_process)
95
+ while len(files_to_process) > 0:
96
+ to_process = files_to_process.pop()
97
+ if os.path.isfile(to_process) and to_process.endswith('.py'):
98
+ to_add = process_file(to_process, folder, flatten_imports_prefix, remove_imports_prefix)
99
+ for file in to_add:
100
+ if file not in files_processed_and_queued:
101
+ files_to_process.append(file)
102
+ files_processed_and_queued.add(file)
interfaces.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .callback_with_config import CallbackWithConfig
llmfoundry.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings('ignore', category=UserWarning, module='bitsandbytes')
3
+ import logging
4
+ from .logging_utils import SpecificWarningFilter
5
+ hf_dynamic_modules_logger = logging.getLogger('transformers.dynamic_module_utils')
6
+ new_files_warning_filter = SpecificWarningFilter('A new version of the following files was downloaded from')
7
+ hf_dynamic_modules_logger.addFilter(new_files_warning_filter)
8
+ from . import algorithms, callbacks, loggers, optim, registry, utils
9
+ from .data import ConcatTokensDataset, NoConcatDataset, Seq2SeqFinetuningCollator, build_finetuning_dataloader
10
+ from .hf import ComposerHFCausalLM, ComposerHFT5
11
+ from .attention import MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn, scaled_multihead_dot_product_attention
12
+ from .blocks import MPTBlock
13
+ from .ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
14
+ from .mpt import ComposerMPTCausalLM, MPTConfig, MPTForCausalLM, MPTModel, MPTPreTrainedModel
15
+ from .tokenizers import TiktokenTokenizerWrapper
16
+ __version__ = '0.7.0'
logging_utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ class SpecificWarningFilter(logging.Filter):
5
+
6
+ def __init__(self, message_to_suppress: str):
7
+ """Filter out a specific warning message based on its content.
8
+
9
+ This can be useful for filtering out specific warning messages from third party packages.
10
+
11
+ Args:
12
+ message_to_suppress (str): The warning message to suppress.
13
+ """
14
+ super().__init__()
15
+ self.message_to_suppress = message_to_suppress
16
+
17
+ def filter(self, record: logging.LogRecord) -> bool:
18
+ return self.message_to_suppress not in record.getMessage()
19
+
20
+ def get_mosaicml_logger():
21
+ if os.environ.get(MOSAICML_PLATFORM_ENV_VAR, 'false').lower() == 'true' and os.environ.get(MOSAICML_ACCESS_TOKEN_ENV_VAR):
22
+ return MosaicMLLogger()
23
+ else:
24
+ return None
meta_init_context.py CHANGED
@@ -95,5 +95,5 @@ def init_on_device(device: torch.device, include_buffers: bool=False):
95
  nn.Module.register_parameter = old_register_parameter
96
  if include_buffers:
97
  nn.Module.register_buffer = old_register_buffer
98
- for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items():
99
  setattr(torch, torch_function_name, old_torch_function)
 
95
  nn.Module.register_parameter = old_register_parameter
96
  if include_buffers:
97
  nn.Module.register_buffer = old_register_buffer
98
+ for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
99
  setattr(torch, torch_function_name, old_torch_function)
model_download_utils.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for downloading models."""
2
+ import copy
3
+ import logging
4
+ import os
5
+ import shutil
6
+ import subprocess
7
+ import time
8
+ import warnings
9
+ from http import HTTPStatus
10
+ from typing import Optional
11
+ from urllib.parse import urljoin
12
+ import huggingface_hub as hf_hub
13
+ import requests
14
+ import tenacity
15
+ import yaml
16
+ from bs4 import BeautifulSoup
17
+ from requests.packages.urllib3.exceptions import InsecureRequestWarning
18
+ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
19
+ from transformers.utils import WEIGHTS_INDEX_NAME as PYTORCH_WEIGHTS_INDEX_NAME
20
+ from transformers.utils import WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME
21
+ DEFAULT_IGNORE_PATTERNS = ['*.ckpt', '*.h5', '*.msgpack']
22
+ PYTORCH_WEIGHTS_PATTERN = 'pytorch_model*.bin*'
23
+ SAFE_WEIGHTS_PATTERN = 'model*.safetensors*'
24
+ TOKENIZER_FILES = ['special_tokens_map.json', 'tokenizer.json', 'tokenizer.model', 'tokenizer_config.json']
25
+ ORAS_PASSWD_PLACEHOLDER = '<placeholder_for_passwd>'
26
+ ORAS_CLI = 'oras'
27
+ log = logging.getLogger(__name__)
28
+
29
+ @tenacity.retry(retry=tenacity.retry_if_not_exception_type((ValueError, hf_hub.utils.RepositoryNotFoundError)), stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(min=1, max=10))
30
+ def download_from_hf_hub(model: str, save_dir: str, prefer_safetensors: bool=True, tokenizer_only: bool=False, token: Optional[str]=None):
31
+ """Downloads model files from a Hugging Face Hub model repo.
32
+
33
+ Only supports models stored in Safetensors and PyTorch formats for now. If both formats are available, only the
34
+ Safetensors weights will be downloaded unless `prefer_safetensors` is set to False.
35
+
36
+ Args:
37
+ repo_id (str): The Hugging Face Hub repo ID.
38
+ save_dir (str, optional): The local path to the directory where the model files will be downloaded.
39
+ prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are
40
+ available. Defaults to True.
41
+ tokenizer_only (bool): If true, only download tokenizer files.
42
+ token (str, optional): The HuggingFace API token. If not provided, the token will be read from the
43
+ `HUGGING_FACE_HUB_TOKEN` environment variable.
44
+
45
+ Raises:
46
+ RepositoryNotFoundError: If the model repo doesn't exist or the token is unauthorized.
47
+ ValueError: If the model repo doesn't contain any supported model weights.
48
+ """
49
+ repo_files = set(hf_hub.list_repo_files(model))
50
+ ignore_patterns = copy.deepcopy(DEFAULT_IGNORE_PATTERNS)
51
+ safetensors_available = SAFE_WEIGHTS_NAME in repo_files or SAFE_WEIGHTS_INDEX_NAME in repo_files
52
+ pytorch_available = PYTORCH_WEIGHTS_NAME in repo_files or PYTORCH_WEIGHTS_INDEX_NAME in repo_files
53
+ if safetensors_available and pytorch_available:
54
+ if prefer_safetensors:
55
+ log.info('Safetensors available and preferred. Excluding pytorch weights.')
56
+ ignore_patterns.append(PYTORCH_WEIGHTS_PATTERN)
57
+ else:
58
+ log.info('Pytorch available and preferred. Excluding safetensors weights.')
59
+ ignore_patterns.append(SAFE_WEIGHTS_PATTERN)
60
+ elif safetensors_available:
61
+ log.info('Only safetensors available. Ignoring weights preference.')
62
+ elif pytorch_available:
63
+ log.info('Only pytorch available. Ignoring weights preference.')
64
+ else:
65
+ raise ValueError(f'No supported model weights found in repo {model}.' + ' Please make sure the repo contains either safetensors or pytorch weights.')
66
+ allow_patterns = TOKENIZER_FILES if tokenizer_only else None
67
+ download_start = time.time()
68
+ hf_hub.snapshot_download(model, local_dir=save_dir, local_dir_use_symlinks=False, ignore_patterns=ignore_patterns, allow_patterns=allow_patterns, token=token)
69
+ download_duration = time.time() - download_start
70
+ log.info(f'Downloaded model {model} from Hugging Face Hub in {download_duration} seconds')
71
+
72
+ def _extract_links_from_html(html: str):
73
+ """Extracts links from HTML content.
74
+
75
+ Args:
76
+ html (str): The HTML content
77
+
78
+ Returns:
79
+ list[str]: A list of links to download.
80
+ """
81
+ soup = BeautifulSoup(html, 'html.parser')
82
+ links = [a['href'] for a in soup.find_all('a')]
83
+ return links
84
+
85
+ def _recursive_download(session: requests.Session, base_url: str, path: str, save_dir: str, ignore_cert: bool=False):
86
+ """Downloads all files/subdirectories from a directory on a remote server.
87
+
88
+ Args:
89
+ session: A requests.Session through which to make requests to the remote server.
90
+ url (str): The base URL where the files are located.
91
+ path (str): The path from the base URL to the files to download. The full URL for the download is equal to
92
+ '<base_url>/<path>'.
93
+ save_dir (str): The directory to save downloaded files to.
94
+ ignore_cert (bool): Whether or not to ignore the validity of the SSL certificate of the remote server.
95
+ Defaults to False.
96
+ WARNING: Setting this to true is *not* secure, as no certificate verification will be performed.
97
+
98
+ Raises:
99
+ PermissionError: If the remote server returns a 401 Unauthorized status code.
100
+ ValueError: If the remote server returns a 404 Not Found status code.
101
+ RuntimeError: If the remote server returns a status code other than 200 OK or 401 Unauthorized.
102
+ """
103
+ url = urljoin(base_url, path)
104
+ print(url)
105
+ response = session.get(url, verify=not ignore_cert)
106
+ if response.status_code == HTTPStatus.UNAUTHORIZED:
107
+ raise PermissionError(f'Not authorized to download file from {url}. Received status code {response.status_code}. ')
108
+ elif response.status_code == HTTPStatus.NOT_FOUND:
109
+ raise ValueError(f'Could not find file at {url}. Received status code {response.status_code}')
110
+ elif response.status_code != HTTPStatus.OK:
111
+ raise RuntimeError(f'Could not download file from {url}. Received unexpected status code {response.status_code}')
112
+ if not url.endswith('/'):
113
+ save_path = os.path.join(save_dir, path)
114
+ parent_dir = os.path.dirname(save_path)
115
+ if not os.path.exists(parent_dir):
116
+ os.makedirs(parent_dir)
117
+ with open(save_path, 'wb') as f:
118
+ f.write(response.content)
119
+ log.info(f'Downloaded file {save_path}')
120
+ return
121
+ child_links = _extract_links_from_html(response.content.decode())
122
+ print(child_links)
123
+ for child_link in child_links:
124
+ _recursive_download(session, base_url, urljoin(path, child_link), save_dir, ignore_cert=ignore_cert)
125
+
126
+ @tenacity.retry(retry=tenacity.retry_if_not_exception_type((PermissionError, ValueError)), stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(min=1, max=10))
127
+ def download_from_http_fileserver(url: str, save_dir: str, ignore_cert: bool=False):
128
+ """Downloads files from a remote HTTP file server.
129
+
130
+ Args:
131
+ url (str): The base URL where the files are located.
132
+ save_dir (str): The directory to save downloaded files to.
133
+ ignore_cert (bool): Whether or not to ignore the validity of the SSL certificate of the remote server.
134
+ Defaults to False.
135
+ WARNING: Setting this to true is *not* secure, as no certificate verification will be performed.
136
+ """
137
+ with requests.Session() as session:
138
+ with warnings.catch_warnings():
139
+ if ignore_cert:
140
+ warnings.simplefilter('ignore', category=InsecureRequestWarning)
141
+ _recursive_download(session, url, '', save_dir, ignore_cert=ignore_cert)
142
+
143
+ def download_from_oras(model: str, config_file: str, credentials_dir: str, save_dir: str, tokenizer_only: bool=False, concurrency: int=10):
144
+ """Download from an OCI-compliant registry using oras.
145
+
146
+ Args:
147
+ model (str): The name of the model to download.
148
+ config_file (str): Path to a YAML config file that maps model and tokenizer names to registry paths.
149
+ credentials_dir (str): Path to a directory containing credentials for the registry. It is expected to contain three
150
+ files: `username`, `password`, and `registry`, each of which contains the corresponding credential.
151
+ save_dir (str): Path to the directory where files will be downloaded.
152
+ tokenizer_only (bool): If true, only download the tokenzier files.
153
+ concurrency (int): The number of concurrent downloads to run.
154
+ """
155
+ if shutil.which(ORAS_CLI) is None:
156
+ raise Exception(f'oras cli command `{ORAS_CLI}` is not found. Please install oras: https://oras.land/docs/installation ')
157
+
158
+ def _read_secrets_file(secret_file_path: str):
159
+ try:
160
+ with open(secret_file_path, encoding='utf-8') as f:
161
+ return f.read().strip()
162
+ except Exception as error:
163
+ raise ValueError(f'secrets file {secret_file_path} failed to be read') from error
164
+ secrets = {}
165
+ for secret in ['username', 'password', 'registry']:
166
+ secrets[secret] = _read_secrets_file(os.path.join(credentials_dir, secret))
167
+ with open(config_file, 'r', encoding='utf-8') as f:
168
+ configs = yaml.safe_load(f.read())
169
+ config_type = 'tokenizers' if tokenizer_only else 'models'
170
+ path = configs[config_type][model]
171
+ registry = secrets['registry']
172
+
173
+ def get_oras_cmd(username: Optional[str]=None, password: Optional[str]=None):
174
+ cmd = [ORAS_CLI, 'pull', f'{registry}/{path}', '-o', save_dir, '--verbose', '--concurrency', str(concurrency)]
175
+ if username is not None:
176
+ cmd.extend(['--username', username])
177
+ if password is not None:
178
+ cmd.extend(['--password', password])
179
+ return cmd
180
+ cmd_without_creds = get_oras_cmd()
181
+ log.info(f"CMD for oras cli to run: {' '.join(cmd_without_creds)}")
182
+ cmd_to_run = get_oras_cmd(username=secrets['username'], password=secrets['password'])
183
+ try:
184
+ subprocess.run(cmd_to_run, check=True)
185
+ except subprocess.CalledProcessError as e:
186
+ raise subprocess.CalledProcessError(e.returncode, cmd_without_creds, e.output, e.stderr)
model_wrapper.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Re-usable :class:`.ComposerModel` for LLM HF Models."""
2
+ from __future__ import annotations
3
+ from collections import UserDict
4
+ from typing import TYPE_CHECKING, List, Mapping, Optional
5
+ import transformers
6
+ from torchmetrics import Metric
7
+ from transformers import PreTrainedTokenizerBase
8
+ from transformers.utils.generic import ModelOutput
9
+ from .hf_fsdp import prepare_hf_model_for_fsdp
10
+ if TYPE_CHECKING:
11
+ from peft import PeftConfig
12
+ _HF_IGNORE_INDEX = -100
13
+
14
+ class HuggingFaceModelWithFSDP(HuggingFaceModel):
15
+ """Wrapper around HuggingFaceModel.
16
+
17
+ Handles preparation for FSDP wrapping.
18
+ """
19
+
20
+ def __init__(self, model: transformers.PreTrainedModel, tokenizer: Optional[PreTrainedTokenizerBase]=None, metrics: Optional[List[Metric]]=None, eval_metrics: Optional[List[Metric]]=None, shift_labels: bool=False, init_device: Optional[str]=None, peft_config: Optional['PeftConfig']=None):
21
+ super().__init__(model, tokenizer, use_logits=True, metrics=metrics, eval_metrics=eval_metrics, shift_labels=shift_labels, peft_config=peft_config, should_save_peft_only=True)
22
+ prepare_hf_model_for_fsdp(self.model, init_device)
23
+ self.model.param_init_fn = lambda module: self.model._init_weights(module)
24
+
25
+ def forward(self, batch: Mapping):
26
+ if isinstance(batch, dict) or isinstance(batch, UserDict):
27
+ batch = {k: v for k, v in batch.items() if k in self.model_forward_args}
28
+ output = self.model(**batch)
29
+ else:
30
+ raise ValueError('Unexpected batch type. Expected a dictionary with keys corresponding to the inputs to the forward function of the Huggingface model')
31
+ return output
32
+
33
+ def loss(self, outputs: ModelOutput, batch: Mapping):
34
+ if self.config.use_return_dict:
35
+ return outputs['loss']
36
+ return outputs[:2]
modeling_mpt.py CHANGED
@@ -9,40 +9,27 @@ from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, Un
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
- from .attention import is_flash_v1_installed, is_flash_v2_installed
 
13
  if is_flash_v2_installed():
14
  try:
15
  from flash_attn import bert_padding
16
  from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding
17
  except Exception as e:
18
  raise e
19
- if is_flash_v1_installed():
20
- try:
21
- from flash_attn import bert_padding
22
- except Exception as e:
23
- raise e
24
  from transformers import PreTrainedModel, PreTrainedTokenizerBase
25
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
26
  from transformers.models.llama.modeling_llama import LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding
27
  from transformers.models.llama.modeling_llama import LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding
28
  from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding as HFRotaryEmbedding
29
- from .attention import ATTN_CLASS_REGISTRY, attn_bias_shape, build_attn_bias, gen_slopes
30
  from .blocks import MPTBlock
31
  from .custom_embedding import SharedEmbedding
32
- from .fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY
33
- from .ffn import FFN_CLASS_REGISTRY as FFN_CLASS_REGISTRY
34
- from .ffn import MPTMLP as MPTMLP
35
  from .ffn import build_ffn as build_ffn
36
- from .norm import NORM_CLASS_REGISTRY
37
  from .configuration_mpt import MPTConfig
38
- from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
39
- from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
40
  from .meta_init_context import init_empty_weights
41
  from .param_init_fns import generic_param_init_fn_, MODEL_INIT_REGISTRY
42
- try:
43
- from .flash_attn_triton import flash_attn_func as flash_attn_func
44
- except:
45
- pass
46
  import logging
47
  log = logging.getLogger(__name__)
48
 
@@ -140,9 +127,9 @@ def gen_flash_attn_padding_info(bsz: int, S: int, past_key_len: int, device: tor
140
  key_padding_mask = attention_mask_in_length
141
  query_padding_mask = attention_mask_in_length
142
  unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
143
- (_, indices_q, cu_seqlens_q, max_seqlen_q) = unpadding_function(torch.empty(bsz, S, 1, device=device), query_padding_mask)
144
- (_, indices_k, cu_seqlens_k, max_seqlen_k) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
145
- (_, indices_v, _, _) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
146
  flash_attn_padding_info['indices_q'] = indices_q
147
  flash_attn_padding_info['indices_k'] = indices_k
148
  flash_attn_padding_info['indices_v'] = indices_v
@@ -176,7 +163,6 @@ class MPTModel(MPTPreTrainedModel):
176
  config._validate_config()
177
  super().__init__(config)
178
  self.attn_impl = config.attn_config['attn_impl']
179
- self.prefix_lm = config.attn_config['prefix_lm']
180
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
181
  self.alibi = config.attn_config['alibi']
182
  self.alibi_bias_max = config.attn_config['alibi_bias_max']
@@ -196,6 +182,10 @@ class MPTModel(MPTPreTrainedModel):
196
  self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
197
  self.emb_drop = nn.Dropout(config.emb_pdrop)
198
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
 
 
 
 
199
  self.norm_f = norm_class(config.d_model, device=config.init_device)
200
  self.rope = config.attn_config['rope']
201
  self.rope_impl = None
@@ -205,10 +195,10 @@ class MPTModel(MPTPreTrainedModel):
205
  if config.init_device != 'meta':
206
  log.info(f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.')
207
  self.apply(self.param_init_fn)
208
- self.is_causal = not self.prefix_lm
209
  self._attn_bias_initialized = False
210
  self.attn_bias = None
211
- self.attn_bias_shape = attn_bias_shape(self.attn_impl, config.n_heads, config.max_seq_len, self.alibi, prefix_lm=self.prefix_lm, causal=self.is_causal, use_sequence_id=self.attn_uses_sequence_id)
212
  if config.no_bias:
213
  for module in self.modules():
214
  if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
@@ -227,7 +217,7 @@ class MPTModel(MPTPreTrainedModel):
227
  self.wte = value
228
 
229
  @torch.no_grad()
230
- def _attn_bias(self, device: torch.device, dtype: torch.dtype, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None) -> Tuple[Optional[torch.Tensor], Optional[torch.ByteTensor]]:
231
  if not self._attn_bias_initialized:
232
  if self.attn_bias_shape:
233
  self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
@@ -238,10 +228,6 @@ class MPTModel(MPTPreTrainedModel):
238
  if self.attn_bias is not None:
239
  self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
240
  attn_bias = self.attn_bias
241
- if self.prefix_lm:
242
- assert isinstance(attn_bias, torch.Tensor)
243
- assert isinstance(prefix_mask, torch.Tensor)
244
- attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
245
  if self.attn_uses_sequence_id and sequence_id is not None:
246
  assert isinstance(attn_bias, torch.Tensor)
247
  attn_bias = apply_sequence_id(attn_bias, sequence_id, self.config.max_seq_len)
@@ -252,43 +238,22 @@ class MPTModel(MPTPreTrainedModel):
252
  else:
253
  _s_k = max(0, attn_bias.size(-1) - s_k)
254
  attn_bias = attn_bias[:, :, :, _s_k:]
255
- if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
256
- raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
257
  min_val = torch.finfo(attn_bias.dtype).min
258
  attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
259
  return (attn_bias, attention_mask)
260
 
261
- def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor) -> torch.Tensor:
262
- (s_k, s_q) = attn_bias.shape[-2:]
263
- if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
264
- raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.')
265
- seq_len = prefix_mask.shape[-1]
266
- if seq_len > self.config.max_seq_len:
267
- raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
268
- attn_bias = attn_bias[..., :seq_len, :seq_len]
269
- causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
270
- prefix = prefix_mask.view(-1, 1, 1, seq_len)
271
- cannot_attend = ~torch.logical_or(causal, prefix.bool())
272
- min_val = torch.finfo(attn_bias.dtype).min
273
- attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
274
- return attn_bias
275
-
276
- def forward(self, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None) -> BaseModelOutputWithPast:
277
  return_dict = return_dict if return_dict is not None else self.config.return_dict
278
  use_cache = use_cache if use_cache is not None else self.config.use_cache
279
  if attention_mask is not None:
280
  attention_mask = attention_mask.bool()
281
- if prefix_mask is not None:
282
- prefix_mask = prefix_mask.bool()
283
  if not return_dict:
284
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
285
  if output_attentions:
286
  if self.attn_impl != 'torch':
287
- raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
288
  if self.training and attention_mask is not None and (attention_mask[:, 0].sum() != attention_mask.shape[0]):
289
  raise NotImplementedError('MPT does not support training with left padding.')
290
- if self.prefix_lm and prefix_mask is None:
291
- raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
292
  if self.training:
293
  if self.attn_uses_sequence_id and sequence_id is None:
294
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
@@ -336,7 +301,7 @@ class MPTModel(MPTPreTrainedModel):
336
  x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
337
  assert isinstance(self.emb_drop, nn.Module)
338
  x = self.emb_drop(x_shrunk)
339
- (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
340
  attention_mask_in_length = gen_attention_mask_in_length(sequence_id=sequence_id, S=S, attn_uses_sequence_id=self.attn_uses_sequence_id, attn_impl=self.attn_impl, attention_mask=attention_mask)
341
  alibi_slopes = None
342
  if self.alibi and self.attn_impl == 'flash':
@@ -349,12 +314,12 @@ class MPTModel(MPTPreTrainedModel):
349
  flash_attn_padding_info = {}
350
  if self.attn_impl == 'flash':
351
  flash_attn_padding_info = gen_flash_attn_padding_info(bsz, S, past_position, x.device, attention_mask_in_length, attention_mask)
352
- for (b_idx, block) in enumerate(self.blocks):
353
  if output_hidden_states:
354
  assert all_hidden_states is not None
355
  all_hidden_states = all_hidden_states + (x,)
356
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
357
- (x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info)
358
  if presents is not None:
359
  presents += (present,)
360
  if output_attentions:
@@ -422,7 +387,8 @@ class MPTForCausalLM(MPTPreTrainedModel):
422
  self.transformer.set_input_embeddings(new_embeddings)
423
 
424
  def tie_weights(self) -> None:
425
- self.lm_head = None
 
426
 
427
  def set_decoder(self, decoder: MPTModel) -> None:
428
  self.transformer = decoder
@@ -430,10 +396,10 @@ class MPTForCausalLM(MPTPreTrainedModel):
430
  def get_decoder(self) -> MPTModel:
431
  return self.transformer
432
 
433
- def forward(self, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None) -> CausalLMOutputWithPast:
434
  return_dict = return_dict if return_dict is not None else self.config.return_dict
435
  use_cache = use_cache if use_cache is not None else self.config.use_cache
436
- outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, inputs_embeds=inputs_embeds)
437
  if self.lm_head is not None:
438
  logits = self.lm_head(outputs.last_hidden_state)
439
  else:
@@ -459,29 +425,48 @@ class MPTForCausalLM(MPTPreTrainedModel):
459
  return _fsdp_wrap_fn(self, module)
460
 
461
  def activation_checkpointing_fn(self, module: nn.Module) -> bool:
462
- act_ckpt_list = getattr(self.config, 'activation_checkpointing_target', None) or ['MPTBlock']
463
- if isinstance(act_ckpt_list, str):
464
- act_ckpt_list = [act_ckpt_list]
465
- elif not isinstance(act_ckpt_list, list):
466
- raise ValueError(f'activation_checkpointing_target must be either a single string or a list, but got {type(act_ckpt_list)}')
467
- if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list:
468
- if len(act_ckpt_list) > 1:
469
- log.info('Activation checkpointing MPTBlock only (ignoring other sub-block modules specified in activation_checkpointing_target).')
470
- return isinstance(module, MPTBlock)
471
- mod_types = ()
472
- for mod_name in act_ckpt_list:
473
- if mod_name.lower() == 'mptblock':
474
- mod_types += (MPTBlock,)
475
- elif mod_name in ATTN_CLASS_REGISTRY:
476
- mod_types += (ATTN_CLASS_REGISTRY[mod_name],)
477
- elif mod_name in FFN_CLASS_REGISTRY:
478
- mod_types += (FFN_CLASS_REGISTRY[mod_name],)
479
- elif mod_name in NORM_CLASS_REGISTRY:
480
- mod_types += (NORM_CLASS_REGISTRY[mod_name],)
481
- else:
482
- msg = ', '.join(list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) + list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock'])
483
- raise ValueError(f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.')
484
- return isinstance(module, mod_types)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
  def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]=None, inputs_embeds: Optional[torch.Tensor]=None, **kwargs: Any) -> Dict[str, Any]:
487
  attention_mask = kwargs['attention_mask'].bool()
@@ -493,17 +478,11 @@ class MPTForCausalLM(MPTPreTrainedModel):
493
  sequence_id = None
494
  if past_key_values is not None:
495
  input_ids = input_ids[:, -1].unsqueeze(-1)
496
- if self.transformer.prefix_lm:
497
- prefix_mask = torch.ones_like(attention_mask)
498
- if kwargs.get('use_cache') == False:
499
- raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')
500
- else:
501
- prefix_mask = None
502
  if inputs_embeds is not None and past_key_values is None:
503
  model_inputs = {'inputs_embeds': inputs_embeds}
504
  else:
505
  model_inputs = {'input_ids': input_ids}
506
- model_inputs.update({'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)})
507
  return model_inputs
508
 
509
  @staticmethod
 
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
+ from .attention import is_flash_v2_installed
13
+ from .norm import NORM_CLASS_REGISTRY
14
  if is_flash_v2_installed():
15
  try:
16
  from flash_attn import bert_padding
17
  from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding
18
  except Exception as e:
19
  raise e
 
 
 
 
 
20
  from transformers import PreTrainedModel, PreTrainedTokenizerBase
21
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
22
  from transformers.models.llama.modeling_llama import LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding
23
  from transformers.models.llama.modeling_llama import LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding
24
  from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding as HFRotaryEmbedding
25
+ from .attention import attn_bias_shape, build_attn_bias, gen_slopes
26
  from .blocks import MPTBlock
27
  from .custom_embedding import SharedEmbedding
 
 
 
28
  from .ffn import build_ffn as build_ffn
 
29
  from .configuration_mpt import MPTConfig
 
 
30
  from .meta_init_context import init_empty_weights
31
  from .param_init_fns import generic_param_init_fn_, MODEL_INIT_REGISTRY
32
+ from .act_ckpt import pass_on_block_idx, build_act_ckpt_mod_to_blocks, check_mapping_blocks_overlap
 
 
 
33
  import logging
34
  log = logging.getLogger(__name__)
35
 
 
127
  key_padding_mask = attention_mask_in_length
128
  query_padding_mask = attention_mask_in_length
129
  unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
130
+ _, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(torch.empty(bsz, S, 1, device=device), query_padding_mask)
131
+ _, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
132
+ _, indices_v, _, _ = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
133
  flash_attn_padding_info['indices_q'] = indices_q
134
  flash_attn_padding_info['indices_k'] = indices_k
135
  flash_attn_padding_info['indices_v'] = indices_v
 
163
  config._validate_config()
164
  super().__init__(config)
165
  self.attn_impl = config.attn_config['attn_impl']
 
166
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
167
  self.alibi = config.attn_config['alibi']
168
  self.alibi_bias_max = config.attn_config['alibi_bias_max']
 
182
  self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
183
  self.emb_drop = nn.Dropout(config.emb_pdrop)
184
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
185
+ for i, block in enumerate(self.blocks):
186
+ block.block_idx = i
187
+ block.max_block_idx = config.n_layers - 1
188
+ pass_on_block_idx(block)
189
  self.norm_f = norm_class(config.d_model, device=config.init_device)
190
  self.rope = config.attn_config['rope']
191
  self.rope_impl = None
 
195
  if config.init_device != 'meta':
196
  log.info(f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.')
197
  self.apply(self.param_init_fn)
198
+ self.is_causal = True
199
  self._attn_bias_initialized = False
200
  self.attn_bias = None
201
+ self.attn_bias_shape = attn_bias_shape(self.attn_impl, config.n_heads, config.max_seq_len, self.alibi, causal=self.is_causal, use_sequence_id=self.attn_uses_sequence_id)
202
  if config.no_bias:
203
  for module in self.modules():
204
  if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
 
217
  self.wte = value
218
 
219
  @torch.no_grad()
220
+ def _attn_bias(self, device: torch.device, dtype: torch.dtype, attention_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None) -> Tuple[Optional[torch.Tensor], Optional[torch.ByteTensor]]:
221
  if not self._attn_bias_initialized:
222
  if self.attn_bias_shape:
223
  self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
 
228
  if self.attn_bias is not None:
229
  self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
230
  attn_bias = self.attn_bias
 
 
 
 
231
  if self.attn_uses_sequence_id and sequence_id is not None:
232
  assert isinstance(attn_bias, torch.Tensor)
233
  attn_bias = apply_sequence_id(attn_bias, sequence_id, self.config.max_seq_len)
 
238
  else:
239
  _s_k = max(0, attn_bias.size(-1) - s_k)
240
  attn_bias = attn_bias[:, :, :, _s_k:]
 
 
241
  min_val = torch.finfo(attn_bias.dtype).min
242
  attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
243
  return (attn_bias, attention_mask)
244
 
245
+ def forward(self, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None) -> BaseModelOutputWithPast:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  return_dict = return_dict if return_dict is not None else self.config.return_dict
247
  use_cache = use_cache if use_cache is not None else self.config.use_cache
248
  if attention_mask is not None:
249
  attention_mask = attention_mask.bool()
 
 
250
  if not return_dict:
251
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
252
  if output_attentions:
253
  if self.attn_impl != 'torch':
254
+ raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash`.')
255
  if self.training and attention_mask is not None and (attention_mask[:, 0].sum() != attention_mask.shape[0]):
256
  raise NotImplementedError('MPT does not support training with left padding.')
 
 
257
  if self.training:
258
  if self.attn_uses_sequence_id and sequence_id is None:
259
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
 
301
  x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
302
  assert isinstance(self.emb_drop, nn.Module)
303
  x = self.emb_drop(x_shrunk)
304
+ attn_bias, attention_mask = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, sequence_id=sequence_id)
305
  attention_mask_in_length = gen_attention_mask_in_length(sequence_id=sequence_id, S=S, attn_uses_sequence_id=self.attn_uses_sequence_id, attn_impl=self.attn_impl, attention_mask=attention_mask)
306
  alibi_slopes = None
307
  if self.alibi and self.attn_impl == 'flash':
 
314
  flash_attn_padding_info = {}
315
  if self.attn_impl == 'flash':
316
  flash_attn_padding_info = gen_flash_attn_padding_info(bsz, S, past_position, x.device, attention_mask_in_length, attention_mask)
317
+ for b_idx, block in enumerate(self.blocks):
318
  if output_hidden_states:
319
  assert all_hidden_states is not None
320
  all_hidden_states = all_hidden_states + (x,)
321
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
322
+ x, attn_weights, present = block(x, past_key_value=past_key_value, attn_bias=attn_bias, rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info)
323
  if presents is not None:
324
  presents += (present,)
325
  if output_attentions:
 
387
  self.transformer.set_input_embeddings(new_embeddings)
388
 
389
  def tie_weights(self) -> None:
390
+ if getattr(self.config, 'tie_word_embeddings', True):
391
+ self.lm_head = None
392
 
393
  def set_decoder(self, decoder: MPTModel) -> None:
394
  self.transformer = decoder
 
396
  def get_decoder(self) -> MPTModel:
397
  return self.transformer
398
 
399
+ def forward(self, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None) -> CausalLMOutputWithPast:
400
  return_dict = return_dict if return_dict is not None else self.config.return_dict
401
  use_cache = use_cache if use_cache is not None else self.config.use_cache
402
+ outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, inputs_embeds=inputs_embeds)
403
  if self.lm_head is not None:
404
  logits = self.lm_head(outputs.last_hidden_state)
405
  else:
 
425
  return _fsdp_wrap_fn(self, module)
426
 
427
  def activation_checkpointing_fn(self, module: nn.Module) -> bool:
428
+ """The MPT activation checkpointing (act ckpt) function.
429
+
430
+ When `activation_checkpointing` in fsdp_config is set to true, this function will be called on all the modules in the FSDP wrapped model and determine whether a given module should be activation checkpointed. It checks the checkpointing target (`activation_checkpointing_target` in `model`) which can be specified as below:
431
+ 1. null (or no such field): The whole MPTBlock will be activation checkpointed on all layers
432
+ 2. a list of modules to act ckpt on all layers, e.g.,
433
+ activation_checkpointing_target:
434
+ - grouped_query_attention
435
+ - mptmlp
436
+ 3. a dictionary of module name with target_blocks, e.g.,
437
+ activation_checkpointing_target:
438
+ {
439
+ "mptblock": target_blocks_1,
440
+ "grouped_query_attention": target_blocks_2
441
+ }
442
+ target_blocks (target_blocks_1, target_blocks_2 above) can be:
443
+ - a single integer n: the first n transformer block will be activation checkpointed
444
+ - a string of first-n, middle-m, last-k, range-i-j: the first n, the middle m, the last k, or the range [i, j) layers will be activation checkpointed. E.g, 'first-2, last-2' means the first 2 and last 2 transformer blocks will be activation checkpointed
445
+ middle-m is range [start, end) where ``start = max(max_block_idx // 2 - m // 2, 0), end = min(start + m, max_block_idx + 1)``
446
+ - a list of integers corresponds to the list of transformer block ids, e.g., [2] means the second transformer block will be activation checkpointed. [2, 3] means the second and third transformer blocks will be activation checkpointed
447
+ - a list of mixed integers and strings of first-n, middle-m, last-k, range-i-j
448
+
449
+ An example in yaml config file:
450
+ fsdp_config:
451
+ activation_checkpointing: true
452
+ model:
453
+ activation_checkpointing_target:
454
+ {
455
+ "mptblock": 'first-5',
456
+ "grouped_query_attention": 'last-35'
457
+ }
458
+ """
459
+ if not hasattr(module, 'block_idx'):
460
+ log.debug(f'{module.__class__.__name__} cannot be activation checkpointed. Only transformer block or its submodules are eligible for activation checkpointing.')
461
+ return False
462
+ act_ckpt_target = getattr(self.config, 'activation_checkpointing_target', None)
463
+ act_ckpt_mod_to_blocks = build_act_ckpt_mod_to_blocks(act_ckpt_target, MPTBlock, module.max_block_idx)
464
+ check_mapping_blocks_overlap(act_ckpt_mod_to_blocks, module.max_block_idx)
465
+ for k in act_ckpt_mod_to_blocks.keys():
466
+ if isinstance(module, k):
467
+ blocks = act_ckpt_mod_to_blocks[k]
468
+ return True if blocks == -1 else module.block_idx in blocks
469
+ return False
470
 
471
  def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]=None, inputs_embeds: Optional[torch.Tensor]=None, **kwargs: Any) -> Dict[str, Any]:
472
  attention_mask = kwargs['attention_mask'].bool()
 
478
  sequence_id = None
479
  if past_key_values is not None:
480
  input_ids = input_ids[:, -1].unsqueeze(-1)
 
 
 
 
 
 
481
  if inputs_embeds is not None and past_key_values is None:
482
  model_inputs = {'inputs_embeds': inputs_embeds}
483
  else:
484
  model_inputs = {'input_ids': input_ids}
485
+ model_inputs.update({'attention_mask': attention_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)})
486
  return model_inputs
487
 
488
  @staticmethod
monolithic_ckpt_callback.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import os
3
+ import tempfile
4
+ from pathlib import Path
5
+ import torch
6
+
7
+ class MonolithicCheckpointSaver(Callback):
8
+ """Save a monolithic checkpoint every N batches.
9
+
10
+ Args:
11
+ save_folder (str): Folder to save checkpoints to (can be a URI)
12
+ batch_interval (int): Number of batches between checkpoints.
13
+ filename (str): Filename to save checkpoints to.
14
+ overwrite (bool): Whether to overwrite previous checkpoints.
15
+ keep_optimizers (bool): Whether to save the optimizer state in the monolithic checkpoint.
16
+ """
17
+
18
+ def __init__(self, save_folder: str, batch_interval: int, filename: str='ep{epoch}-ba{batch}.pt', overwrite: bool=False, keep_optimizers: bool=False):
19
+ self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(save_folder)
20
+ self.filename_format_str = filename
21
+ self.batch_interval = batch_interval
22
+ self.upload_to_object_store = self.backend != ''
23
+ self.overwrite = overwrite
24
+ self.keep_optimizers = keep_optimizers
25
+ if self.upload_to_object_store:
26
+ self.remote_ud = RemoteUploaderDownloader(bucket_uri=f'{self.backend}://{self.bucket_name}')
27
+ else:
28
+ self.remote_ud = None
29
+
30
+ def init(self, state: State, logger: Logger) -> None:
31
+ if self.upload_to_object_store and self.remote_ud is not None:
32
+ self.remote_ud.init(state, logger)
33
+ state.callbacks.append(self.remote_ud)
34
+
35
+ def batch_checkpoint(self, state: State, logger: Logger) -> None:
36
+ if state.timestamp.batch.value % self.batch_interval == 0:
37
+ self._save_checkpoint(state, logger)
38
+
39
+ def fit_end(self, state: State, logger: Logger) -> None:
40
+ if state.timestamp.batch.value % self.batch_interval != 0:
41
+ self._save_checkpoint(state, logger)
42
+
43
+ def _save_checkpoint(self, state: State, logger: Logger) -> None:
44
+ del logger
45
+ filename = format_name_with_dist_and_time(self.filename_format_str, state.run_name, state.timestamp)
46
+ save_dir = format_name_with_dist_and_time(self.save_dir_format_str, state.run_name, state.timestamp)
47
+ dir_context_mgr = tempfile.TemporaryDirectory() if self.upload_to_object_store else contextlib.nullcontext(enter_result=save_dir)
48
+ with dir_context_mgr as temp_save_dir:
49
+ assert isinstance(temp_save_dir, str)
50
+ save_path = str(Path(temp_save_dir) / Path(filename))
51
+ dirname = os.path.dirname(save_path)
52
+ if dirname:
53
+ os.makedirs(dirname, exist_ok=True)
54
+ state_dict = {'state': state.state_dict(), 'rng': reproducibility.get_rng_state()}
55
+ state_dict['state'].pop('optimizers')
56
+ state_dict['state'].pop('model')
57
+ with fsdp_state_dict_type_context(state.model, state_dict_type='full'):
58
+ state_dict['state']['model'] = state.model.state_dict()
59
+ if self.keep_optimizers:
60
+ optimizer = state.optimizers[0]
61
+ state_dict['state']['optimizers'] = {type(optimizer).__qualname__: fsdp_get_optim_state_dict(state.model, optimizer, state_dict_type='full')}
62
+ if dist.get_global_rank() == 0:
63
+ torch.save(state_dict, save_path)
64
+ if self.upload_to_object_store and self.remote_ud is not None and (dist.get_global_rank() == 0):
65
+ remote_file_name = str(Path(save_dir) / Path(filename))
66
+ self.remote_ud.upload_file(state=state, remote_file_name=remote_file_name, file_path=Path(save_path), overwrite=self.overwrite)
mosaicml_logger_utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Any, Dict, List, Optional, Union
4
+ _MODEL_KEYS_TO_LOG = ['pretrained_model_name_or_path', 'pretrained', 'vocab_size', 'd_model', 'n_heads', 'n_layers', 'expansion_ratio', 'max_seq_len']
5
+
6
+ def maybe_create_mosaicml_logger() -> Optional[MosaicMLLogger]:
7
+ """Creates a MosaicMLLogger if the run was sent from the Mosaic platform."""
8
+ if os.environ.get(MOSAICML_PLATFORM_ENV_VAR, 'false').lower() == 'true' and os.environ.get(MOSAICML_ACCESS_TOKEN_ENV_VAR):
9
+ return MosaicMLLogger()
10
+
11
+ def find_mosaicml_logger(loggers: List[LoggerDestination]) -> Optional[MosaicMLLogger]:
12
+ """Returns the first MosaicMLLogger from a list, and None otherwise."""
13
+ return next((logger for logger in loggers if isinstance(logger, MosaicMLLogger)), None)
14
+
15
+ def log_eval_analytics(mosaicml_logger: MosaicMLLogger, model_configs: ListConfig, icl_tasks: Union[str, ListConfig], eval_gauntlet_config: Optional[Union[str, DictConfig]]):
16
+ """Logs analytics for runs using the `eval.py` script."""
17
+ metrics: Dict[str, Any] = {'llmfoundry/script': 'eval'}
18
+ metrics['llmfoundry/gauntlet_configured'] = eval_gauntlet_config is not None
19
+ metrics['llmfoundry/icl_configured'] = isinstance(icl_tasks, str) or len(icl_tasks) > 0
20
+ metrics['llmfoundry/model_configs'] = []
21
+ for model_config in model_configs:
22
+ nested_model_config = model_config.get('model', {})
23
+ model_config_data = {}
24
+ for key in _MODEL_KEYS_TO_LOG:
25
+ if nested_model_config.get(key, None) is not None:
26
+ model_config_data[key] = nested_model_config.get(key)
27
+ if len(model_config_data) > 0:
28
+ metrics['llmfoundry/model_configs'].append(json.dumps(model_config_data, sort_keys=True))
29
+ mosaicml_logger.log_metrics(metrics)
30
+ mosaicml_logger._flush_metadata(force_flush=True)
31
+
32
+ def log_train_analytics(mosaicml_logger: MosaicMLLogger, model_config: DictConfig, train_loader_config: DictConfig, eval_loader_config: Optional[Union[DictConfig, ListConfig]], callback_configs: Optional[DictConfig], tokenizer_name: str, load_path: Optional[str], icl_tasks_config: Optional[Union[ListConfig, str]], eval_gauntlet: Optional[Union[DictConfig, str]]):
33
+ """Logs analytics for runs using the `train.py` script."""
34
+ train_loader_dataset = train_loader_config.get('dataset', {})
35
+ metrics: Dict[str, Any] = {'llmfoundry/tokenizer_name': tokenizer_name, 'llmfoundry/script': 'train', 'llmfoundry/train_loader_name': train_loader_config.get('name')}
36
+ if callback_configs is not None:
37
+ metrics['llmfoundry/callbacks'] = [name for name, _ in callback_configs.items()]
38
+ metrics['llmfoundry/gauntlet_configured'] = eval_gauntlet is not None
39
+ metrics['llmfoundry/icl_configured'] = icl_tasks_config is not None and (isinstance(icl_tasks_config, str) or len(icl_tasks_config) > 0)
40
+ if train_loader_dataset.get('hf_name', None) is not None:
41
+ metrics['llmfoundry/train_dataset_hf_name'] = train_loader_dataset.get('hf_name', None)
42
+ if train_loader_config.get('name') == 'finetuning':
43
+ metrics['llmfoundry/train_task_type'] = 'INSTRUCTION_FINETUNE'
44
+ elif train_loader_config.get('name') == 'text':
45
+ if load_path is not None or model_config.get('pretrained') == True:
46
+ metrics['llmfoundry/train_task_type'] = 'CONTINUED_PRETRAIN'
47
+ else:
48
+ metrics['llmfoundry/train_task_type'] = 'PRETRAIN'
49
+ if eval_loader_config is not None:
50
+ metrics['llmfoundry/eval_loaders'] = []
51
+ if isinstance(eval_loader_config, ListConfig):
52
+ eval_loader_configs: ListConfig = eval_loader_config
53
+ else:
54
+ eval_loader_configs = ListConfig([eval_loader_config])
55
+ for loader_config in eval_loader_configs:
56
+ eval_loader_info = {}
57
+ eval_loader_dataset = loader_config.get('dataset', {})
58
+ eval_loader_info['name'] = loader_config.get('name')
59
+ if eval_loader_dataset.get('hf_name', None) is not None:
60
+ eval_loader_info['dataset_hf_name'] = eval_loader_dataset.get('hf_name')
61
+ metrics['llmfoundry/eval_loaders'].append(json.dumps(eval_loader_info, sort_keys=True))
62
+ model_config_data = {}
63
+ for key in _MODEL_KEYS_TO_LOG:
64
+ if model_config.get(key, None) is not None:
65
+ model_config_data[f'llmfoundry/{key}'] = model_config.get(key)
66
+ if len(model_config_data) > 0:
67
+ metrics.update(model_config_data)
68
+ mosaicml_logger.log_metrics(metrics)
69
+ mosaicml_logger._flush_metadata(force_flush=True)
mpt.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .configuration_mpt import MPTConfig
2
+ from .modeling_mpt import ComposerMPTCausalLM, MPTForCausalLM, MPTModel, MPTPreTrainedModel
packing.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import tempfile
3
+ from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple
4
+ import numpy as np
5
+ import torch
6
+ from transformers import PreTrainedTokenizerBase
7
+ log = logging.getLogger(__name__)
8
+
9
+ class BinPackCollator:
10
+ """Utility collator for packing to reduce padding."""
11
+
12
+ def __init__(self, collator: Callable, target_batch_size: int, max_seq_len: int, pad_token_id: int, padding_side: Literal['left', 'right'], max_leftover_bins_to_keep: Optional[int]=None):
13
+ self.base_collator = collator
14
+ self.out_size = int(target_batch_size)
15
+ self.max_seq_len = int(max_seq_len)
16
+ self.pad_token_id = int(pad_token_id)
17
+ self.padding_side = padding_side
18
+ if self.out_size <= 0:
19
+ raise ValueError(f'target_batch_size={target_batch_size!r} must be >0.')
20
+ if self.max_seq_len <= 0:
21
+ raise ValueError(f'max_seq_len={max_seq_len!r} must be >0.')
22
+ if self.pad_token_id < 0:
23
+ raise ValueError(f'pad_token_id={pad_token_id!r} must be >=0.')
24
+ if max_leftover_bins_to_keep is not None and max_leftover_bins_to_keep < 0:
25
+ raise ValueError(f'max_leftover_bins_to_keep={max_leftover_bins_to_keep!r} must be >=0 or None.')
26
+ self.max_leftover_bins_to_keep = max_leftover_bins_to_keep
27
+ self.n_packed_tokens = 0
28
+ self.n_total_tokens = 0
29
+ self.n_packed_examples = 0
30
+ self._leftover_bins: List[Tuple[int, Dict[str, torch.Tensor]]] = []
31
+
32
+ @property
33
+ def waste(self) -> float:
34
+ return 1 - self.n_packed_tokens / self.n_total_tokens
35
+
36
+ @property
37
+ def efficiency(self) -> float:
38
+ return self.n_packed_tokens / (self.max_seq_len * self.n_packed_examples)
39
+
40
+ def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
41
+ batch = self.base_collator(examples)
42
+ return self.pack(batch)
43
+
44
+ def pack(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
45
+ assert 'attention_mask' in batch
46
+ assert 'input_ids' in batch
47
+ for key in batch.keys():
48
+ assert key in ['input_ids', 'labels', 'attention_mask', 'sequence_id']
49
+ sizes, trimmed_examples = _trim_batch(batch)
50
+ return self._pack_trimmed_examples(trimmed_examples, sizes)
51
+
52
+ def _pack_trimmed_examples(self, trimmed_examples: List[Dict[str, torch.Tensor]], sizes: List[int]) -> Dict[str, torch.Tensor]:
53
+ """Packs trimmed examples into fixed-size bins and repads them.
54
+
55
+ Args:
56
+ trimmed_examples (List[Dict[str, torch.Tensor]]): A list of trimmed examples.
57
+ sizes (List[int]): The sizes of the trimmed examples.
58
+
59
+ Returns:
60
+ Dict[str, torch.Tensor]: A batch of repadded examples ready for processing
61
+ """
62
+ packed_examples, n_packed_tokens, n_total_tokens, leftover_bins = _first_fit_bin_packing(sizes=sizes, examples=trimmed_examples, num_bins=self.out_size, max_bin_size=self.max_seq_len, existing_bins=self._leftover_bins)
63
+ self.n_packed_tokens += n_packed_tokens
64
+ self.n_total_tokens += n_total_tokens
65
+ self.n_packed_examples += self.out_size
66
+ self._leftover_bins = leftover_bins[:self.max_leftover_bins_to_keep]
67
+ batch = _repad(packed_examples, max_seq_len=self.max_seq_len, pad_token_id=self.pad_token_id, padding_side=self.padding_side)
68
+ return batch
69
+
70
+ def _trim_batch(batch: Dict[str, torch.Tensor]) -> Tuple[List[int], List[Dict[str, torch.Tensor]]]:
71
+ """Trims padding off all examples in batch.
72
+
73
+ Args:
74
+ batch (Dict[str, torch.Tensor]): Batch of padded data with tensors as values.
75
+
76
+ Returns:
77
+ A tuple with unpadded lengths of examples and a list of each trimmed example from the batch.
78
+ """
79
+ sizes, trimmed_examples = ([], [])
80
+ for idx in range(batch['attention_mask'].shape[0]):
81
+ size, trimmed_example = _extract_trim_batch_idx(batch, idx)
82
+ sizes.append(size)
83
+ trimmed_examples.append(trimmed_example)
84
+ return (sizes, trimmed_examples)
85
+
86
+ def _extract_trim_batch_idx(batch: Dict[str, torch.Tensor], idx: int) -> Tuple[int, Dict[str, torch.Tensor]]:
87
+ example = {k: v[idx] for k, v in batch.items()}
88
+ keep = example['attention_mask'] == 1
89
+ size = int(keep.sum())
90
+ trim_example = {k: v[keep] for k, v in example.items()}
91
+ trim_example['sequence_id'] = torch.zeros_like(trim_example['input_ids'])
92
+ return (size, trim_example)
93
+
94
+ def _combine_in_place(example: Dict[str, torch.Tensor], add_on: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
95
+ if 'labels' in add_on:
96
+ add_on['labels'][0] = -100
97
+ for k in example.keys():
98
+ if k == 'sequence_id':
99
+ example[k] = torch.cat([example[k], add_on[k] + 1 + torch.max(example[k])])
100
+ else:
101
+ example[k] = torch.cat([example[k], add_on[k]])
102
+ return example
103
+
104
+ def _first_fit_bin_packing(sizes: List[int], examples: List[Dict[str, torch.Tensor]], num_bins: int, max_bin_size: int, existing_bins: List[Tuple[int, Dict[str, torch.Tensor]]]) -> Tuple[List[Dict[str, torch.Tensor]], int, int, List[Tuple[int, Dict[str, torch.Tensor]]]]:
105
+ bins: List[Tuple[int, Dict[str, torch.Tensor]]] = existing_bins
106
+ starting_total_bin_sizes = sum([bin_size for bin_size, _ in bins])
107
+ sizes_and_examples = [(size, example) for size, example in zip(sizes, examples)]
108
+ sorted_sizes_and_examples = sorted(sizes_and_examples, key=lambda x: x[0], reverse=True)
109
+ required_num_examples = max(0, num_bins - len(bins))
110
+ num_examples = len(sizes)
111
+ if num_examples < required_num_examples:
112
+ for size, example in sorted_sizes_and_examples:
113
+ bins.append((size, example))
114
+ total_bin_sizes = sum([bin_size for bin_size, _ in bins])
115
+ total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes
116
+ total_example_sizes = sum(sizes)
117
+ if total_new_bin_sizes != total_example_sizes:
118
+ raise AssertionError(f'Error in packing. total_example_sizes={total_example_sizes!r} does not equal total_new_bin_sizes={total_new_bin_sizes!r}.')
119
+ sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True)
120
+ bin_sizes, packed_examples = ([], [])
121
+ for bin_size, packed_example in sorted_bins:
122
+ bin_sizes.append(bin_size)
123
+ packed_examples.append(packed_example)
124
+ return (packed_examples[:num_bins], sum(bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:])
125
+ for i, (size, example) in enumerate(sorted_sizes_and_examples):
126
+ required_num_examples = max(0, num_bins - len(bins))
127
+ n_remaining = num_examples - i
128
+ assert n_remaining >= required_num_examples
129
+ if n_remaining == required_num_examples:
130
+ bins.append((size, example))
131
+ continue
132
+ added = False
133
+ for bidx in range(len(bins)):
134
+ if bins[bidx][0] + size <= max_bin_size:
135
+ bin_size, packed_example = bins.pop(bidx)
136
+ bin_size = bin_size + size
137
+ packed_example = _combine_in_place(packed_example, example)
138
+ bins.append((bin_size, packed_example))
139
+ added = True
140
+ break
141
+ if not added:
142
+ bins.append((size, example))
143
+ total_bin_sizes = sum([bin_size for bin_size, _ in bins])
144
+ total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes
145
+ total_example_sizes = sum(sizes)
146
+ if total_new_bin_sizes != total_example_sizes:
147
+ raise AssertionError(f'Error in packing. total_example_sizes={total_example_sizes!r} does not equal total_new_bin_sizes={total_new_bin_sizes!r}.')
148
+ sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True)
149
+ bin_sizes, packed_examples = ([], [])
150
+ for bin_size, packed_example in sorted_bins:
151
+ bin_sizes.append(bin_size)
152
+ packed_examples.append(packed_example)
153
+ return (packed_examples[:num_bins], sum(bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:])
154
+
155
+ def _repad(packed_examples: List[Dict[str, torch.Tensor]], max_seq_len: int, pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]:
156
+
157
+ def pad_tensor(tensor: torch.Tensor, pad_value: int):
158
+ if len(tensor) == max_seq_len:
159
+ return tensor
160
+ t = torch.full((max_seq_len,), pad_value, dtype=tensor.dtype, device=tensor.device)
161
+ if padding_side == 'left':
162
+ t[-len(tensor):] = tensor
163
+ elif padding_side == 'right':
164
+ t[:len(tensor)] = tensor
165
+ else:
166
+ raise ValueError(f'Unknown padding_side={padding_side!r}')
167
+ return t
168
+ pad_vals = {'input_ids': pad_token_id, 'labels': -100, 'attention_mask': 0, 'sequence_id': -1}
169
+ keys = packed_examples[0].keys()
170
+ batch = {}
171
+ for key in keys:
172
+ batch[key] = torch.stack([pad_tensor(example[key], pad_vals[key]) for example in packed_examples])
173
+ return batch
174
+
175
+ def auto_packing_ratio(dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int, num_packing_ratios: int=20) -> float:
176
+ """Find a packing ratio that minimizes padding with zero waste.
177
+
178
+ By packing examples, we can increase training efficiency, training on more data with less batches.
179
+ However, in practice, the selected packing_ratio may produce some waste because profiling is done on only
180
+ a subset of the dataset.
181
+
182
+ We select a min_ratio of 1 and a max_ratio that is the max_seq_len / 100, and profile up to
183
+ num_packing_ratios packing ratios between min_ratio and max_ratio, inclusive.
184
+ When a packing_ratio with non-zero waste is found, we stop and select the previous ratio,
185
+ which has zero waste.
186
+
187
+ Args:
188
+ dataloader_cfg (DictConfig): The dataloader configuration for profiling.
189
+ tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling.
190
+ device_batch_size (int): The size of the batches (number of examples) per device.
191
+ num_packing_ratio (int): The number of packing ratios to try.
192
+
193
+ Returns:
194
+ A packing ratio that minimizes padding while maintaining zero waste.
195
+ """
196
+ rng_state = reproducibility.get_rng_state()
197
+ reproducibility.seed_all(0)
198
+ max_seq_len = dataloader_cfg.dataset.max_seq_len
199
+ if max_seq_len <= 100:
200
+ return 1
201
+ min_ratio = 1
202
+ max_ratio = max_seq_len / 100
203
+ profiling_results = profile_packing(dataloader_cfg, tokenizer, min_ratio, max_ratio, num_packing_ratios, device_batch_size)
204
+ packing_ratio = 1
205
+ for packing_ratio_candidate, _, waste in profiling_results:
206
+ if waste is None or waste > 0:
207
+ break
208
+ packing_ratio = packing_ratio_candidate
209
+ if dist.is_available() and dist.is_initialized():
210
+ device = get_device(None)
211
+ packing_ratio_tensor = device.tensor_to_device(torch.tensor(packing_ratio))
212
+ dist.all_reduce(packing_ratio_tensor, reduce_operation='MIN')
213
+ packing_ratio = packing_ratio_tensor.item()
214
+ reproducibility.load_rng_state(rng_state)
215
+ return packing_ratio
216
+
217
+ def profile_packing(dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, min_ratio: float, max_ratio: float, num_packing_ratios: int, device_batch_size: int) -> Iterable[Tuple[float, Optional[float], Optional[float]]]:
218
+ """Generator function that profiles example packing across packing ratios.
219
+
220
+ Args:
221
+ dataloader_cfg (DictConfig): The dataloader configuration for profiling.
222
+ tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling.
223
+ min_ratio (float): Smallest packing_ratio to test. Must be >=1.
224
+ max_ratio (float): Largest packing_ratio to test. Must be larger than `min_ratio`.
225
+ num_packing_ratios (int): Number of packing_ratio values (spaced between `min_ratio` and `max_ratio`) to try.
226
+ device_batch_size (int): The size of the batches (number of examples) per device.
227
+
228
+ Returns:
229
+ An iterable of tuples of packing ratio, padding, and waste, sorted by smallest to largest packing ratio.
230
+ """
231
+ import copy
232
+ from .dataloader import build_dataloader
233
+ max_seq_len = dataloader_cfg.dataset.get('max_seq_len')
234
+ max_leftovers_to_keep = dataloader_cfg.dataset.get('max_leftovers_to_keep', None)
235
+ dataloader_cfg = copy.deepcopy(dataloader_cfg)
236
+ dataloader_cfg.dataset.packing_ratio = 1.0
237
+ dataloader_cfg.drop_last = False
238
+ dataloader_cfg.num_workers = 0
239
+ dataloader_cfg.prefetch_factor = None
240
+ dataloader_cfg.persistent_workers = False
241
+ if dataloader_cfg.dataset.get('remote') is not None:
242
+ dataloader_cfg.dataset.local = tempfile.TemporaryDirectory().name
243
+ packing_ratios, raw_batch_sizes = ([], [])
244
+ for packing_ratio in np.linspace(min_ratio, max_ratio, num_packing_ratios, endpoint=True):
245
+ packing_ratio = np.round(10 * packing_ratio) / 10
246
+ raw_batch_size = int(packing_ratio * device_batch_size)
247
+ if raw_batch_size not in raw_batch_sizes:
248
+ packing_ratios.append(packing_ratio)
249
+ raw_batch_sizes.append(raw_batch_size)
250
+ n_profile_examples = max(raw_batch_sizes) * 100
251
+ train_dataspec = build_dataloader(dataloader_cfg, tokenizer, n_profile_examples)
252
+ train_dataloader = train_dataspec.dataloader
253
+ big_batch = next(iter(train_dataloader))
254
+ sizes, trimmed_examples = _trim_batch(big_batch)
255
+
256
+ def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]:
257
+ trimmed_examples_copy = [te.copy() for te in trimmed_examples]
258
+ packer = BinPackCollator(collator=lambda x: x, target_batch_size=device_batch_size, max_seq_len=max_seq_len, pad_token_id=0, padding_side='left', max_leftover_bins_to_keep=max_leftovers_to_keep)
259
+ for idx in range(0, len(trimmed_examples_copy), raw_batch_size):
260
+ batch = trimmed_examples_copy[idx:idx + raw_batch_size]
261
+ if len(batch) < device_batch_size:
262
+ continue
263
+ packer._pack_trimmed_examples(batch, sizes[idx:idx + raw_batch_size])
264
+ if packer.n_packed_examples == 0:
265
+ log.debug('No examples packed during profiling. Dataset is smaller than device batch size.')
266
+ return (None, None)
267
+ padding_percent = 100 * (1 - packer.efficiency)
268
+ waste_percent = 100 * packer.waste
269
+ return (padding_percent, waste_percent)
270
+ for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes):
271
+ padding, waste = profile(raw_batch_size)
272
+ yield (packing_ratio, padding, waste)
param_init_fns.py CHANGED
@@ -22,9 +22,9 @@ def fused_init_helper_(module: nn.Module, init_fn_: Callable) -> None:
22
  if _fused is None:
23
  raise RuntimeError(f'Internal logic error')
24
  assert isinstance(module.weight, torch.Tensor)
25
- (dim, splits) = _fused
26
  splits = (0, *splits, module.weight.size(dim))
27
- for (s, e) in zip(splits[:-1], splits[1:]):
28
  slice_indices = [slice(None)] * module.weight.ndim
29
  slice_indices[dim] = slice(s, e)
30
  init_fn_(module.weight[slice_indices])
@@ -71,7 +71,7 @@ def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int,
71
  if lim == 0:
72
  warnings.warn(f'Embedding layer initialized to 0.')
73
  lim = [-lim, lim]
74
- (a, b) = lim
75
  emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
76
  else:
77
  emb_init_fn_ = init_fn_
@@ -88,7 +88,7 @@ def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int,
88
  assert d_model is not None
89
  _d = d_model
90
  splits = (0, _d, 2 * _d, 3 * _d)
91
- for (s, e) in zip(splits[:-1], splits[1:]):
92
  init_fn_(module.in_proj_weight[s:e])
93
  else:
94
  assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)
 
22
  if _fused is None:
23
  raise RuntimeError(f'Internal logic error')
24
  assert isinstance(module.weight, torch.Tensor)
25
+ dim, splits = _fused
26
  splits = (0, *splits, module.weight.size(dim))
27
+ for s, e in zip(splits[:-1], splits[1:]):
28
  slice_indices = [slice(None)] * module.weight.ndim
29
  slice_indices[dim] = slice(s, e)
30
  init_fn_(module.weight[slice_indices])
 
71
  if lim == 0:
72
  warnings.warn(f'Embedding layer initialized to 0.')
73
  lim = [-lim, lim]
74
+ a, b = lim
75
  emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
76
  else:
77
  emb_init_fn_ = init_fn_
 
88
  assert d_model is not None
89
  _d = d_model
90
  splits = (0, _d, 2 * _d, 3 * _d)
91
+ for s, e in zip(splits[:-1], splits[1:]):
92
  init_fn_(module.in_proj_weight[s:e])
93
  else:
94
  assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)
prompt_files.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Optional
3
+ PROMPTFILE_PREFIX = 'file::'
4
+
5
+ def load_prompts(prompts: List[str], prompt_delimiter: Optional[str]=None) -> List[str]:
6
+ """Loads a set of prompts, both free text and from file.
7
+
8
+ Args:
9
+ prompts (List[str]): List of free text prompts and prompt files
10
+ prompt_delimiter (Optional str): Delimiter for text file
11
+ If not provided, assumes the prompt file is a single prompt (non-delimited)
12
+
13
+ Returns:
14
+ List of prompt string(s)
15
+ """
16
+ prompt_strings = []
17
+ for prompt in prompts:
18
+ if prompt.startswith(PROMPTFILE_PREFIX):
19
+ prompts = load_prompts_from_file(prompt, prompt_delimiter)
20
+ prompt_strings.extend(prompts)
21
+ else:
22
+ prompt_strings.append(prompt)
23
+ return prompt_strings
24
+
25
+ def load_prompts_from_file(prompt_path: str, prompt_delimiter: Optional[str]=None) -> List[str]:
26
+ """Load a set of prompts from a text fie.
27
+
28
+ Args:
29
+ prompt_path (str): Path for text file
30
+ prompt_delimiter (Optional str): Delimiter for text file
31
+ If not provided, assumes the prompt file is a single prompt (non-delimited)
32
+
33
+ Returns:
34
+ List of prompt string(s)
35
+ """
36
+ if not prompt_path.startswith(PROMPTFILE_PREFIX):
37
+ raise ValueError(f'prompt_path_str must start with {PROMPTFILE_PREFIX}')
38
+ _, prompt_file_path = prompt_path.split(PROMPTFILE_PREFIX, maxsplit=1)
39
+ prompt_file_path = os.path.expanduser(prompt_file_path)
40
+ if not os.path.isfile(prompt_file_path):
41
+ raise FileNotFoundError(f'prompt_file_path={prompt_file_path!r} does not match any existing files.')
42
+ with open(prompt_file_path, 'r') as f:
43
+ prompt_string = f.read()
44
+ if prompt_delimiter is None:
45
+ return [prompt_string]
46
+ return [i for i in prompt_string.split(prompt_delimiter) if i]
registry.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Type
2
+ from torch.optim import Optimizer
3
+ from torchmetrics import Metric
4
+ from transformers import PreTrainedTokenizerBase
5
+ from .interfaces import CallbackWithConfig
6
+ from .registry_utils import create_registry
7
+ _loggers_description = 'The loggers registry is used to register classes that implement the LoggerDestination interface. ' + 'These classes are used to log data from the training loop, and will be passed to the loggers arg of the Trainer. The loggers ' + 'will be constructed by directly passing along the specified kwargs to the constructor.'
8
+ loggers = create_registry('llmfoundry', 'loggers', generic_type=Type[LoggerDestination], entry_points=True, description=_loggers_description)
9
+ _callbacks_description = 'The callbacks registry is used to register classes that implement the Callback interface. ' + 'These classes are used to interact with the Composer event system, and will be passed to the callbacks arg of the Trainer. ' + 'The callbacks will be constructed by directly passing along the specified kwargs to the constructor.'
10
+ callbacks = create_registry('llmfoundry', 'callbacks', generic_type=Type[Callback], entry_points=True, description=_callbacks_description)
11
+ _callbacks_with_config_description = 'The callbacks_with_config registry is used to register classes that implement the CallbackWithConfig interface. ' + 'These are the same as the callbacks registry, except that they additionally take the full training config as an argument to their constructor.'
12
+ callbacks_with_config = create_registry('llm_foundry.callbacks_with_config', generic_type=Type[CallbackWithConfig], entry_points=True, description=_callbacks_with_config_description)
13
+ _optimizers_description = 'The optimizers registry is used to register classes that implement the Optimizer interface. ' + 'The optimizer will be passed to the optimizers arg of the Trainer. The optimizer will be constructed by directly passing along the ' + 'specified kwargs to the constructor, along with the model parameters.'
14
+ optimizers = create_registry('llmfoundry', 'optimizers', generic_type=Type[Optimizer], entry_points=True, description=_optimizers_description)
15
+ _algorithms_description = 'The algorithms registry is used to register classes that implement the Algorithm interface. ' + 'The algorithm will be passed to the algorithms arg of the Trainer. The algorithm will be constructed by directly passing along the ' + 'specified kwargs to the constructor.'
16
+ algorithms = create_registry('llmfoundry', 'algorithms', generic_type=Type[Algorithm], entry_points=True, description=_algorithms_description)
17
+ _schedulers_description = 'The schedulers registry is used to register classes that implement the ComposerScheduler interface. ' + 'The scheduler will be passed to the schedulers arg of the Trainer. The scheduler will be constructed by directly passing along the ' + 'specified kwargs to the constructor.'
18
+ schedulers = create_registry('llmfoundry', 'schedulers', generic_type=Type[ComposerScheduler], entry_points=True, description=_schedulers_description)
19
+ _models_description = 'The models registry is used to register classes that implement the ComposerModel interface. The model\nconstructor should accept two arguments: an omegaconf DictConfig named `om_model_config` and a PreTrainedTokenizerBase named `tokenizer`.\nNote: This will soon be updated to take in named kwargs instead of a config directly.'
20
+ models = create_registry('llmfoundry', 'models', generic_type=Type[ComposerModel], entry_points=True, description=_models_description)
21
+ _dataloaders_description = 'The dataloaders registry is used to register functions that create a DataSpec. The function should take\na DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec.'
22
+ dataloaders = create_registry('llmfoundry', 'dataloaders', generic_type=Callable[[DictConfig, PreTrainedTokenizerBase, int], DataSpec], entry_points=True, description=_dataloaders_description)
23
+ _metrics_description = 'The metrics registry is used to register classes that implement the torchmetrics.Metric interface.'
24
+ metrics = create_registry('llmfoundry', 'metrics', generic_type=Type[Metric], entry_points=True, description=_metrics_description)
registry_utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import importlib.util
3
+ import os
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+ from typing import Any, Callable, Dict, Generic, Optional, Sequence, Type, TypeVar, Union
7
+ import catalogue
8
+ T = TypeVar('T')
9
+
10
+ class TypedRegistry(catalogue.Registry, Generic[T]):
11
+ """A thin wrapper around catalogue.Registry to add static typing and.
12
+
13
+ descriptions.
14
+ """
15
+
16
+ def __init__(self, namespace: Sequence[str], entry_points: bool=False, description: str='') -> None:
17
+ super().__init__(namespace, entry_points=entry_points)
18
+ self.description = description
19
+
20
+ def __call__(self, name: str, func: Optional[T]=None) -> Callable[[T], T]:
21
+ return super().__call__(name, func)
22
+
23
+ def register(self, name: str, *, func: Optional[T]=None) -> T:
24
+ return super().register(name, func=func)
25
+
26
+ def get(self, name: str) -> T:
27
+ return super().get(name)
28
+
29
+ def get_all(self) -> Dict[str, T]:
30
+ return super().get_all()
31
+
32
+ def get_entry_point(self, name: str, default: Optional[T]=None) -> T:
33
+ return super().get_entry_point(name, default=default)
34
+
35
+ def get_entry_points(self) -> Dict[str, T]:
36
+ return super().get_entry_points()
37
+ S = TypeVar('S')
38
+
39
+ def create_registry(*namespace: str, generic_type: Type[S], entry_points: bool=False, description: str='') -> 'TypedRegistry[S]':
40
+ """Create a new registry.
41
+
42
+ Args:
43
+ namespace (str): The namespace, e.g. "llmfoundry.loggers"
44
+ generic_type (Type[S]): The type of the registry.
45
+ entry_points (bool): Accept registered functions from entry points.
46
+ description (str): A description of the registry.
47
+
48
+ Returns:
49
+ The TypedRegistry object.
50
+ """
51
+ if catalogue.check_exists(*namespace):
52
+ raise catalogue.RegistryError(f'Namespace already exists: {namespace}')
53
+ return TypedRegistry[generic_type](namespace, entry_points=entry_points, description=description)
54
+
55
+ def construct_from_registry(name: str, registry: TypedRegistry, partial_function: bool=True, pre_validation_function: Optional[Union[Callable[[Any], None], type]]=None, post_validation_function: Optional[Callable[[Any], None]]=None, kwargs: Optional[Dict[str, Any]]=None) -> Any:
56
+ """Helper function to build an item from the registry.
57
+
58
+ Args:
59
+ name (str): The name of the registered item
60
+ registry (catalogue.Registry): The registry to fetch the item from
61
+ partial_function (bool, optional): Whether to return a partial function for registered callables. Defaults to True.
62
+ pre_validation_function (Optional[Union[Callable[[Any], None], type]], optional): An optional validation function called
63
+ before constructing the item to return. This should throw an exception if validation fails. Defaults to None.
64
+ post_validation_function (Optional[Callable[[Any], None]], optional): An optional validation function called after
65
+ constructing the item to return. This should throw an exception if validation fails. Defaults to None.
66
+
67
+ Raises:
68
+ ValueError: If the validation functions failed or the registered item is invalid
69
+
70
+ Returns:
71
+ Any: The constructed item from the registry
72
+ """
73
+ if kwargs is None:
74
+ kwargs = {}
75
+ registered_constructor = registry.get(name)
76
+ if pre_validation_function is not None:
77
+ if isinstance(pre_validation_function, type):
78
+ if not issubclass(registered_constructor, pre_validation_function):
79
+ raise ValueError(f'Expected {name} to be of type {pre_validation_function}, but got {type(registered_constructor)}')
80
+ elif isinstance(pre_validation_function, Callable):
81
+ pre_validation_function(registered_constructor)
82
+ else:
83
+ raise ValueError(f'Expected pre_validation_function to be a callable or a type, but got {type(pre_validation_function)}')
84
+ if isinstance(registered_constructor, type) or (callable(registered_constructor) and (not partial_function)):
85
+ constructed_item = registered_constructor(**kwargs)
86
+ elif callable(registered_constructor):
87
+ constructed_item = functools.partial(registered_constructor, **kwargs)
88
+ else:
89
+ raise ValueError(f'Expected {name} to be a class or function, but got {type(registered_constructor)}')
90
+ if post_validation_function is not None:
91
+ post_validation_function(registered_constructor)
92
+ return constructed_item
93
+
94
+ def import_file(loc: Union[str, Path]) -> ModuleType:
95
+ """Import module from a file.
96
+
97
+ Used to run arbitrary python code.
98
+ Args:
99
+ name (str): Name of module to load.
100
+ loc (str / Path): Path to the file.
101
+
102
+ Returns:
103
+ ModuleType: The module object.
104
+ """
105
+ if not os.path.exists(loc):
106
+ raise FileNotFoundError(f'File {loc} does not exist.')
107
+ spec = importlib.util.spec_from_file_location('python_code', str(loc))
108
+ assert spec is not None
109
+ assert spec.loader is not None
110
+ module = importlib.util.module_from_spec(spec)
111
+ try:
112
+ spec.loader.exec_module(module)
113
+ except Exception as e:
114
+ raise RuntimeError(f'Error executing {loc}') from e
115
+ return module
resumption_callbacks.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List
3
+ log = logging.getLogger(__name__)
4
+
5
+ class GlobalLRScaling(Callback):
6
+ """GlobalLRScaling.
7
+
8
+ This callback can be applied upon resuming a model checkpoint. Upon
9
+ fit_start it will multiply the base LR by `lr_scale` and set the WD to be.
10
+
11
+ `wd_pct` * `lr`.
12
+
13
+ Args:
14
+ lr_scale (float): Multiplicative factor to scale LR by
15
+ wd_pct (float): Percentage of LR to set weight decay to.
16
+ """
17
+
18
+ def __init__(self, lr_scale: float, wd_pct: float=0.0):
19
+ self.lr_scale = lr_scale
20
+ self.wd_pct = wd_pct
21
+
22
+ def fit_start(self, state: State, logger: Logger) -> None:
23
+ del logger
24
+ if hasattr(state, 'optimizer') and state.optimizers is None:
25
+ raise Exception('No optimizers defined')
26
+ for optimizer in state.optimizers:
27
+ for group in optimizer.param_groups:
28
+ group['lr'] *= self.lr_scale
29
+ group['weight_decay'] = group['lr'] * self.wd_pct
30
+ if 'initial_lr' in group:
31
+ group['initial_lr'] *= self.lr_scale
32
+ log.info(f"Set LR and WD to {group['lr']}, {group['weight_decay']}")
33
+ for scheduler in state.schedulers:
34
+ scheduler.base_lrs = [self.lr_scale * lr for lr in scheduler.base_lrs]
35
+
36
+ class LayerFreezing(Callback):
37
+ """LayerFreezing.
38
+
39
+ This callback can be applied upon resuming a model checkpoint. Upon
40
+ fit_start it freeze the layers specified in `layer_names`. If using
41
+ activation checkpointing, please set the
42
+ `activation_checkpointing_reentrant` flag in `fsdp_config` to false.
43
+
44
+ Args:
45
+ layer_names (float): Names of layers to freeze.
46
+ """
47
+
48
+ def __init__(self, layer_names: List[str]):
49
+ self.layer_names = set(layer_names)
50
+
51
+ def fit_start(self, state: State, logger: Logger) -> None:
52
+ del logger
53
+ model_layers = set((name for name, _ in state.model.named_parameters()))
54
+ for layer in self.layer_names:
55
+ if layer not in model_layers:
56
+ raise Exception(f'Attempted to freeze layer not found in model: {layer}\nAvailable layers: {model_layers}')
57
+ successful_freeze = False
58
+ for name, p in state.model.named_parameters():
59
+ if p.requires_grad and name in self.layer_names:
60
+ p.requires_grad = False
61
+ log.debug(f'Froze layer: {name}\nParam: {p}')
62
+ successful_freeze = True
63
+ if not successful_freeze:
64
+ raise Exception(f"Tried to run LayerFreezing but didn't freeze any layers")
scheduled_gc_callback.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ from typing import Optional
3
+ import torch
4
+
5
+ def gc_cuda():
6
+ """Garbage collect Torch (CUDA) memory."""
7
+ gc.collect()
8
+ if torch.cuda.is_available():
9
+ torch.cuda.empty_cache()
10
+
11
+ class ScheduledGarbageCollector(Callback):
12
+ """Disable automatic garbage collection and collect garbage at interval.
13
+
14
+ Args:
15
+ batch_interval (int): Number of batches between calls to gc.collect()
16
+ gen_1_batch_interval(int, optional): Number of batches between calls to gc.collect(1)
17
+ eval_keep_disabled (bool): keep gc disabled during eval (default: False)
18
+ """
19
+
20
+ def __init__(self, batch_interval: int, gen_1_batch_interval: Optional[int]=None, eval_keep_disabled: bool=False):
21
+ self.batch_interval = batch_interval
22
+ self.gen_1_batch_interval = gen_1_batch_interval
23
+ self.eval_keep_disabled = eval_keep_disabled
24
+ self.gc_init_state = None
25
+
26
+ def fit_start(self, state: State, logger: Logger) -> None:
27
+ del state, logger
28
+ self.gc_init_state = gc.isenabled()
29
+ gc.disable()
30
+ gc_cuda()
31
+
32
+ def fit_end(self, state: State, logger: Logger) -> None:
33
+ del state, logger
34
+ gc_cuda()
35
+ if self.gc_init_state:
36
+ gc.enable()
37
+ else:
38
+ gc.disable()
39
+
40
+ def before_dataloader(self, state: State, logger: Logger) -> None:
41
+ del logger
42
+ if self.gen_1_batch_interval is not None and state.timestamp.batch.value % self.gen_1_batch_interval == 0:
43
+ gc.collect(1)
44
+ if state.timestamp.batch.value % self.batch_interval == 0:
45
+ gc_cuda()
46
+
47
+ def eval_start(self, state: State, logger: Logger) -> None:
48
+ del state, logger
49
+ gc_cuda()
50
+ if not self.eval_keep_disabled:
51
+ gc.enable()
52
+
53
+ def eval_end(self, state: State, logger: Logger) -> None:
54
+ del state, logger
55
+ if not self.eval_keep_disabled:
56
+ gc.disable()
57
+ gc_cuda()
tasks.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Includes code for task-specific seq-to-seq data formatting.
2
+
3
+ This file provides some templates/examples of preprocessing functions
4
+ that format examples for use in seq-to-seq finetuning tasks.
5
+ These preprocessing functions take individual examples that contain raw
6
+ text and process them into formatted examples.
7
+
8
+ These functions have this basic structure:
9
+
10
+ def preprocessing_fn(example: Dict) -> Dict[str, str]:
11
+ # code to extract prompt/response from `example`
12
+ ...
13
+ return {
14
+ 'prompt': <prompt>,
15
+ 'response': <response>,
16
+ }
17
+
18
+ where `<prompt>` is a placeholder for the prompt text string that you
19
+ extracted from the input example, and '<response>' is a placeholder for
20
+ the response text string.
21
+
22
+ Just to be clear, "prompt" represents the text you would give the model
23
+ at inference time, and "response" represents the text you are training
24
+ it to produce given the prompt.
25
+
26
+ The key requirement of these functions is that they return a dictionary
27
+ with "prompt" and "response" keys, and that the values associated with
28
+ those keys are strings (i.e. text).
29
+ """
30
+ import importlib
31
+ import logging
32
+ import os
33
+ import warnings
34
+ from collections.abc import Mapping
35
+ from functools import partial
36
+ from pathlib import Path
37
+ from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union, cast
38
+ import datasets as hf_datasets
39
+ import huggingface_hub as hf_hub
40
+ import numpy as np
41
+ from streaming import Stream, StreamingDataset
42
+ from transformers import PreTrainedTokenizerBase
43
+ from .collator import _HF_IGNORE_INDEX, stitch_turns_decoder_only, stitch_turns_encoder_decoder
44
+ from .exceptions import ConsecutiveRepeatedChatRolesError, IncorrectMessageKeyQuantityError, InvalidContentTypeError, InvalidFileExtensionError, InvalidLastChatMessageRoleError, InvalidPromptResponseKeysError, InvalidPromptTypeError, InvalidResponseTypeError, InvalidRoleError, NotEnoughChatDataError, TooManyKeysInExampleError, UnableToProcessPromptResponseError, UnknownExampleTypeError
45
+ from .logging_utils import SpecificWarningFilter
46
+ log = logging.getLogger(__name__)
47
+ _ALLOWED_RESPONSE_KEYS = {'response', 'completion'}
48
+ _ALLOWED_PROMPT_KEYS = {'prompt'}
49
+ _ALLOWED_MESSAGES_KEYS = {'messages'}
50
+ _ALLOWED_ROLE_KEYS = {'role'}
51
+ _ALLOWED_CONTENT_KEYS = {'content'}
52
+ _ALLOWED_ROLES = {'user', 'assistant', 'system'}
53
+ _ALLOWED_LAST_MESSAGE_ROLES = {'assistant'}
54
+ DOWNLOADED_FT_DATASETS_DIRPATH = os.path.abspath(os.path.join(os.path.realpath(__file__), os.pardir, os.pardir, os.pardir, '.downloaded_finetuning'))
55
+ SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet']
56
+ PromptResponseDict = Mapping[str, str]
57
+ ChatFormattedDict = Mapping[str, List[Dict[str, str]]]
58
+ Example = Union[PromptResponseDict, ChatFormattedDict]
59
+ ExampleType = Literal['prompt_response', 'chat']
60
+ TokenizedExample = Dict[str, List[Dict[str, List[int]]]]
61
+
62
+ def _get_example_type(example: Example) -> ExampleType:
63
+ """Determines the type of the input example.
64
+
65
+ Args:
66
+ example (Example): The input example, which can be a multi-way chat formatted conversation or an instruction-response pair.
67
+
68
+ Returns:
69
+ ExampleType: The type of the input example, which can be either 'chat' for multi-way chat formatted conversation or 'prompt_response' for instruction-response pair.
70
+
71
+ Raises:
72
+ KeyError: If the example type is unknown.
73
+ """
74
+ if not isinstance(example, Mapping):
75
+ raise TypeError(f'Expected example to be a Mapping, but found {type(example)}')
76
+ if any((allowed_message_key in example for allowed_message_key in _ALLOWED_MESSAGES_KEYS)):
77
+ return 'chat'
78
+ elif any((p in example for p in _ALLOWED_PROMPT_KEYS)) and any((r in example for r in _ALLOWED_RESPONSE_KEYS)):
79
+ return 'prompt_response'
80
+ else:
81
+ raise UnknownExampleTypeError(example)
82
+
83
+ def _is_empty_or_nonexistent(dirpath: str) -> bool:
84
+ """Check if a directory is empty or non-existent.
85
+
86
+ Args:
87
+ dirpath (str): Directory path to check.
88
+
89
+ Returns
90
+ True if directory is empty or non-existent. False otherwise.
91
+ """
92
+ return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0
93
+
94
+ def _get_key(dictionary: Mapping[str, Any], allowed_keys: set[str]):
95
+ if not isinstance(dictionary, Mapping):
96
+ raise TypeError(f'Expected dictionary to be a mapping, but found {type(dictionary)}')
97
+ desired_keys = allowed_keys.intersection(dictionary.keys())
98
+ if len(desired_keys) != 1:
99
+ raise TooManyKeysInExampleError(allowed_keys, desired_keys)
100
+ return list(desired_keys)[0]
101
+
102
+ def _validate_chat_formatted_example(example: ChatFormattedDict):
103
+ if not isinstance(example, Mapping):
104
+ raise TypeError(f'Expected example to be a mapping, but found {type(example)}')
105
+ messages = example[_get_key(example, _ALLOWED_MESSAGES_KEYS)]
106
+ if not isinstance(messages, List):
107
+ raise TypeError(f'Expected messages to be an iterable, but found {type(messages)}')
108
+ if len(messages) <= 1:
109
+ raise NotEnoughChatDataError()
110
+ last_message = messages[-1]
111
+ role_key = _get_key(last_message, _ALLOWED_ROLE_KEYS)
112
+ last_role = last_message[role_key]
113
+ if last_role not in _ALLOWED_LAST_MESSAGE_ROLES:
114
+ raise InvalidLastChatMessageRoleError(last_role, _ALLOWED_LAST_MESSAGE_ROLES)
115
+ last_message_role = None
116
+ for message in messages:
117
+ role_key, content_key = (_get_key(message, _ALLOWED_ROLE_KEYS), _get_key(message, _ALLOWED_CONTENT_KEYS))
118
+ if len(message.keys()) != 2:
119
+ raise IncorrectMessageKeyQuantityError(list(message.keys()))
120
+ if message[role_key] not in _ALLOWED_ROLES:
121
+ raise InvalidRoleError(message[role_key], _ALLOWED_ROLES)
122
+ if not isinstance(message[content_key], str):
123
+ raise InvalidContentTypeError(type(message[content_key]))
124
+ if last_message_role is not None and last_message_role == message[role_key]:
125
+ raise ConsecutiveRepeatedChatRolesError(last_message_role)
126
+ last_message_role = message[role_key]
127
+
128
+ def _slice_chat_formatted_example(example: ChatFormattedDict, tokenizer: PreTrainedTokenizerBase) -> List[Tuple[str, str]]:
129
+ """Slices chat example into a list of templated prompt, response turns.
130
+
131
+ Note: Assistant messages mark the end of chat turns. So there are as many turns as there are
132
+ assistant messages in the chat example.
133
+
134
+ Args:
135
+ example (ChatFormattedDict): The chat example containing the messages.
136
+ tokenizer (PreTrainedTokenizerBase): The tokenizer to apply the chat template.
137
+
138
+ Returns:
139
+ List[Tuple[str, str]]: A list of templated prompt and response string pairs, one pair per chat turn.
140
+
141
+ Raises:
142
+ ValueError: If any chat turn in the example has less than two messages or if the last message is not from the assistant.
143
+ KeyError: If a message does not have a role or content.
144
+ """
145
+ _validate_chat_formatted_example(example)
146
+ messages = example[_get_key(example, _ALLOWED_MESSAGES_KEYS)]
147
+ last_message = messages[-1]
148
+ if last_message['role'] != 'assistant':
149
+ raise InvalidLastChatMessageRoleError(last_message['role'], set(['assistant']))
150
+
151
+ def slice_out_last_turn(messages_through_current_turn: List[Dict[str, str]], conversation_through_previous_turn: str) -> Tuple[str, str]:
152
+ full_conversation = tokenizer.apply_chat_template(messages_through_current_turn, tokenize=False)
153
+ prompt_with_history = tokenizer.apply_chat_template(messages_through_current_turn[:-1], tokenize=False, add_generation_prompt=True)
154
+ if conversation_through_previous_turn != full_conversation[:len(conversation_through_previous_turn)]:
155
+ raise ValueError(f'The full conversation must start with the conversation through the previous turn. conversation_through_previous_turn={conversation_through_previous_turn!r}, full_conversation={full_conversation!r}')
156
+ if conversation_through_previous_turn != prompt_with_history[:len(conversation_through_previous_turn)]:
157
+ raise ValueError(f'The prompt_with_histry must start with the conversation through the previous turn. conversation_through_previous_turn={conversation_through_previous_turn!r}, prompt_with_history={prompt_with_history!r}')
158
+ if prompt_with_history != full_conversation[:len(prompt_with_history)]:
159
+ raise ValueError(f'prompt_with_history must be the first part of the full conversation. prompt_with_history={prompt_with_history!r}, full_conversation={full_conversation!r}')
160
+ prompt = prompt_with_history[len(conversation_through_previous_turn):]
161
+ response = full_conversation[len(prompt_with_history):]
162
+ return (prompt, response)
163
+ templated_prompt_response_turns: List[Tuple[str, str]] = []
164
+ conversation_through_previous_turn = ''
165
+ for idx, message in enumerate(messages):
166
+ if message['role'] == 'assistant':
167
+ prompt, response = slice_out_last_turn(messages[:idx + 1], conversation_through_previous_turn)
168
+ templated_prompt_response_turns.append((prompt, response))
169
+ conversation_through_previous_turn += prompt
170
+ conversation_through_previous_turn += response
171
+ return templated_prompt_response_turns
172
+
173
+ def _tokenize_with_bos_removal(tokenizer: PreTrainedTokenizerBase, text: str, text_target: str) -> Dict[str, List[int]]:
174
+ """Tokenizes the prompt and response using the provided tokenizer.
175
+
176
+ Args:
177
+ tokenizer (PreTrainedTokenizerBase): The tokenizer to use for tokenization.
178
+ text (str): The prompt to tokenize.
179
+ text_target (str): The response to tokenize.
180
+
181
+ Returns:
182
+ Dict[str, List[int]]: The tokenized text and text_target.
183
+ """
184
+ tokenized_sample = tokenizer(text=text, text_target=text_target, padding=False, truncation=False)
185
+ if hasattr(tokenizer, 'add_bos_token') and tokenizer.add_bos_token:
186
+ if tokenizer.bos_token_id is not None and tokenized_sample['labels'][0] == tokenizer.bos_token_id:
187
+ tokenized_sample['labels'] = tokenized_sample['labels'][1:]
188
+ return tokenized_sample
189
+
190
+ def _tokenize_chat_formatted_example(example: ChatFormattedDict, tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
191
+ """Tokenizes a chat-formatted example using the provided tokenizer.
192
+
193
+ Args:
194
+ example (ChatFormattedDict): The chat-formatted example to tokenize.
195
+ tokenizer (PreTrainedTokenizerBase): The tokenizer to use for tokenization.
196
+
197
+ Returns:
198
+ TokenizedExample: The tokenized example.
199
+ """
200
+ return {'turns': [tokenizer(text=prompt, text_target=response, add_special_tokens=False, padding=False, truncation=False) for prompt, response in _slice_chat_formatted_example(example, tokenizer)]}
201
+
202
+ def _tokenize_prompt_response_formatted_example(example: PromptResponseDict, tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
203
+ """Tokenize a formatted example and validate expected keys."""
204
+ example_keys = set(example.keys())
205
+ prompt_keys = example_keys.intersection(_ALLOWED_PROMPT_KEYS)
206
+ response_keys = example_keys.intersection(_ALLOWED_RESPONSE_KEYS)
207
+ if len(prompt_keys) != 1:
208
+ raise TooManyKeysInExampleError(_ALLOWED_PROMPT_KEYS, prompt_keys)
209
+ if len(response_keys) != 1:
210
+ raise TooManyKeysInExampleError(_ALLOWED_RESPONSE_KEYS, response_keys)
211
+ prompt_key = prompt_keys.pop()
212
+ response_key = response_keys.pop()
213
+ prompt = example[prompt_key]
214
+ response = example[response_key]
215
+ if not isinstance(prompt, str):
216
+ raise InvalidPromptTypeError(type(prompt))
217
+ if not isinstance(response, str):
218
+ raise InvalidResponseTypeError(type(response))
219
+ return {'turns': [_tokenize_with_bos_removal(tokenizer=tokenizer, text=prompt, text_target=response)]}
220
+
221
+ def tokenize_formatted_example(example: Example, tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
222
+ """Tokenizes a formatted example using the provided tokenizer.
223
+
224
+ Args:
225
+ example (Example): The input example to be tokenized.
226
+ tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for tokenization.
227
+
228
+ Returns:
229
+ TokenizedExample: The tokenized example.
230
+
231
+ Raises:
232
+ ValueError: If the example format is unknown.
233
+ """
234
+ example_format = _get_example_type(example)
235
+ if example_format == 'chat':
236
+ chat_example = cast(ChatFormattedDict, example)
237
+ return _tokenize_chat_formatted_example(chat_example, tokenizer)
238
+ elif example_format == 'prompt_response':
239
+ prompt_response_example: PromptResponseDict = cast(PromptResponseDict, example)
240
+ return _tokenize_prompt_response_formatted_example(prompt_response_example, tokenizer)
241
+ else:
242
+ raise UnknownExampleTypeError(example)
243
+
244
+ def is_valid_ift_example(max_seq_len: int, target_prompts: str, target_responses: str, decoder_only_format: bool, example: TokenizedExample) -> bool:
245
+ """Check if the example is a valid ift example.
246
+
247
+ This function confirms that none of the ``input_ids`` and ``labels`` fields
248
+ are empty in any of the turns within the example.
249
+
250
+ This function also prepares the final input_ids and labels
251
+ of the (potentially multi-turn) example, using the target settings
252
+ and format, and checks whether they are suitable for training at max_seq_len.
253
+ The example is not valid if (1) after truncation (if necessary),
254
+ the training targets contain no loss-generating tokens, or (2) either the
255
+ input_ids and labels are empty.
256
+
257
+ The token sequences in ``example`` are assumed to not have had
258
+ any padding or truncation applied already.
259
+
260
+ Args:
261
+ max_seq_len (int): Maximum sequence length.
262
+ target_prompts (str): The prompts that are used as targets.
263
+ target_responses (str): The responses that are used as targets.
264
+ decoder_only_format (bool): Whether the data will be formatted
265
+ for a decoder-only model.
266
+ example (Dict): The input example after tokenization, which has
267
+ a list of dicts, each with ``input_ids`` and ``labels`` fields.
268
+
269
+ Returns:
270
+ bool: Indicator of whether the input example is valid
271
+ """
272
+ for turn in example['turns']:
273
+ if len(turn['input_ids']) == 0:
274
+ return False
275
+ if len(turn['labels']) == 0:
276
+ return False
277
+ if decoder_only_format:
278
+ input_ids, labels = stitch_turns_decoder_only(example_turns=example['turns'], target_prompts=target_prompts, target_responses=target_responses)
279
+ else:
280
+ input_ids, labels = stitch_turns_encoder_decoder(example_turns=example['turns'])
281
+ input_ids = input_ids[:max_seq_len]
282
+ labels = labels[:max_seq_len]
283
+ if len(input_ids) == 0:
284
+ return False
285
+ if len([label for label in labels if label != _HF_IGNORE_INDEX]) == 0:
286
+ return False
287
+ return True
288
+
289
+ def _stream_remote_local_validate(remote: Optional[str], local: Optional[str], split: Optional[str]):
290
+ if remote is None or local == remote:
291
+ if local is not None and os.path.isdir(local):
292
+ contents = set(os.listdir(local))
293
+ if split is not None and split not in contents:
294
+ raise ValueError(f'Local directory {local} does not contain split {split}')
295
+
296
+ class StreamingFinetuningDataset(StreamingDataset):
297
+ """Finetuning dataset with flexible tokenization using StreamingDataset.
298
+
299
+ Args:
300
+ tokenizer (Tokenizer): The name of the HuggingFace tokenizer to use to
301
+ tokenize samples.
302
+ streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from,
303
+ which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or
304
+ ``remote``/``local``. Defaults to ``None``.
305
+ local (str): Local dataset directory where shards are cached by split.
306
+ remote (str, optional): Remote path or directory to download the dataset from. If ``None``,
307
+ its data must exist locally. StreamingDataset uses either ``streams`` or
308
+ ``remote``/``local``. Defaults to ``None``.
309
+ split (str, optional): Which dataset split to use, if any. If provided, we stream from/to
310
+ the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``.
311
+ download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
312
+ download_timeout (float): Number of seconds to wait for a shard to download before raising
313
+ an exception. Defaults to ``60``.
314
+ validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
315
+ shards. Defaults to ``None``.
316
+ keep_zip (bool): Whether to keep or delete the compressed form when decompressing
317
+ downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
318
+ `False``.
319
+ epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced across all
320
+ streams. If ``None``, takes its value from the total number of underlying samples.
321
+ Provide this field if you are weighting streams relatively to target a larger or
322
+ smaller epoch size. Defaults to ``None``.
323
+ predownload (int, optional): Target number of samples ahead to download the shards of while
324
+ iterating. If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``.
325
+ cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's
326
+ shard cache. Before downloading a shard, the least recently used resident shard(s) may
327
+ be evicted (deleted from the local cache) in order to stay under the limit. Set to None
328
+ to disable shard eviction. Supports integer bytes as well as string human-readable
329
+ bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None.
330
+ partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
331
+ num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
332
+ resumption. If ``None``, this is interpreted as 64 times the number of physical
333
+ nodes of the initial run if ``shuffle_algo`` is ``py1s`` or ``py2s``, and simply the
334
+ number of physical nodes of the initial run otherwise. Defaults to ``None``.
335
+ batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
336
+ partitioned over the workers. Defaults to ``None``.
337
+ shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
338
+ ``False``.
339
+ shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``.
340
+ shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
341
+ shuffle_block_size (int): Unit of shuffle. If ``None``, its value is calculated as
342
+ ``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to ``None``.
343
+ sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
344
+ Defaults to ``balanced``.
345
+ sampling_granularity (int): When picking samples for a stream's final partial repeat,
346
+ how many samples to pick from the same shard at a time (``1`` for evenly balanced
347
+ across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc).
348
+ Defaults to ``1``.
349
+ batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
350
+ ``per_stream``. Defaults to ``random``.
351
+ """
352
+
353
+ def __init__(self, tokenizer: PreTrainedTokenizerBase, streams: Optional[Sequence[Stream]]=None, local: Optional[str]=None, remote: Optional[str]=None, split: Optional[str]=None, download_retry: int=2, download_timeout: float=60, validate_hash: Optional[str]=None, keep_zip: bool=False, epoch_size: Optional[Union[int, str]]=None, predownload: Optional[int]=None, cache_limit: Optional[Union[int, str]]=None, partition_algo: str='relaxed', num_canonical_nodes: Optional[int]=None, batch_size: Optional[int]=None, shuffle: bool=False, shuffle_algo: str='py1e', shuffle_seed: int=9176, shuffle_block_size: Optional[int]=None, sampling_method: str='balanced', sampling_granularity: int=1, batching_method: str='random', max_seq_len: int=2048, **kwargs: Any):
354
+ if len(kwargs) > 0:
355
+ raise ValueError(f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}')
356
+ if streams is None:
357
+ _stream_remote_local_validate(remote, local, split)
358
+ else:
359
+ for stream in streams:
360
+ _stream_remote_local_validate(stream.remote, stream.local, split)
361
+ super().__init__(streams=streams, local=local, remote=remote, split=split, download_retry=download_retry, download_timeout=download_timeout, validate_hash=validate_hash, keep_zip=keep_zip, epoch_size=epoch_size, predownload=predownload, cache_limit=cache_limit, partition_algo=partition_algo, num_canonical_nodes=num_canonical_nodes, batch_size=batch_size, shuffle=shuffle, shuffle_algo=shuffle_algo, shuffle_seed=shuffle_seed, shuffle_block_size=shuffle_block_size, sampling_method=sampling_method, sampling_granularity=sampling_granularity, batching_method=batching_method)
362
+ self.tokenizer = tokenizer
363
+ self.max_seq_len = max_seq_len
364
+
365
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
366
+ sample = super().__getitem__(idx)
367
+ if 'turns' in sample:
368
+ return sample
369
+ if 'input_ids' in sample:
370
+ if isinstance(sample['input_ids'], bytes):
371
+ sample['input_ids'] = np.frombuffer(sample['input_ids'], dtype=np.int64)[:self.max_seq_len].tolist().copy()
372
+ sample['labels'] = np.frombuffer(sample['labels'], dtype=np.int64)[:self.max_seq_len].tolist().copy()
373
+ elif isinstance(sample['input_ids'], np.ndarray):
374
+ sample['input_ids'] = sample['input_ids'][:self.max_seq_len].tolist().copy()
375
+ sample['labels'] = sample['labels'][:self.max_seq_len].tolist().copy()
376
+ else:
377
+ raise ValueError(f"Expect input_ids to be bytes or numpy.ndarray type, but got {type(sample['input_ids'])}")
378
+ return {'turns': [sample]}
379
+ return tokenize_formatted_example(sample, tokenizer=self.tokenizer)
380
+
381
+ class DatasetConstructor:
382
+
383
+ def __init__(self):
384
+ self._task_preprocessing_registry: Dict[str, Callable] = {}
385
+
386
+ def register(self, *names: str) -> Callable[[Callable], Callable]:
387
+ """Decorator for registering preprocessing functions."""
388
+
389
+ def _register_func(name: str, func: Callable) -> None:
390
+ if name in self._task_preprocessing_registry:
391
+ raise ValueError(f'A tokenization function has already been registered with name={name!r}.')
392
+ self._task_preprocessing_registry[name] = func
393
+ return
394
+
395
+ def wrapper(func: Callable) -> Callable:
396
+ for name in names:
397
+ _register_func(name, func)
398
+ return func
399
+ return wrapper
400
+
401
+ def print_registered_tasks(self) -> None:
402
+ tasks = sorted(self._task_preprocessing_registry.keys())
403
+ log.info('\n'.join(tasks))
404
+
405
+ def get_preprocessing_fn_from_dict(self, mapping: Dict[str, str]) -> Callable[[Dict[str, Any]], Dict[str, str]]:
406
+ """Get a preprocessing function from a dictionary.
407
+
408
+ The dictionary maps column names in the dataset to "prompt" and "response".
409
+ For example,
410
+ ```yaml
411
+ preprocessing_fn:
412
+ prompt: text
413
+ response: summary
414
+ ```
415
+ would map the `text` column as to prompt and the `summary` column as the response.
416
+
417
+ Args:
418
+ mapping (dict): A dictionary mapping column names to "prompt" and "response".
419
+
420
+ Returns:
421
+ Callable: The preprocessing function.
422
+
423
+ Raises:
424
+ ValueError: If the mapping does not have keys "prompt" and "response".
425
+ """
426
+
427
+ def _preprocessor(example: Dict[str, Any]) -> Dict[str, str]:
428
+ if list(mapping.keys()) != ['prompt', 'response']:
429
+ raise InvalidPromptResponseKeysError(mapping, example)
430
+ return {'prompt': example[mapping['prompt']], 'response': example[mapping['response']]}
431
+ return _preprocessor
432
+
433
+ def get_preprocessing_fn_from_str(self, preprocessor: Optional[str], dataset_name: Optional[str]=None) -> Optional[Callable[[Dict[str, Any]], Dict[str, str]]]:
434
+ """Get a preprocessing function from a string.
435
+
436
+ String can be either a registered function or an import path.
437
+
438
+ Args:
439
+ preprocessor (Optional[str]): The name of the preprocessing function, or an import path.
440
+ dataset_name (Optional[str]): The dataset name to look up in the registry.
441
+
442
+ Returns:
443
+ Callable: The preprocessing function or None if not found.
444
+
445
+ Raises:
446
+ ValueError: If the preprocessing function import from the provided string fails.
447
+ """
448
+ if preprocessor is None:
449
+ if dataset_name is None:
450
+ return None
451
+ if dataset_name in self._task_preprocessing_registry:
452
+ log.info(f'Re-formatting dataset with "{dataset_name}" preprocessing function.')
453
+ return self._task_preprocessing_registry[dataset_name]
454
+ else:
455
+ log.info('No preprocessor was supplied and no preprocessing function ' + f'is registered for dataset name "{dataset_name}". No additional ' + 'preprocessing will be applied. If the dataset is already formatted ' + 'correctly, you can ignore this message.')
456
+ return None
457
+ if preprocessor in self._task_preprocessing_registry:
458
+ log.info(f'Re-formatting dataset with "{preprocessor}" preprocessing function.')
459
+ return self._task_preprocessing_registry[preprocessor]
460
+ try:
461
+ import_path, function_name = preprocessor.split(':', maxsplit=1)
462
+ module = importlib.import_module(import_path)
463
+ preprocessing_fn = getattr(module, function_name)
464
+ except Exception as e:
465
+ raise ValueError(f'Failed to import preprocessing function from string = {preprocessor}.') from e
466
+ return preprocessing_fn
467
+
468
+ def build_from_hf(self, dataset_name: str, split: str, safe_load: bool, max_seq_len: int, preprocessing_fn: Optional[Callable[[dict[str, Any]], dict[str, str]]], tokenizer: PreTrainedTokenizerBase, target_prompts: str, target_responses: str, decoder_only_format: bool, hf_kwargs: Dict[str, Any]) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset, hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]:
469
+ """Load a HuggingFace Datasets, preprocess, and tokenize.
470
+
471
+ Note: This function will drop examples where the prompt is longer than the max_seq_len
472
+
473
+ Args:
474
+ cfg (DictConfig): The dataset configuration.
475
+ max_seq_len (int): The maximum sequence length. Examples with prompts longer than this will be dropped.
476
+ tokenizer (Tokenizer): The tokenizer to be used for tokenizing the dataset.
477
+
478
+ Returns:
479
+ Dataset: The tokenized dataset.
480
+ """
481
+ signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_data_prep_completed'
482
+ if dist.get_local_rank() != 0:
483
+ log.debug('Waiting for local_rank 0 to finish data prep')
484
+ with dist.local_rank_zero_download_and_wait(signal_file_path):
485
+ pass
486
+ hf_tokenization_logger = logging.getLogger('transformers.tokenization_utils_base')
487
+ sequence_length_warning_filter = SpecificWarningFilter('Token indices sequence length is longer than the specified maximum sequence length')
488
+ hf_tokenization_logger.addFilter(sequence_length_warning_filter)
489
+ error: Optional[Exception] = None
490
+ filtered_dataset = None
491
+ try:
492
+ if safe_load:
493
+ if not os.path.isdir(dataset_name):
494
+ local_dataset_dir = os.path.join(DOWNLOADED_FT_DATASETS_DIRPATH, dataset_name)
495
+ if _is_empty_or_nonexistent(dirpath=local_dataset_dir):
496
+ hf_hub.snapshot_download(dataset_name, repo_type='dataset', allow_patterns=['*' + ext for ext in SUPPORTED_EXTENSIONS], token=hf_kwargs.get('token', None), revision=hf_kwargs.get('revision', None), local_dir_use_symlinks=False, local_dir=local_dataset_dir)
497
+ if _is_empty_or_nonexistent(dirpath=local_dataset_dir):
498
+ raise InvalidFileExtensionError(dataset_name, SUPPORTED_EXTENSIONS)
499
+ dataset_name = local_dataset_dir
500
+ dataset_name = os.path.abspath(dataset_name)
501
+ dataset_files = [f for _, _, files in os.walk(dataset_name) for f in files]
502
+ if not all((Path(f).suffix in SUPPORTED_EXTENSIONS for f in dataset_files)):
503
+ raise InvalidFileExtensionError(dataset_name, SUPPORTED_EXTENSIONS)
504
+ dataset = hf_datasets.load_dataset(dataset_name, split=split, **hf_kwargs)
505
+
506
+ def dataset_mapper(example: Dict):
507
+ if preprocessing_fn is not None:
508
+ example = preprocessing_fn(example)
509
+ return tokenize_formatted_example(example, tokenizer)
510
+ detected_cpu_count = os.cpu_count() or 1
511
+ detected_cpus_with_margin = detected_cpu_count - 8
512
+ num_cpus_to_use = max(1, detected_cpus_with_margin)
513
+ columns_to_remove = list(dataset[0].keys())
514
+ tokenized_dataset = dataset.map(dataset_mapper, batched=False, remove_columns=columns_to_remove, num_proc=num_cpus_to_use, desc='Tokenizing dataset')
515
+ filtered_dataset = tokenized_dataset.filter(partial(is_valid_ift_example, max_seq_len, target_prompts, target_responses, decoder_only_format), num_proc=num_cpus_to_use, desc='Filtering out long prompts')
516
+ examples_removed = len(tokenized_dataset) - len(filtered_dataset)
517
+ if examples_removed > 0:
518
+ warnings.warn(f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, ' + 'the prompt or response was empty, or the response was all padding tokens.')
519
+ except Exception as e:
520
+ error = e
521
+ if dist.get_local_rank() == 0:
522
+ log.debug('Local rank 0 finished data prep')
523
+ with open(signal_file_path, 'wb') as f:
524
+ f.write(b'local_rank0_completed_data_prep')
525
+ dist.barrier()
526
+ if dist.get_local_rank() == 0:
527
+ os.remove(signal_file_path)
528
+ if error is not None:
529
+ log.error('Error during data prep')
530
+ raise error
531
+ log.debug('All ranks finished data prep')
532
+ hf_tokenization_logger.removeFilter(sequence_length_warning_filter)
533
+ assert filtered_dataset is not None
534
+ return filtered_dataset
535
+
536
+ def build_from_streaming(self, *args: Any, **kwargs: Any) -> StreamingFinetuningDataset:
537
+ return StreamingFinetuningDataset(*args, **kwargs)
538
+ dataset_constructor = DatasetConstructor()
539
+
540
+ @dataset_constructor.register('tatsu-lab/alpaca')
541
+ def alpaca_preprocessing_function(inp: Dict) -> Dict[str, str]:
542
+ """Split out prompt/response from text."""
543
+ try:
544
+ prompt, response = inp['text'].split('### Response:')
545
+ prompt += '### Response:'
546
+ except Exception as e:
547
+ raise UnableToProcessPromptResponseError(inp) from e
548
+ return {'prompt': prompt, 'response': response}
549
+
550
+ @dataset_constructor.register('HuggingFaceH4/databricks_dolly_15k')
551
+ def dolly_preprocessing_function(inp: Dict) -> Dict[str, str]:
552
+ """Format the text string."""
553
+ PROMPT_FORMAT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n'
554
+ try:
555
+ if inp['input'] != '':
556
+ instruction = inp['instruction'] + '\n' + inp['input']
557
+ else:
558
+ instruction = inp['instruction']
559
+ prompt = PROMPT_FORMAT.format(instruction=instruction)
560
+ response = inp['output']
561
+ except Exception as e:
562
+ raise UnableToProcessPromptResponseError(inp) from e
563
+ return {'prompt': prompt, 'response': response}
564
+
565
+ @dataset_constructor.register('bigscience/P3')
566
+ def p3_preprocessing_function(inp: Dict) -> Dict[str, str]:
567
+ """Format the already-split example."""
568
+ return {'prompt': inp['inputs'] + ':', 'response': inp['targets']}
569
+
570
+ @dataset_constructor.register('Muennighoff/P3', 'Muennighoff/flan')
571
+ def muennighoff_tokenize_function(inp: Dict) -> Dict[str, str]:
572
+ """Format the already-split example."""
573
+ try:
574
+ prompt: str = inp['inputs']
575
+ response: str = inp['targets']
576
+ transitions = (' ', '\n', '\t')
577
+ if not (prompt.endswith(transitions) or response.startswith(transitions)):
578
+ response = ' ' + response
579
+ except Exception as e:
580
+ raise UnableToProcessPromptResponseError(inp) from e
581
+ return {'prompt': prompt, 'response': response}
text_data.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Build a StreamingTextDataset dataset and dataloader for training."""
2
+ import os
3
+ from itertools import islice
4
+ from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast
5
+ import numpy as np
6
+ import torch
7
+ import transformers
8
+ from streaming import Stream, StreamingDataset
9
+ from torch.utils.data import DataLoader
10
+ from transformers import PreTrainedTokenizerBase
11
+
12
+ class StreamingTextDataset(StreamingDataset):
13
+ """Generic text dataset using MosaicML's StreamingDataset.
14
+
15
+ Args:
16
+ tokenizer (Tokenizer): HuggingFace tokenizer to
17
+ tokenize samples.
18
+ max_seq_len (int): The max sequence length of each sample.
19
+ streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from,
20
+ which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or
21
+ ``remote``/``local``. Defaults to ``None``.
22
+ remote (str, optional): Remote path or directory to download the dataset from. If ``None``,
23
+ its data must exist locally. StreamingDataset uses either ``streams`` or
24
+ ``remote``/``local``. Defaults to ``None``.
25
+ local (str, optional): Local working directory to download shards to. This is where shards
26
+ are cached while they are being used. Uses a temp directory if not set.
27
+ StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``.
28
+ split (str, optional): Which dataset split to use, if any. If provided, we stream from/to
29
+ the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``.
30
+ download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
31
+ download_timeout (float): Number of seconds to wait for a shard to download before raising
32
+ an exception. Defaults to ``60``.
33
+ validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
34
+ shards. Defaults to ``None``.
35
+ keep_zip (bool): Whether to keep or delete the compressed form when decompressing
36
+ downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
37
+ `False``.
38
+ epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced across all
39
+ streams. If ``None``, takes its value from the total number of underlying samples.
40
+ Provide this field if you are weighting streams relatively to target a larger or
41
+ smaller epoch size. Defaults to ``None``.
42
+ predownload (int, optional): Target number of samples ahead to download the shards of while
43
+ iterating. If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``.
44
+ cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's
45
+ shard cache. Before downloading a shard, the least recently used resident shard(s) may
46
+ be evicted (deleted from the local cache) in order to stay under the limit. Set to None
47
+ to disable shard eviction. Supports integer bytes as well as string human-readable
48
+ bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None.
49
+ partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
50
+ num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
51
+ resumption. If ``None``, this is interpreted as 64 times the number of physical
52
+ nodes of the initial run if ``shuffle_algo`` is ``py1s`` or ``py2s``, and simply the
53
+ number of physical nodes of the initial run otherwise. Defaults to ``None``.
54
+ batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
55
+ partitioned over the workers. Defaults to ``None``.
56
+ shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
57
+ ``False``.
58
+ shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``.
59
+ shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
60
+ shuffle_block_size (int, optional): Unit of shuffle. A canonical node's samples are split
61
+ into blocks of this size, and samples within each block are shuffled. If ``None``, its
62
+ value is calculated as ``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to
63
+ ``None``.
64
+ sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
65
+ Defaults to ``balanced``.
66
+ sampling_granularity (int): When picking samples for a stream's final partial repeat,
67
+ how many samples to pick from the same shard at a time (``1`` for evenly balanced
68
+ across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc).
69
+ Defaults to ``1``.
70
+ batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
71
+ ``per_stream``. Defaults to ``random``.
72
+ """
73
+
74
+ def __init__(self, tokenizer: PreTrainedTokenizerBase, max_seq_len: int, streams: Optional[Sequence[Stream]]=None, remote: Optional[str]=None, local: Optional[str]=None, split: Optional[str]=None, download_retry: int=2, download_timeout: float=60, validate_hash: Optional[str]=None, keep_zip: bool=False, epoch_size: Optional[Union[int, str]]=None, predownload: Optional[int]=None, cache_limit: Optional[Union[int, str]]=None, partition_algo: str='relaxed', num_canonical_nodes: Optional[int]=None, batch_size: Optional[int]=None, shuffle: bool=False, shuffle_algo: str='py1e', shuffle_seed: int=9176, shuffle_block_size: Optional[int]=None, sampling_method: str='balanced', sampling_granularity: int=1, batching_method: str='random', **kwargs: Any):
75
+ if len(kwargs) > 0:
76
+ raise ValueError(f'StreamingTextDataset() got an unexpected keyword argument: {kwargs}')
77
+ if local is not None and (remote is None or local == remote):
78
+ if os.path.isdir(local):
79
+ contents = set(os.listdir(local))
80
+ if split not in contents:
81
+ raise ValueError(f'local directory {local} does not contain split {split}')
82
+ if isinstance(shuffle_block_size, float):
83
+ shuffle_block_size = int(shuffle_block_size)
84
+ super().__init__(streams=streams, remote=remote, local=local, split=split, download_retry=download_retry, download_timeout=download_timeout, validate_hash=validate_hash, keep_zip=keep_zip, epoch_size=epoch_size, predownload=predownload, cache_limit=cache_limit, partition_algo=partition_algo, num_canonical_nodes=num_canonical_nodes, batch_size=batch_size, shuffle=shuffle, shuffle_algo=shuffle_algo, shuffle_seed=shuffle_seed, shuffle_block_size=shuffle_block_size, sampling_method=sampling_method, sampling_granularity=sampling_granularity, batching_method=batching_method)
85
+ self.tokenizer = tokenizer
86
+ self.max_seq_len = max_seq_len
87
+
88
+ def _tokenize(self, text_sample: Mapping) -> Dict[str, List[int]]:
89
+ if self.tokenizer._pad_token is None:
90
+ raise RuntimeError('If tokenizing on-the-fly, tokenizer must have a pad_token_id')
91
+ return self.tokenizer(text_sample['text'], truncation=True, padding='max_length', max_length=self.max_seq_len)
92
+
93
+ def _read_binary_tokenized_sample(self, sample: Dict[str, Any]) -> torch.Tensor:
94
+ return torch.from_numpy(np.frombuffer(sample['tokens'], dtype=np.int64)[:self.max_seq_len].copy())
95
+
96
+ def __getitem__(self, idx: int) -> Union[Dict[str, List[int]], torch.Tensor]:
97
+ sample = super().__getitem__(idx)
98
+ if 'text' in sample:
99
+ token_sample = self._tokenize(sample)
100
+ elif 'tokens' in sample:
101
+ token_sample = self._read_binary_tokenized_sample(sample)
102
+ else:
103
+ raise RuntimeError('StreamingTextDataset needs samples to have a `text` or `tokens` column')
104
+ return token_sample
105
+
106
+ class ConcatenatedSequenceCollatorWrapper:
107
+ """Collator wrapper to add sequence_id to batch."""
108
+
109
+ def __init__(self, base_collator: Callable, eos_token_id: Optional[int]=None, bos_token_id: Optional[int]=None):
110
+ self.base_collator = base_collator
111
+ if eos_token_id is None and bos_token_id is None:
112
+ raise ValueError('Must supply a value for either eos_token_id or bos_token_id, but got None for both.')
113
+ if eos_token_id is not None and bos_token_id is not None:
114
+ raise ValueError('Cannot use *both* EOS and BOS tokens for detecting sequence boundaries. ' + 'Please supply `eos_token_id` if sequences end with an EOS token, or use ' + '`bos_token_id` if sequences start with a BOS token.')
115
+ if eos_token_id is None:
116
+ self.split_token_id = cast(int, bos_token_id)
117
+ self.bos_mode = True
118
+ else:
119
+ self.split_token_id = eos_token_id
120
+ self.bos_mode = False
121
+
122
+ def __call__(self, examples: List[Any]) -> Dict[str, torch.Tensor]:
123
+ batch = self.base_collator(examples)
124
+ batch['sequence_id'] = self.get_sequence_id_from_batch(batch)
125
+ return batch
126
+
127
+ def get_sequence_id_from_batch(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
128
+ is_separator = torch.eq(batch['input_ids'], self.split_token_id)
129
+ cumulative_sep = torch.cumsum(is_separator, dim=1).to(batch['input_ids'].dtype)
130
+ if self.bos_mode:
131
+ return cumulative_sep
132
+ left_zeros = cumulative_sep.new_zeros((cumulative_sep.shape[0], 1))
133
+ return torch.cat([left_zeros, cumulative_sep[:, :-1]], dim=1)
134
+
135
+ def build_streams(dataset_cfg: DictConfig):
136
+ streams_dict = dataset_cfg.pop('streams', None)
137
+ streams = None
138
+ if streams_dict is not None:
139
+ streams = []
140
+ for _, stream in streams_dict.items():
141
+ streams.append(Stream(**stream))
142
+ return streams
143
+
144
+ def build_text_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int) -> DataSpec:
145
+ assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}'
146
+ mlm_probability = cfg.dataset.pop('mlm_probability', None)
147
+ eos_token_id = cfg.dataset.pop('eos_token_id', None)
148
+ bos_token_id = cfg.dataset.pop('bos_token_id', None)
149
+ streams = build_streams(cfg.dataset)
150
+ dataset = StreamingTextDataset(tokenizer=tokenizer, streams=streams, batch_size=device_batch_size, **cfg.dataset)
151
+ collate_fn = transformers.DataCollatorForLanguageModeling(tokenizer=dataset.tokenizer, mlm=mlm_probability is not None, mlm_probability=mlm_probability)
152
+ if eos_token_id is not None or bos_token_id is not None:
153
+ collate_fn = ConcatenatedSequenceCollatorWrapper(base_collator=collate_fn, eos_token_id=eos_token_id, bos_token_id=bos_token_id)
154
+ dl = DataLoader(dataset, collate_fn=collate_fn, batch_size=device_batch_size, drop_last=cfg.drop_last, num_workers=cfg.num_workers, pin_memory=cfg.get('pin_memory', True), prefetch_factor=cfg.get('prefetch_factor', 2), persistent_workers=cfg.get('persistent_workers', True), timeout=cfg.get('timeout', 0))
155
+ token_counting_func = None
156
+ if tokenizer.pad_token_id is not None:
157
+ token_counting_func = get_tokens_per_batch_func()
158
+ return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)
159
+
160
+ def get_tokens_per_batch_func(decoder_only: bool=True) -> Callable[[Batch], int]:
161
+ """Returns a callable that counts the number of tokens in a batch.
162
+
163
+ Args:
164
+ pad_token_id (int): The id of the padding token.
165
+ decoder_only (bool, optional): Whether to expect the batch to just contain ``input_ids`` (decoder only)
166
+ or to also contain ``decoder_input_ids`` (encoder decoder). Defaults to ``True``.
167
+
168
+ Returns:
169
+ Callable[[Batch], int]: A callable that counts the number of tokens in a batch.
170
+ """
171
+
172
+ def get_num_samples_in_batch(batch: Batch) -> int:
173
+ if not isinstance(batch, Mapping) or ('attention_mask' not in batch and 'input_ids' not in batch):
174
+ raise ValueError('get_tokens_per_batch_func() requires a batch with an attention_mask key or an input_ids key')
175
+ if not decoder_only and 'decoder_attention_mask' not in batch:
176
+ raise ValueError('get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_attention_mask key')
177
+ if 'attention_mask' in batch:
178
+ input_ids_tokens = int(torch.sum(batch['attention_mask']).item())
179
+ else:
180
+ input_ids_tokens = batch['input_ids'].numel()
181
+ decoder_input_ids_tokens = 0
182
+ if not decoder_only:
183
+ decoder_input_ids_tokens = int(torch.sum(batch['decoder_attention_mask']).item())
184
+ return input_ids_tokens + decoder_input_ids_tokens
185
+ return get_num_samples_in_batch
186
+ if __name__ == '__main__':
187
+ import argparse
188
+ from .builders import build_tokenizer
189
+ parser = argparse.ArgumentParser()
190
+ parser.add_argument('--tokenizer', type=str, default='EleutherAI/gpt-neox-20b', help='the name of the tokenizer to use')
191
+ parser.add_argument('--local_path', type=str, required=True, help='the path to the local copy of the dataset')
192
+ parser.add_argument('--remote_path', type=str, default=None, help='the path to the remote copy to stream from (optional)')
193
+ parser.add_argument('--split', type=str, default='val', help='which split of the dataset to use')
194
+ parser.add_argument('--max_seq_len', type=int, default=32, help='max sequence length to test')
195
+ args = parser.parse_args()
196
+ if args.remote_path is not None:
197
+ print(f'Reading {args.split} split from {args.local_path} <- streamed from <- {args.remote_path}')
198
+ else:
199
+ print(f'Reading {args.split} split from {args.local_path}')
200
+ cfg = {'name': 'text', 'dataset': {'local': args.local_path, 'remote': args.remote_path, 'split': args.split, 'shuffle': False, 'max_seq_len': args.max_seq_len, 'keep_zip': True}, 'drop_last': False, 'num_workers': 4}
201
+ cfg = om.create(cfg)
202
+ device_batch_size = 2
203
+ tokenizer_name = args.tokenizer
204
+ tokenizer_kwargs = {'model_max_length': args.max_seq_len}
205
+ tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
206
+ loader = build_text_dataloader(cfg, tokenizer, device_batch_size).dataloader
207
+ assert isinstance(loader, DataLoader)
208
+ assert isinstance(loader.dataset, StreamingTextDataset)
209
+ tokenizer = loader.dataset.tokenizer
210
+ for batch_ix, batch in enumerate(islice(loader, 5)):
211
+ print('\n')
212
+ print('#' * 20, f'Batch {batch_ix}', '#' * 20)
213
+ for k, v in batch.items():
214
+ print(k, v.shape, v.dtype)
215
+ for sample_ix, token_sample in enumerate(batch['input_ids']):
216
+ print('-' * 20, f' Sample {sample_ix} ', '-' * 20)
217
+ print(tokenizer.decode(token_sample))
tiktoken.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+ from transformers import PreTrainedTokenizer
4
+ DEFAULT_SYSTEM_PROMPT = 'You are a helpful, respectful and honest assistant. Always answer as helpfully as possible.'
5
+
6
+ @lru_cache()
7
+ def bytes_to_unicode():
8
+ """Returns list of utf-8 byte and a mapping to unicode strings.
9
+
10
+ We specifically avoids mapping to whitespace/control characters the bpe code
11
+ barfs on.
12
+
13
+ The reversible bpe codes work on unicode strings. This means you need a
14
+ large # of unicode characters in your vocab if you want to avoid UNKs. When
15
+ you're at something like a 10B token dataset you end up needing around 5K
16
+ for decent coverage. This is a significant percentage of your normal, say,
17
+ 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and
18
+ unicode strings.
19
+ """
20
+ bs = list(range(ord('!'), ord('~') + 1)) + list(range(ord('¡'), ord('¬') + 1)) + list(range(ord('®'), ord('ÿ') + 1))
21
+ cs = bs[:]
22
+ n = 0
23
+ for b in range(2 ** 8):
24
+ if b not in bs:
25
+ bs.append(b)
26
+ cs.append(2 ** 8 + n)
27
+ n += 1
28
+ cs = [chr(n) for n in cs]
29
+ return dict(zip(bs, cs))
30
+
31
+ class TiktokenTokenizerWrapper(PreTrainedTokenizer):
32
+ """A thin wrapper around tiktoken to make it compatible with Hugging Face.
33
+
34
+ tokenizers.
35
+
36
+ See HuggingFace for further documentation on general tokenizer methods.
37
+ """
38
+ model_input_names = ['input_ids', 'attention_mask']
39
+
40
+ def __init__(self, model_name: Optional[str]=None, encoding_name: Optional[str]=None, add_bos_token: bool=False, add_eos_token: bool=False, use_default_system_prompt: bool=False, unk_token: Optional[str]='<|endoftext|>', eos_token: Optional[str]='<|endoftext|>', bos_token: Optional[str]='<|endoftext|>', pad_token: Optional[str]=None, errors: str='replace', **kwargs: Any):
41
+ """Constructor creates a tiktoken tokenizer to use as the underlying.
42
+
43
+ tokenizer.
44
+
45
+ Args:
46
+ model_name (Optional[str], optional): The name of the model to load from tiktoken. Defaults to None.
47
+ Either model_name or encoding_name must be set, but not both.
48
+ encoding_name (Optional[str], optional): The name of the encoding to load from tiktoken. Defaults to None.
49
+ Either model_name or encoding_name must be set, but not both.
50
+ add_bos_token (bool, optional): Whether to add bos tokens. Defaults to False.
51
+ add_eos_token (bool, optional): Whether to add eos tokens. Defaults to False.
52
+ use_default_system_prompt (bool, optional): Use the default system prompt or not. Defaults to False.
53
+ unk_token (Optional[str], optional): The unk token. Defaults to '<|endoftext|>'.
54
+ eos_token (Optional[str], optional): The eos token. Defaults to '<|endoftext|>'.
55
+ bos_token (Optional[str], optional): The bos token. Defaults to '<|endoftext|>'.
56
+ pad_token (Optional[str], optional): The pad token. Defaults to None.
57
+ errors (str, optional): Paradigm to follow when decoding bytes to UTF-8. See
58
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
59
+ Defaults to `"replace"`.
60
+ """
61
+ try:
62
+ import tiktoken
63
+ except:
64
+ raise ImportError('You need to install tiktoken to use TiktokenTokenizerWrapper.')
65
+ import copyreg
66
+ import functools
67
+ from tiktoken import Encoding
68
+
69
+ def pickle_Encoding(enc: Encoding):
70
+ return (functools.partial(Encoding, enc.name, pat_str=enc._pat_str, mergeable_ranks=enc._mergeable_ranks, special_tokens=enc._special_tokens), ())
71
+ copyreg.pickle(Encoding, pickle_Encoding)
72
+ if model_name is not None and encoding_name is not None:
73
+ raise ValueError('You need to specify either model_name or encoding_name, not both.')
74
+ self.model_name = model_name
75
+ self.encoding_name = encoding_name
76
+ if self.model_name is not None:
77
+ self.encoding = tiktoken.encoding_for_model(self.model_name)
78
+ elif self.encoding_name is not None:
79
+ self.encoding = tiktoken.get_encoding(self.encoding_name)
80
+ else:
81
+ raise ValueError('You need to specify either model_name or encoding_name.')
82
+ self.add_bos_token = add_bos_token
83
+ self.add_eos_token = add_eos_token
84
+ self.use_default_system_prompt = use_default_system_prompt
85
+ self.byte_encoder = bytes_to_unicode()
86
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
87
+ self.errors = errors
88
+ self.decoder: Dict[int, str] = {}
89
+ for i in range(self.encoding.n_vocab):
90
+ try:
91
+ self.encoding.decode_single_token_bytes(i)
92
+ except KeyError:
93
+ continue
94
+ decoding = ''.join([bytes_to_unicode()[ord(char)] for char in self.encoding.decode_single_token_bytes(i).decode('latin-1')])
95
+ self.decoder[i] = decoding
96
+ self.encoder: Dict[str, int] = {}
97
+ for i in range(self.encoding.n_vocab):
98
+ if i in self.decoder:
99
+ self.encoder[self.decoder[i]] = i
100
+ super().__init__(model_name=model_name, encoding_name=encoding_name, add_bos_token=add_bos_token, add_eos_token=add_eos_token, use_default_system_prompt=use_default_system_prompt, unk_token=unk_token, eos_token=eos_token, bos_token=bos_token, pad_token=pad_token, errors=errors, **kwargs)
101
+
102
+ @property
103
+ def vocab_size(self) -> int:
104
+ """Returns vocab size."""
105
+ return self.encoding.n_vocab
106
+
107
+ @property
108
+ def is_fast(self) -> bool:
109
+ return False
110
+
111
+ @property
112
+ def default_chat_template(self):
113
+ """Chat ML Template for User/Assistant.
114
+
115
+ Pinning default Chat ML template in case defaults change.
116
+ """
117
+ template = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not 'system' in messages[0]['role'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_PROMPT' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% if system_message != false %}{{ '<|im_start|>system\n' + system_message.strip() + '<|im_end|>\n'}}{% endif %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}{% else %}{{ '\n' + '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}{% endif %}{% if (add_generation_prompt == true and loop.last) %}{{ '\n' + '<|im_start|>' + 'assistant' + '\n' }}{% endif %}{% endfor %}"
118
+ template = template.replace('USE_DEFAULT_PROMPT', 'true' if self.use_default_system_prompt else 'false')
119
+ template = template.replace('DEFAULT_SYSTEM_PROMPT', DEFAULT_SYSTEM_PROMPT)
120
+ return template
121
+
122
+ def get_vocab(self) -> Dict[str, int]:
123
+ """Returns vocab as a dict."""
124
+ vocab_clone = self.encoder.copy()
125
+ extra_id_index = 0
126
+ candidate_extra_id = f'<extra_id_{extra_id_index}>'
127
+ indices_to_fill_in = {i for i in range(self.vocab_size)} - set(vocab_clone.values())
128
+ for index_to_add in indices_to_fill_in:
129
+ while candidate_extra_id in vocab_clone:
130
+ extra_id_index += 1
131
+ candidate_extra_id = f'<extra_id_{extra_id_index}>'
132
+ vocab_clone[candidate_extra_id] = index_to_add
133
+ return vocab_clone
134
+
135
+ def _tokenize(self, text: str) -> List[str]:
136
+ """Returns a tokenized string."""
137
+ if not isinstance(text, str):
138
+ raise ValueError(f'Expected a string input to _tokenize but got {type(text)}.')
139
+ tokens = [self.decoder[t] for t in self.encoding.encode(text, allowed_special='all')]
140
+ return tokens
141
+
142
+ def _convert_token_to_id(self, token: str) -> Optional[int]:
143
+ """Converts a token (str) in an id using the vocab."""
144
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
145
+
146
+ def _convert_id_to_token(self, index: int) -> Optional[str]:
147
+ """Converts an index (integer) in a token (str) using the vocab."""
148
+ return self.decoder.get(index, '')
149
+
150
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
151
+ """Converts a sequence of tokens (string) in a single string."""
152
+ text = ''.join(tokens)
153
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
154
+ return text
155
+
156
+ def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]]=None) -> List[int]:
157
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
158
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
159
+ output = bos_token_id + token_ids_0 + eos_token_id
160
+ if token_ids_1 is not None:
161
+ output = output + bos_token_id + token_ids_1 + eos_token_id
162
+ return output
163
+
164
+ def get_special_tokens_mask(self, token_ids_0: List[int], token_ids_1: Optional[List[int]]=None, already_has_special_tokens: bool=False) -> List[int]:
165
+ """Retrieves sequence ids from a token list that has no special tokens.
166
+
167
+ Function copied from
168
+ https://github.com/huggingface/transformers/blob/e3a4bd2bee212a2d0fd9f03b27fe7bfc1debe42d/src/transformers/models/gpt2/tokenization_gpt2.py#L265-L295
169
+
170
+ added. This method is called when adding special tokens using the
171
+ tokenizer `prepare_for_model` or `encode_plus` methods.
172
+
173
+ Args:
174
+ token_ids_0 (`List[int]`):
175
+ List of IDs.
176
+ token_ids_1 (`List[int]`, *optional*):
177
+ Optional second list of IDs for sequence pairs.
178
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
179
+ Whether or not the token list is already formatted with special tokens for the model.
180
+
181
+ Returns:
182
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
183
+ """
184
+ if already_has_special_tokens:
185
+ return super().get_special_tokens_mask(token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True)
186
+ bos_token_id = [1] if self.add_bos_token else []
187
+ eos_token_id = [1] if self.add_eos_token else []
188
+ if token_ids_1 is None:
189
+ return bos_token_id + [0] * len(token_ids_0) + eos_token_id
190
+ return bos_token_id + [0] * len(token_ids_0) + eos_token_id + bos_token_id + [0] * len(token_ids_1) + eos_token_id
191
+
192
+ def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]]=None) -> List[int]:
193
+ sep = [self.sep_token_id]
194
+ if token_ids_1 is None:
195
+ return len(token_ids_0 + sep) * [0]
196
+ return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
197
+
198
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str]=None) -> Tuple[str]:
199
+ return (None, None)
200
+
201
+ def sanitize_special_tokens(self) -> int:
202
+ """Make sure that all the special tokens attributes of the tokenizer.
203
+
204
+ (`tokenizer.mask_token`, `tokenizer.cls_token`, etc.) are in the
205
+ vocabulary.
206
+
207
+ Add the missing ones to the vocabulary if needed.
208
+
209
+ Return:
210
+ `int`: The number of tokens added in the vocabulary during the operation.
211
+ """
212
+ actual_new_tokens = []
213
+ for token in self.all_special_tokens_extended:
214
+ encoded = self.encoding.encode(token, allowed_special='all')
215
+ if len(encoded) > 1:
216
+ actual_new_tokens.append(token)
217
+ return self.add_tokens(actual_new_tokens, special_tokens=True)
218
+ TiktokenTokenizerWrapper.register_for_auto_class()
tokenizers.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .tiktoken import TiktokenTokenizerWrapper
utils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .builders import build_algorithm, build_callback, build_logger, build_optimizer, build_scheduler, build_tokenizer
2
+ from .checkpoint_conversion_helpers import convert_and_save_ft_weights, get_hf_tokenizer_from_composer_state_dict, load_tokenizer
3
+ from .config_utils import calculate_batch_size_info, log_config, pop_config, process_init_device, update_batch_size_info
4
+ from .data_prep_utils import DownloadingIterable, merge_shard_groups, with_id
5
+ from .huggingface_hub_utils import edit_files_for_hf_compatibility
6
+ from .logging_utils import SpecificWarningFilter
7
+ from .model_download_utils import download_from_hf_hub, download_from_http_fileserver, download_from_oras
8
+ from .mosaicml_logger_utils import find_mosaicml_logger, log_eval_analytics, log_train_analytics, maybe_create_mosaicml_logger
9
+ from .prompt_files import load_prompts, load_prompts_from_file
10
+ from .registry_utils import TypedRegistry, construct_from_registry, create_registry
11
+ from .warnings import VersionedDeprecationWarning
warnings.py CHANGED
@@ -1,4 +1,8 @@
1
- class VersionedDeprecationWarning(DeprecationWarning):
 
 
 
 
2
  """A custom deprecation warning class that includes version information.
3
 
4
  Attributes:
@@ -10,7 +14,7 @@ class VersionedDeprecationWarning(DeprecationWarning):
10
  ... warnings.warn(
11
  ... VersionedDeprecationWarning(
12
  ... "Function XYZ is deprecated.",
13
- ... after_version="2.0.0"
14
  ... )
15
  ... )
16
  ...
@@ -19,4 +23,49 @@ class VersionedDeprecationWarning(DeprecationWarning):
19
  """
20
 
21
  def __init__(self, message: str, remove_version: str) -> None:
22
- super().__init__(message + f' It will be removed in version {remove_version}.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import warnings
3
+ from typing import Any, Callable, Type, TypeVar, cast
4
+
5
+ class VersionedDeprecationWarning(UserWarning):
6
  """A custom deprecation warning class that includes version information.
7
 
8
  Attributes:
 
14
  ... warnings.warn(
15
  ... VersionedDeprecationWarning(
16
  ... "Function XYZ is deprecated.",
17
+ ... remove_version="2.0.0"
18
  ... )
19
  ... )
20
  ...
 
23
  """
24
 
25
  def __init__(self, message: str, remove_version: str) -> None:
26
+ super().__init__(message + f' It will be removed in version {remove_version}.')
27
+
28
+ class ExperimentalWarning(Warning):
29
+ """A warning for experimental features.
30
+
31
+ Attributes:
32
+ feature_name (str): The name of the experimental feature.
33
+ """
34
+
35
+ def __init__(self, feature_name: str) -> None:
36
+ super().__init__(f'{feature_name} is experimental and may change with future versions.')
37
+ F = TypeVar('F', bound=Callable[..., Any])
38
+
39
+ def experimental_function(feature_name: str) -> Callable[[F], F]:
40
+ """Decorator to mark a function as experimental.
41
+
42
+ The message displayed will be {feature_name} is experimental and may change with future versions.
43
+
44
+ Args:
45
+ feature_name (str): The name of the experimental feature.
46
+
47
+ Returns:
48
+ The decorated function.
49
+ """
50
+
51
+ def decorator(func: Callable):
52
+
53
+ @functools.wraps(func)
54
+ def wrapper(*args: Any, **kwargs: Any):
55
+ warnings.warn(ExperimentalWarning(feature_name))
56
+ return func(*args, **kwargs)
57
+ return cast(F, wrapper)
58
+ return decorator
59
+
60
+ def experimental_class(feature_name: str) -> Callable[[Type], Type]:
61
+ """Class decorator to mark a class as experimental."""
62
+
63
+ def class_decorator(cls: Type):
64
+ original_init = cls.__init__
65
+
66
+ def new_init(self: Any, *args: Any, **kwargs: Any):
67
+ warnings.warn(ExperimentalWarning(feature_name))
68
+ original_init(self, *args, **kwargs)
69
+ cls.__init__ = new_init
70
+ return cls
71
+ return class_decorator