LLM-foundry update March 26, 2024 23:50:31
#73
by
irenedea
- opened
- act_ckpt.py +119 -0
- async_eval_callback.py +322 -0
- attention.py +31 -82
- blocks.py +7 -7
- builders.py +323 -0
- callback_with_config.py +12 -0
- callbacks.py +26 -0
- checkpoint_conversion_helpers.py +206 -0
- collator.py +256 -0
- config_utils.py +105 -0
- configuration_mpt.py +7 -17
- curriculum_learning_callback.py +62 -0
- data.py +76 -0
- data_prep_utils.py +84 -0
- dataloader.py +313 -0
- eval_gauntlet_callback.py +141 -0
- exceptions.py +162 -0
- fdiff_callback.py +44 -0
- ffn.py +2 -2
- finetuning.py +2 -0
- hf.py +3 -0
- hf_causal_lm.py +14 -0
- hf_checkpointer.py +221 -0
- hf_fsdp.py +165 -0
- hf_t5.py +8 -0
- huggingface_hub_utils.py +102 -0
- interfaces.py +1 -0
- llmfoundry.py +16 -0
- logging_utils.py +24 -0
- meta_init_context.py +1 -1
- model_download_utils.py +186 -0
- model_wrapper.py +36 -0
- modeling_mpt.py +66 -87
- monolithic_ckpt_callback.py +66 -0
- mosaicml_logger_utils.py +69 -0
- mpt.py +2 -0
- packing.py +272 -0
- param_init_fns.py +4 -4
- prompt_files.py +46 -0
- registry.py +24 -0
- registry_utils.py +115 -0
- resumption_callbacks.py +64 -0
- scheduled_gc_callback.py +57 -0
- tasks.py +581 -0
- text_data.py +217 -0
- tiktoken.py +218 -0
- tokenizers.py +1 -0
- utils.py +11 -0
- warnings.py +52 -3
act_ckpt.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
import torch
|
3 |
+
from .attention import ATTN_CLASS_REGISTRY
|
4 |
+
from .blocks import MPTBlock
|
5 |
+
from .ffn import FFN_CLASS_REGISTRY
|
6 |
+
from .norm import NORM_CLASS_REGISTRY
|
7 |
+
|
8 |
+
def pass_on_block_idx(parent: torch.nn.Module):
|
9 |
+
if not hasattr(parent, 'block_idx') or not hasattr(parent, 'max_block_idx'):
|
10 |
+
return
|
11 |
+
for child in parent.children():
|
12 |
+
child.block_idx = parent.block_idx
|
13 |
+
child.max_block_idx = parent.max_block_idx
|
14 |
+
if child.children():
|
15 |
+
pass_on_block_idx(child)
|
16 |
+
|
17 |
+
def get_act_ckpt_module(mod_name: str) -> Any:
|
18 |
+
"""Get the module type from the module name."""
|
19 |
+
if mod_name.lower() == 'mptblock':
|
20 |
+
mod_type = MPTBlock
|
21 |
+
elif mod_name in ATTN_CLASS_REGISTRY:
|
22 |
+
mod_type = ATTN_CLASS_REGISTRY[mod_name]
|
23 |
+
elif mod_name in FFN_CLASS_REGISTRY:
|
24 |
+
mod_type = FFN_CLASS_REGISTRY[mod_name]
|
25 |
+
elif mod_name in NORM_CLASS_REGISTRY:
|
26 |
+
mod_type = NORM_CLASS_REGISTRY[mod_name]
|
27 |
+
else:
|
28 |
+
msg = ', '.join(list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) + list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock'])
|
29 |
+
raise ValueError(f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.')
|
30 |
+
return mod_type
|
31 |
+
|
32 |
+
def parse_ele_str(ele: str, max_block_idx: int) -> list:
|
33 |
+
"""Parse a string in target_blocks and return a list of block ids to add.
|
34 |
+
|
35 |
+
Supported formats are: first-n, middle-m, last-k, range-i-j which correspond
|
36 |
+
to the first n, the middle m, the last k, and the range [i, j).
|
37 |
+
"""
|
38 |
+
to_add = None
|
39 |
+
if ele.startswith('first-'):
|
40 |
+
assert ele[6:].isdigit(), f'Invalid target_blocks element {ele}'
|
41 |
+
to_add = list(range(min(int(ele[6:]), max_block_idx + 1)))
|
42 |
+
elif ele.startswith('last-'):
|
43 |
+
assert ele[5:].isdigit(), f'Invalid target_blocks element {ele}'
|
44 |
+
to_add = list(range(max(max_block_idx - int(ele[5:]) + 1, 0), max_block_idx + 1))
|
45 |
+
elif ele.startswith('middle-'):
|
46 |
+
assert ele[7:].isdigit(), f'Invalid target_blocks element {ele}'
|
47 |
+
num = int(ele[7:])
|
48 |
+
start = max(max_block_idx // 2 - num // 2, 0)
|
49 |
+
end = min(start + num, max_block_idx + 1)
|
50 |
+
to_add = list(range(start, end))
|
51 |
+
elif ele.startswith('range-'):
|
52 |
+
r = ele[6:].split('-')
|
53 |
+
assert len(r) == 2, f'Invalid target_blocks element {ele}'
|
54 |
+
start, end = (int(r[0]), int(r[1]))
|
55 |
+
start = max(start, 0)
|
56 |
+
end = min(end, max_block_idx + 1)
|
57 |
+
to_add = list(range(start, end))
|
58 |
+
else:
|
59 |
+
raise ValueError(f'Invalid target_blocks element {ele}')
|
60 |
+
return to_add
|
61 |
+
|
62 |
+
def get_target_block_list(target_blocks: Any, max_block_idx: int) -> list:
|
63 |
+
"""Parse the user input and return a list of block ids."""
|
64 |
+
candidate_block_ids = []
|
65 |
+
if isinstance(target_blocks, int):
|
66 |
+
candidate_block_ids = list(range(target_blocks))
|
67 |
+
elif isinstance(target_blocks, list):
|
68 |
+
for ele in target_blocks:
|
69 |
+
if isinstance(ele, int):
|
70 |
+
candidate_block_ids.append(ele)
|
71 |
+
elif isinstance(ele, str):
|
72 |
+
to_add = parse_ele_str(ele, max_block_idx)
|
73 |
+
candidate_block_ids.extend(to_add)
|
74 |
+
else:
|
75 |
+
raise ValueError(f'target_blocks must be a list of integers or "first-n", "middle-m", "last-k", or "range-i-j" where n, m, k, i, j are integers, but got {target_blocks}')
|
76 |
+
elif isinstance(target_blocks, str):
|
77 |
+
target_blocks = target_blocks.replace(' ', '')
|
78 |
+
for ele in target_blocks.split(','):
|
79 |
+
to_add = parse_ele_str(ele, max_block_idx)
|
80 |
+
candidate_block_ids.extend(to_add)
|
81 |
+
else:
|
82 |
+
raise ValueError(f'target_blocks must be either a single intege, or a list of integers, or a comma separated string made of "first-n", "last-m", "middle-k", "range-i-j", or a list of mixed integers and before-mentioned strings, but got {type(target_blocks)}')
|
83 |
+
candidate_block_ids = list(set(candidate_block_ids))
|
84 |
+
return candidate_block_ids
|
85 |
+
|
86 |
+
def check_mapping_blocks_overlap(mapping: dict, max_block_idx: int) -> None:
|
87 |
+
"""Check if the block ids in the mapping overlap with each other."""
|
88 |
+
all_blocks = [None] * (max_block_idx + 1)
|
89 |
+
for k, v in mapping.items():
|
90 |
+
if v == -1:
|
91 |
+
v = list(range(max_block_idx + 1))
|
92 |
+
for vv in v:
|
93 |
+
if vv < 0 or vv > max_block_idx:
|
94 |
+
continue
|
95 |
+
elif all_blocks[vv] is not None:
|
96 |
+
raise ValueError(f'Block {vv} is assigned to both {k} and {all_blocks[vv]}. Each block can only have one granularity of activation checkpointing. Make sure the target_blocks in activation_checkpointing_target do not overlap. For more details, refer to the docs of activation_checkpointing_fn.')
|
97 |
+
else:
|
98 |
+
all_blocks[vv] = k
|
99 |
+
|
100 |
+
def build_act_ckpt_mod_to_blocks(act_ckpt_target: Any, top_module: Any, max_block_idx: int) -> dict:
|
101 |
+
act_ckpt_mod_to_blocks = {}
|
102 |
+
if act_ckpt_target is None or act_ckpt_target == []:
|
103 |
+
mod = top_module
|
104 |
+
act_ckpt_mod_to_blocks[mod] = -1
|
105 |
+
elif isinstance(act_ckpt_target, str):
|
106 |
+
mod = get_act_ckpt_module(act_ckpt_target)
|
107 |
+
act_ckpt_mod_to_blocks[mod] = -1
|
108 |
+
elif isinstance(act_ckpt_target, list):
|
109 |
+
for target in act_ckpt_target:
|
110 |
+
mod = get_act_ckpt_module(target)
|
111 |
+
act_ckpt_mod_to_blocks[mod] = -1
|
112 |
+
elif isinstance(act_ckpt_target, dict):
|
113 |
+
for k, v in act_ckpt_target.items():
|
114 |
+
mod = get_act_ckpt_module(k)
|
115 |
+
block_ids = get_target_block_list(v, max_block_idx)
|
116 |
+
act_ckpt_mod_to_blocks[mod] = block_ids
|
117 |
+
else:
|
118 |
+
raise ValueError(f'activation_checkpointing_target must be either a single string or a list or a dict, but got {type(act_ckpt_target)}')
|
119 |
+
return act_ckpt_mod_to_blocks
|
async_eval_callback.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Run the eval loop asynchronously as part of a MosaicML platform run.
|
2 |
+
|
3 |
+
This callback is currently experimental. The API may change in the future.
|
4 |
+
"""
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import warnings
|
8 |
+
from collections import Counter
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
11 |
+
from .interfaces import CallbackWithConfig
|
12 |
+
from mcli import Run, RunConfig, create_run, get_run
|
13 |
+
log = logging.getLogger(__name__)
|
14 |
+
REQUIRED_PARAMS_FOR_EVAL = {'device_eval_batch_size', 'icl_tasks', 'max_seq_len', 'model', 'tokenizer'}
|
15 |
+
OPTIONAL_PARAMS_FOR_EVAL = {'dist_timeout', 'eval_gauntlet', 'eval_loader', 'fsdp_config', 'eval_subset_num_batches', 'icl_subset_num_batches', 'loggers', 'precision', 'python_log_level', 'seed'}
|
16 |
+
RUN_NAME_PREFIX = 'eval'
|
17 |
+
MAX_RUN_NAME_BASE_LENGTH = 55
|
18 |
+
|
19 |
+
def get_run_name(training_run_name: str, current_interval: str) -> str:
|
20 |
+
"""Get the new eval run name.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
training_run_name: The name of the current training run
|
24 |
+
current_interval: The current interval string of the training run
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
The new run name
|
28 |
+
"""
|
29 |
+
name_without_uuid_suffix = training_run_name.rsplit('-', 1)[0]
|
30 |
+
max_length = MAX_RUN_NAME_BASE_LENGTH - len(RUN_NAME_PREFIX) - len(current_interval) - 2
|
31 |
+
if len(name_without_uuid_suffix) > max_length:
|
32 |
+
new_name = name_without_uuid_suffix[:max_length]
|
33 |
+
log.warning(f'Training run name {name_without_uuid_suffix} may be too long,' + f' truncating to {new_name}')
|
34 |
+
name_without_uuid_suffix = new_name
|
35 |
+
return f'{RUN_NAME_PREFIX}-{current_interval}-{name_without_uuid_suffix}'
|
36 |
+
|
37 |
+
def get_eval_parameters(parameters: Dict[str, Any], checkpoint: str, training_run_name: str) -> Dict[str, Any]:
|
38 |
+
"""Get the parameters needed for the eval run.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
parameters: The parameters from the training run
|
42 |
+
checkpoint: The path to the latest checkpoint
|
43 |
+
training_run_name: The name of the training run
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
The parameters needed for the eval run as a dict
|
47 |
+
"""
|
48 |
+
looking_for = REQUIRED_PARAMS_FOR_EVAL.copy()
|
49 |
+
subset_keys = {}
|
50 |
+
for key in parameters:
|
51 |
+
if key in OPTIONAL_PARAMS_FOR_EVAL:
|
52 |
+
subset_keys[key] = parameters[key]
|
53 |
+
elif key in REQUIRED_PARAMS_FOR_EVAL:
|
54 |
+
subset_keys[key] = parameters[key]
|
55 |
+
looking_for.remove(key)
|
56 |
+
if looking_for:
|
57 |
+
raise Exception(f'Missing the following required parameters for async eval: {looking_for}')
|
58 |
+
for logger, config in subset_keys.get('loggers', {}).items():
|
59 |
+
if logger == 'wandb':
|
60 |
+
config['group'] = config.pop('name', training_run_name)
|
61 |
+
model = subset_keys.pop('model')
|
62 |
+
model_name = model.get('name', None)
|
63 |
+
if not model_name:
|
64 |
+
raise Exception(f'Async evaluation requires "name" keys for models')
|
65 |
+
new_models = {'model_name': model_name, 'model': model, 'load_path': checkpoint}
|
66 |
+
tokenizer = subset_keys.pop('tokenizer', None)
|
67 |
+
if tokenizer is not None:
|
68 |
+
new_models['tokenizer'] = tokenizer
|
69 |
+
subset_keys['models'] = [new_models]
|
70 |
+
return subset_keys
|
71 |
+
|
72 |
+
def validate_interval(interval: Union[str, int, Time], save_interval: Union[str, int, Time]) -> Time:
|
73 |
+
new_save_interval = Time.from_input(save_interval, TimeUnit.EPOCH)
|
74 |
+
async_interval = Time.from_input(interval, TimeUnit.EPOCH)
|
75 |
+
if new_save_interval.unit != async_interval.unit:
|
76 |
+
raise ValueError('Save interval and async eval interval must be in the same unit')
|
77 |
+
if async_interval < new_save_interval:
|
78 |
+
raise ValueError('Async eval interval must be equal or greater (less frequent) than save interval')
|
79 |
+
if async_interval.value % new_save_interval.value != 0:
|
80 |
+
raise ValueError('Async eval interval must be a multiple of save interval')
|
81 |
+
return async_interval
|
82 |
+
|
83 |
+
def validate_eval_run_config(eval_run_config: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
84 |
+
if not eval_run_config:
|
85 |
+
return {}
|
86 |
+
run_config = eval_run_config.copy()
|
87 |
+
supported_keys = {'image', 'command', 'compute', 'scheduling'}
|
88 |
+
found_unsupported = set()
|
89 |
+
for key in run_config:
|
90 |
+
if key not in supported_keys:
|
91 |
+
found_unsupported.add(key)
|
92 |
+
if found_unsupported:
|
93 |
+
raise ValueError(f"Unsupported eval run config keys found: {', '.join(found_unsupported)}" + f'. Supported keys: {supported_keys}')
|
94 |
+
return run_config
|
95 |
+
CHECKS_PER_INTERVAL = 4
|
96 |
+
|
97 |
+
class AsyncEval(CallbackWithConfig):
|
98 |
+
"""Run the eval loop asynchronously as part of a MosaicML platform run.
|
99 |
+
|
100 |
+
This callback is currently experimental. The API may change in the future.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
training_params: Dict[str, Any]: The parameter config from the training run
|
104 |
+
interval: Union[str, int, Time]: The interval describing how often eval runs should be
|
105 |
+
launched. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`.
|
106 |
+
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
|
107 |
+
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
|
108 |
+
eval_run_config: Optional[Dict[str, Any]]: A subset of mcli run config values to use
|
109 |
+
for the eval run. If not specified, any fields from run config will be created
|
110 |
+
dynamically from the training run config and the interval. The following fields
|
111 |
+
are supported:
|
112 |
+
- ``image``: Image of the eval run. Default: same as training run
|
113 |
+
- ``command``: Command to run for the eval run. Default: calls
|
114 |
+
`composer scripts/eval/eval.py $PARAMETERS`. If custom setup is needed,
|
115 |
+
the command should include calling the eval script with $PARAMETERS
|
116 |
+
- ``compute``: Compute to use for the eval run. Default: same cluster as
|
117 |
+
the training run and a single node (8 GPUs)
|
118 |
+
- ``scheduling``: Scheduling to use for the eval run. Default: same as training run
|
119 |
+
|
120 |
+
All fields are optional, but if specified, must be valid for a mcli run config. We
|
121 |
+
provide this optional config to give you the most flexibility in customizing the eval
|
122 |
+
run, but it is recommended to use the default values unless you have a specific use case
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(self, training_params: Dict[str, Any], interval: Union[str, int, Time], eval_run_config: Optional[Dict[str, Any]]=None):
|
126 |
+
for required in ('save_interval', 'save_folder'):
|
127 |
+
if required not in training_params:
|
128 |
+
raise ValueError(f'{required} required for async eval')
|
129 |
+
if '/' in training_params.get('save_filename', ''):
|
130 |
+
raise ValueError('AsyncEval not supported for save_filename that includes a path')
|
131 |
+
self.checkpoint_save_folder = training_params['save_folder']
|
132 |
+
self.training_params = training_params
|
133 |
+
self.eval_run_config = validate_eval_run_config(eval_run_config)
|
134 |
+
self.current_run = self._get_current_run()
|
135 |
+
get_eval_parameters(parameters=training_params, checkpoint='test', training_run_name=self.current_run.name)
|
136 |
+
self.interval = validate_interval(interval, self.training_params['save_interval'])
|
137 |
+
check_interval_value = max(self.interval.value // CHECKS_PER_INTERVAL, 1)
|
138 |
+
self.check_interval = Time(check_interval_value, self.interval.unit)
|
139 |
+
self.checkpoints_evaled: Dict[Time, Tuple[str, str]] = {}
|
140 |
+
self.is_at_check_interval = create_interval_scheduler(self.check_interval, include_end_of_training=False)
|
141 |
+
log.info('Initialized AsyncEval callback. Will generate runs at ' + f'interval {interval}, checking at {self.check_interval}')
|
142 |
+
|
143 |
+
def state_dict(self) -> Dict[str, Any]:
|
144 |
+
checkpoints_evaled = []
|
145 |
+
for eval_ts, (checkpoint, run_name) in self.checkpoints_evaled.items():
|
146 |
+
eval_ts_dict = {'value': eval_ts.value, 'unit': eval_ts.unit.value}
|
147 |
+
checkpoints_evaled.append((eval_ts_dict, checkpoint, run_name))
|
148 |
+
return {'checkpoints_evaled': checkpoints_evaled}
|
149 |
+
|
150 |
+
def load_state_dict(self, state_dict: Dict[str, Any]):
|
151 |
+
previous_checkpoints_evaled = state_dict.get('checkpoints_evaled', [])
|
152 |
+
if previous_checkpoints_evaled:
|
153 |
+
for eval_ts, checkpoint, run_name in previous_checkpoints_evaled:
|
154 |
+
eval_ts = Time(eval_ts['value'], TimeUnit(eval_ts['unit']))
|
155 |
+
self.checkpoints_evaled[eval_ts] = (checkpoint, run_name)
|
156 |
+
log.info(f'Loaded previous checkpoints evaled: {self.checkpoints_evaled}')
|
157 |
+
|
158 |
+
@staticmethod
|
159 |
+
def _get_ready_sharded_checkpoints(checkpointer_checkpoints: Dict[str, Timestamp], remote_files: List[str]) -> Dict[str, Timestamp]:
|
160 |
+
"""Identify checkpoints ready to be evaled based on remote files.
|
161 |
+
|
162 |
+
This has special logic for sharded checkpoints to consider checkpoints composed
|
163 |
+
of multiple shards (one per gpu) and metadata
|
164 |
+
|
165 |
+
Args:
|
166 |
+
checkpointer_checkpoints: All checkpoints from the checkpointer state
|
167 |
+
remote_files: List of remote files in the save folder
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
Dict of checkpoints that are complete and ready to be evaled
|
171 |
+
"""
|
172 |
+
remote_file_group_counts = Counter()
|
173 |
+
for f in remote_files:
|
174 |
+
checkpoint_ts_path = Path(f).parts[-2]
|
175 |
+
remote_file_group_counts[checkpoint_ts_path] += 1
|
176 |
+
checkpoints_to_eval = {}
|
177 |
+
for checkpoint, checkpoint_ts in checkpointer_checkpoints.items():
|
178 |
+
checkpoint_ts_path = Path(checkpoint).parts[-2]
|
179 |
+
expected_shard_count = dist.get_world_size() + 1
|
180 |
+
if remote_file_group_counts[checkpoint_ts_path] != expected_shard_count:
|
181 |
+
log.debug(f'Checkpoint {checkpoint} not fully uploaded (missing shards ' + f'{remote_file_group_counts[checkpoint_ts_path]}/{expected_shard_count}), skipping')
|
182 |
+
continue
|
183 |
+
checkpoints_to_eval[checkpoint_ts_path] = checkpoint_ts
|
184 |
+
return checkpoints_to_eval
|
185 |
+
|
186 |
+
@staticmethod
|
187 |
+
def _get_ready_single_checkpoints(checkpointer_checkpoints: Dict[str, Timestamp], remote_checkpoints: List[str]) -> Dict[str, Timestamp]:
|
188 |
+
"""Identify checkpoints ready to be evaled based on remote checkpoints.
|
189 |
+
|
190 |
+
This is much simpler than the sharded case, because there is only one file
|
191 |
+
|
192 |
+
Args:
|
193 |
+
checkpointer_checkpoints: All checkpoints from the checkpointer state
|
194 |
+
remote_checkpoints: List of remote checkpoints in the save folder
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
Dict of checkpoints that are complete and ready to be evaled
|
198 |
+
"""
|
199 |
+
unique_remote_checkpoints = set(remote_checkpoints)
|
200 |
+
checkpoints_to_eval = {}
|
201 |
+
for checkpoint, checkpoint_ts in checkpointer_checkpoints.items():
|
202 |
+
checkpoint_ts_path = Path(checkpoint).parts[-1]
|
203 |
+
if checkpoint not in unique_remote_checkpoints:
|
204 |
+
log.debug(f'Checkpoint {checkpoint} not fully uploaded, skipping')
|
205 |
+
continue
|
206 |
+
checkpoints_to_eval[checkpoint_ts_path] = checkpoint_ts
|
207 |
+
return checkpoints_to_eval
|
208 |
+
|
209 |
+
def _get_checkpoints_and_launch_runs(self, state: State):
|
210 |
+
"""Get the latest checkpoint from the training run.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
state: The current state of the training run
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
Returns checkpoints that have not been evaled
|
217 |
+
"""
|
218 |
+
checkpointer = None
|
219 |
+
for callback in state.callbacks:
|
220 |
+
if isinstance(callback, CheckpointSaver):
|
221 |
+
if checkpointer is None:
|
222 |
+
checkpointer = callback
|
223 |
+
else:
|
224 |
+
log.warning('Multiple checkpoint savers found. Using the first one')
|
225 |
+
if not checkpointer:
|
226 |
+
warnings.warn('No checkpoint saver callback found. Skipping eval')
|
227 |
+
return
|
228 |
+
if not checkpointer.all_saved_checkpoints_to_timestamp:
|
229 |
+
log.debug('No saved checkpoints found on the checkpointer. Skipping eval')
|
230 |
+
return
|
231 |
+
log.debug(f'Found {len(checkpointer.all_saved_checkpoints_to_timestamp)} ' + f'checkpoints: {checkpointer.all_saved_checkpoints_to_timestamp}')
|
232 |
+
remote_checkpoints = list_remote_objects(self.checkpoint_save_folder)
|
233 |
+
if not remote_checkpoints:
|
234 |
+
log.debug('No saved checkpoints found yet on remote. Skipping eval')
|
235 |
+
return
|
236 |
+
if state.fsdp_sharded_state_dict_enabled:
|
237 |
+
checkpoints_to_eval = self._get_ready_sharded_checkpoints(checkpointer.all_saved_checkpoints_to_timestamp, remote_checkpoints)
|
238 |
+
else:
|
239 |
+
checkpoints_to_eval = self._get_ready_single_checkpoints(checkpointer.all_saved_checkpoints_to_timestamp, remote_checkpoints)
|
240 |
+
for checkpoint_interval_path, checkpoint_timestamp in checkpoints_to_eval.items():
|
241 |
+
checkpoint_ts = checkpoint_timestamp.get(self.interval.unit)
|
242 |
+
if checkpoint_ts.value % self.interval.value != 0:
|
243 |
+
log.debug(f'Checkpoint {checkpoint_interval_path} ({checkpoint_ts}) is ' + f'not at an eval interval ({self.interval}), skipping')
|
244 |
+
continue
|
245 |
+
if checkpoint_ts in self.checkpoints_evaled:
|
246 |
+
continue
|
247 |
+
full_checkpoint_path = f'{self.checkpoint_save_folder}/{checkpoint_interval_path}'
|
248 |
+
eval_run = self.launch_run(full_checkpoint_path, checkpoint_ts)
|
249 |
+
self.checkpoints_evaled[checkpoint_ts] = (full_checkpoint_path, eval_run.name)
|
250 |
+
|
251 |
+
def run_event(self, event: Event, state: State, logger: Logger) -> None:
|
252 |
+
del logger
|
253 |
+
should_launch_run = all([state.get_elapsed_duration() is not None, self.is_at_check_interval(state, event), dist.get_global_rank() == 0])
|
254 |
+
if should_launch_run:
|
255 |
+
self._get_checkpoints_and_launch_runs(state)
|
256 |
+
|
257 |
+
def close(self, state: State, logger: Logger) -> None:
|
258 |
+
del logger
|
259 |
+
if dist.get_global_rank() != 0:
|
260 |
+
return
|
261 |
+
self._get_checkpoints_and_launch_runs(state)
|
262 |
+
latest_timestamp = state.timestamp.get(self.interval.unit)
|
263 |
+
if latest_timestamp not in self.checkpoints_evaled:
|
264 |
+
save_latest_filename = self.training_params.get('save_latest_filename', None)
|
265 |
+
if not save_latest_filename:
|
266 |
+
rank = dist.get_global_rank()
|
267 |
+
save_latest_filename = f'latest-rank{rank}.pt'
|
268 |
+
checkpoint = f'{self.checkpoint_save_folder}/{save_latest_filename}'
|
269 |
+
eval_run = self.launch_run(checkpoint, latest_timestamp)
|
270 |
+
self.checkpoints_evaled[latest_timestamp] = (checkpoint, eval_run.name)
|
271 |
+
log.info(f'AsyncEval callback finished. Launched {len(self.checkpoints_evaled)} eval runs:')
|
272 |
+
for checkpoint_ts, (checkpoint, run_name) in self.checkpoints_evaled.items():
|
273 |
+
log.info(f' {checkpoint_ts}: {checkpoint}, {run_name}')
|
274 |
+
|
275 |
+
def _get_current_run(self) -> Run:
|
276 |
+
if os.environ.get(MOSAICML_PLATFORM_ENV_VAR, 'false').lower() == 'false':
|
277 |
+
raise RuntimeError('AsyncEval callback is only supported when running on the MosaicML platform')
|
278 |
+
run_name = os.environ.get(RUN_NAME_ENV_VAR, None)
|
279 |
+
if not run_name:
|
280 |
+
raise RuntimeError('RUN_NAME environment variable must be set to use the AsyncEval callback')
|
281 |
+
return get_run(run_name, include_details=True)
|
282 |
+
|
283 |
+
def launch_run(self, checkpoint: str, current_interval: Time) -> Run:
|
284 |
+
"""Launch a new eval run.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
checkpoint: The checkpoint to eval
|
288 |
+
current_interval: The interval of the checkpoint
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
The launched run (mcli.Run type)
|
292 |
+
"""
|
293 |
+
log.info(f'Launching eval run for {checkpoint} at {current_interval}')
|
294 |
+
cfg = self.current_run.submitted_config
|
295 |
+
default_compute = {'gpus': 8, 'cluster': self.current_run.cluster}
|
296 |
+
run_name = get_run_name(self.current_run.name, str(current_interval))
|
297 |
+
params = get_eval_parameters(parameters=self.training_params, checkpoint=checkpoint, training_run_name=self.current_run.name)
|
298 |
+
params['run_name'] = run_name
|
299 |
+
integrations = cfg.integrations
|
300 |
+
found_llm_foundry, installation_path = (False, 'llm-foundry')
|
301 |
+
for i in integrations:
|
302 |
+
if i['integration_type'] != 'git_repo':
|
303 |
+
continue
|
304 |
+
if not i['git_repo'].endswith('llm-foundry'):
|
305 |
+
continue
|
306 |
+
found_llm_foundry = True
|
307 |
+
if i.get('path'):
|
308 |
+
installation_path = i['path']
|
309 |
+
if not found_llm_foundry:
|
310 |
+
from .llmfoundry import __version__ as latest_foundry_version
|
311 |
+
version = f'v{latest_foundry_version}'
|
312 |
+
log.warning('No github integration found for llm-foundry. Adding installation ' + f'to eval run for latest foundry release ({version}). ' + 'To use a fork, custom branch, or custom version, configure ' + 'llm-foundry installation through a github integration')
|
313 |
+
integrations.append({'integration_type': 'git_repo', 'git_repo': 'mosaicml/llm-foundry', 'git_branch': version, 'pip_install': '-e .[gpu]', 'ssh_clone': False})
|
314 |
+
metadata = cfg.metadata
|
315 |
+
metadata['eval_timestamp'] = current_interval.value
|
316 |
+
metadata['eval_timestamp_unit'] = current_interval.unit.value
|
317 |
+
default_command = f'cd {installation_path}/scripts \n composer eval/eval.py $PARAMETERS'
|
318 |
+
run_config = RunConfig(name=run_name, image=self.eval_run_config.get('image', self.current_run.image), command=self.eval_run_config.get('command', default_command), compute=self.eval_run_config.get('compute', default_compute), scheduling=self.eval_run_config.get('scheduling', self.current_run.submitted_config.scheduling), integrations=integrations, env_variables=cfg.env_variables, metadata=cfg.metadata, parameters=params)
|
319 |
+
log.info(f'Creating new run with config: \n{run_config}')
|
320 |
+
new_run = create_run(run_config)
|
321 |
+
log.info(f'Launched new run {new_run.name} inside eval loop')
|
322 |
+
return new_run
|
attention.py
CHANGED
@@ -31,9 +31,6 @@ def is_transformers_version_gte(hf_version: str) -> bool:
|
|
31 |
|
32 |
def check_alibi_support(attention_impl: str) -> bool:
|
33 |
return attention_impl != 'flash' or is_flash_v2_installed(v2_version='v2.4.2')
|
34 |
-
if is_flash_v1_installed():
|
35 |
-
import transformers
|
36 |
-
transformers.utils.is_flash_attn_available = lambda : False
|
37 |
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
38 |
|
39 |
def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool) -> bool:
|
@@ -53,7 +50,7 @@ def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
53 |
"""
|
54 |
if n_rep == 1:
|
55 |
return hidden
|
56 |
-
|
57 |
hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)
|
58 |
return hidden.reshape(b, s, kv_n_heads * n_rep, d)
|
59 |
|
@@ -66,7 +63,7 @@ def scaled_multihead_dot_product_attention(query: torch.Tensor, key: torch.Tenso
|
|
66 |
k = torch.cat([past_key_value[0], k], dim=3)
|
67 |
v = torch.cat([past_key_value[1], v], dim=2)
|
68 |
past_key_value = (k, v)
|
69 |
-
|
70 |
s_k = k.size(-1)
|
71 |
if kv_n_heads > 1 and kv_n_heads < n_heads:
|
72 |
k = repeat_kv_for_gqa(k.transpose(1, 2), n_heads // kv_n_heads).transpose(1, 2)
|
@@ -130,7 +127,7 @@ def flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n
|
|
130 |
past_key_value = (key, value)
|
131 |
if attn_bias is not None:
|
132 |
raise NotImplementedError(f'attn_bias not implemented for flash attn.')
|
133 |
-
|
134 |
indices_q = flash_attn_padding_info['indices_q']
|
135 |
indices_k = flash_attn_padding_info['indices_k']
|
136 |
indices_v = flash_attn_padding_info['indices_v']
|
@@ -169,65 +166,17 @@ def flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n
|
|
169 |
output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
|
170 |
return (output, None, past_key_value)
|
171 |
|
172 |
-
def triton_flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: int, past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
173 |
-
try:
|
174 |
-
from .flash_attn_triton import flash_attn_func
|
175 |
-
except:
|
176 |
-
_installed = False
|
177 |
-
if version.parse(torch.__version__) < version.parse('2.0.0'):
|
178 |
-
_installed = True
|
179 |
-
try:
|
180 |
-
from flash_attn.flash_attn_triton import flash_attn_func
|
181 |
-
except:
|
182 |
-
_installed = False
|
183 |
-
if not _installed:
|
184 |
-
raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU ' + 'and `pip install .[gpu]` if installing from llm-foundry source or ' + '`pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` ' + 'if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). ' + 'Note: (1) requires you have CMake and PyTorch already installed.')
|
185 |
-
check_valid_inputs(query, key, value)
|
186 |
-
if past_key_value is not None:
|
187 |
-
if len(past_key_value) != 0:
|
188 |
-
key = torch.cat([past_key_value[0], key], dim=1)
|
189 |
-
value = torch.cat([past_key_value[1], value], dim=1)
|
190 |
-
past_key_value = (key, value)
|
191 |
-
if attn_bias is not None:
|
192 |
-
_s_q = max(0, attn_bias.size(2) - query.size(1))
|
193 |
-
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
194 |
-
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
195 |
-
if dropout_p:
|
196 |
-
raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
|
197 |
-
dropout_p = dropout_p if training else 0.0
|
198 |
-
if needs_weights:
|
199 |
-
raise NotImplementedError(f'attn_impl: triton cannot return attn weights.')
|
200 |
-
if key_padding_mask is not None:
|
201 |
-
warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
|
202 |
-
(b_size, s_k) = key_padding_mask.shape[:2]
|
203 |
-
if attn_bias is None:
|
204 |
-
attn_bias = query.new_zeros(b_size, 1, 1, s_k)
|
205 |
-
attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
|
206 |
-
query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
|
207 |
-
key = rearrange(key, 'b s (h d) -> b s h d', h=kv_n_heads)
|
208 |
-
value = rearrange(value, 'b s (h d) -> b s h d', h=kv_n_heads)
|
209 |
-
if kv_n_heads == 1:
|
210 |
-
key = key.repeat(1, 1, n_heads, 1)
|
211 |
-
value = value.repeat(1, 1, n_heads, 1)
|
212 |
-
elif kv_n_heads < n_heads:
|
213 |
-
key = repeat_kv_for_gqa(key, n_heads // kv_n_heads)
|
214 |
-
value = repeat_kv_for_gqa(value, n_heads // kv_n_heads)
|
215 |
-
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
216 |
-
attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
|
217 |
-
output = attn_output.view(*attn_output.shape[:2], -1)
|
218 |
-
return (output, None, past_key_value)
|
219 |
-
|
220 |
class GroupedQueryAttention(nn.Module):
|
221 |
"""Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
|
222 |
|
223 |
and Multi-query attention (MQA).
|
224 |
|
225 |
This allows the user to set a variable of number of kv_n_heads, rather than
|
226 |
-
just n_heads or 1, as in MHA and MQA. Using torch
|
227 |
implementation enables user to also use additive bias.
|
228 |
"""
|
229 |
|
230 |
-
def __init__(self, d_model: int, n_heads: int, kv_n_heads: int, attn_impl: str='
|
231 |
super().__init__()
|
232 |
self.attn_impl = attn_impl
|
233 |
self.clip_qkv = clip_qkv
|
@@ -251,8 +200,7 @@ class GroupedQueryAttention(nn.Module):
|
|
251 |
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
252 |
self.attn_dropout_p = attn_pdrop
|
253 |
fc_kwargs: dict[str, Any] = {'bias': bias}
|
254 |
-
|
255 |
-
fc_kwargs['device'] = device
|
256 |
self.Wqkv = FC_CLASS_REGISTRY[fc_type](self.d_model, self.d_model + 2 * self.kv_n_heads * self.head_dim, **fc_kwargs)
|
257 |
fuse_splits = [i * self.head_dim for i in range(1, self.n_heads + 2 * self.kv_n_heads)]
|
258 |
self.Wqkv._fused = (0, fuse_splits)
|
@@ -265,8 +213,6 @@ class GroupedQueryAttention(nn.Module):
|
|
265 |
self.k_ln = norm_class(norm_size, device=device)
|
266 |
if self.attn_impl == 'flash':
|
267 |
self.attn_fn = flash_attn_fn
|
268 |
-
elif self.attn_impl == 'triton':
|
269 |
-
self.attn_fn = triton_flash_attn_fn
|
270 |
elif self.attn_impl == 'torch':
|
271 |
self.attn_fn = scaled_multihead_dot_product_attention
|
272 |
else:
|
@@ -278,12 +224,12 @@ class GroupedQueryAttention(nn.Module):
|
|
278 |
qkv = self.Wqkv(x)
|
279 |
if self.clip_qkv:
|
280 |
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
|
281 |
-
|
282 |
key_padding_mask = attention_mask
|
283 |
if self.qk_ln or self.qk_gn:
|
284 |
-
|
285 |
if self.qk_gn:
|
286 |
-
|
287 |
query = query.view(b, s, self.n_heads, -1)
|
288 |
key = key.view(b, s, self.kv_n_heads, -1)
|
289 |
dtype = query.dtype
|
@@ -293,23 +239,28 @@ class GroupedQueryAttention(nn.Module):
|
|
293 |
rotary_emb = rotary_emb_w_meta_info['rotary_emb']
|
294 |
seq_len = rotary_emb_w_meta_info['seq_len']
|
295 |
offset_info = rotary_emb_w_meta_info['offset_info']
|
296 |
-
|
297 |
query = query.view(bsz, seqlen, -1, self.head_dim)
|
298 |
key = key.view(bsz, seqlen, -1, self.head_dim)
|
299 |
if rotary_emb_w_meta_info['impl'] == 'dail':
|
300 |
value = value.view(bsz, seqlen, -1, self.head_dim)
|
301 |
kv = torch.stack([key, value], dim=2)
|
302 |
-
|
303 |
[key, value] = torch.unbind(kv, dim=2)
|
304 |
value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
|
305 |
elif rotary_emb_w_meta_info['impl'] == 'hf':
|
306 |
-
|
307 |
-
|
308 |
-
|
|
|
|
|
|
|
|
|
|
|
309 |
else:
|
310 |
query = query.transpose(1, 2)
|
311 |
key = key.transpose(1, 2)
|
312 |
-
|
313 |
query = query.transpose(1, 2)
|
314 |
key = key.transpose(1, 2)
|
315 |
query = query.view(bsz, seqlen, self.d_model)
|
@@ -318,38 +269,36 @@ class GroupedQueryAttention(nn.Module):
|
|
318 |
if self.attn_impl == 'flash':
|
319 |
key_padding_mask = None
|
320 |
extra_attn_kwargs = {'should_repeat_kv_for_gqa': not is_flash_v2_installed(), 'sliding_window_size': self.sliding_window_size, 'alibi_slopes': alibi_slopes, 'flash_attn_padding_info': flash_attn_padding_info}
|
321 |
-
|
322 |
return (self.out_proj(context), attn_weights, past_key_value)
|
323 |
|
324 |
class MultiheadAttention(GroupedQueryAttention):
|
325 |
"""Multi-head self attention.
|
326 |
|
327 |
-
Using torch
|
328 |
-
additive bias.
|
329 |
"""
|
330 |
|
331 |
-
def __init__(self, d_model: int, n_heads: int, attn_impl: str='
|
332 |
super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=n_heads, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device, bias=bias, sliding_window_size=sliding_window_size)
|
333 |
|
334 |
class MultiQueryAttention(GroupedQueryAttention):
|
335 |
"""Multi-Query self attention.
|
336 |
|
337 |
-
Using torch
|
338 |
-
additive bias.
|
339 |
"""
|
340 |
|
341 |
-
def __init__(self, d_model: int, n_heads: int, attn_impl: str='
|
342 |
super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=1, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device, bias=bias, sliding_window_size=sliding_window_size)
|
343 |
|
344 |
-
def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool,
|
345 |
if attn_impl == 'flash':
|
346 |
return None
|
347 |
-
elif attn_impl
|
348 |
if alibi:
|
349 |
-
if
|
350 |
return (1, n_heads, seq_len, seq_len)
|
351 |
return (1, n_heads, 1, seq_len)
|
352 |
-
elif
|
353 |
return (1, 1, seq_len, seq_len)
|
354 |
return None
|
355 |
else:
|
@@ -358,9 +307,9 @@ def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool, pre
|
|
358 |
def build_attn_bias(attn_impl: str, attn_bias: torch.Tensor, n_heads: int, seq_len: int, causal: bool=False, alibi: bool=False, alibi_bias_max: int=8) -> Optional[torch.Tensor]:
|
359 |
if attn_impl == 'flash':
|
360 |
return None
|
361 |
-
elif attn_impl
|
362 |
if alibi:
|
363 |
-
|
364 |
attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
|
365 |
return attn_bias
|
366 |
else:
|
|
|
31 |
|
32 |
def check_alibi_support(attention_impl: str) -> bool:
|
33 |
return attention_impl != 'flash' or is_flash_v2_installed(v2_version='v2.4.2')
|
|
|
|
|
|
|
34 |
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
35 |
|
36 |
def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool) -> bool:
|
|
|
50 |
"""
|
51 |
if n_rep == 1:
|
52 |
return hidden
|
53 |
+
b, s, kv_n_heads, d = hidden.shape
|
54 |
hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)
|
55 |
return hidden.reshape(b, s, kv_n_heads * n_rep, d)
|
56 |
|
|
|
63 |
k = torch.cat([past_key_value[0], k], dim=3)
|
64 |
v = torch.cat([past_key_value[1], v], dim=2)
|
65 |
past_key_value = (k, v)
|
66 |
+
b, _, s_q, d = q.shape
|
67 |
s_k = k.size(-1)
|
68 |
if kv_n_heads > 1 and kv_n_heads < n_heads:
|
69 |
k = repeat_kv_for_gqa(k.transpose(1, 2), n_heads // kv_n_heads).transpose(1, 2)
|
|
|
127 |
past_key_value = (key, value)
|
128 |
if attn_bias is not None:
|
129 |
raise NotImplementedError(f'attn_bias not implemented for flash attn.')
|
130 |
+
batch_size, seqlen = query.shape[:2]
|
131 |
indices_q = flash_attn_padding_info['indices_q']
|
132 |
indices_k = flash_attn_padding_info['indices_k']
|
133 |
indices_v = flash_attn_padding_info['indices_v']
|
|
|
166 |
output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
|
167 |
return (output, None, past_key_value)
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
class GroupedQueryAttention(nn.Module):
|
170 |
"""Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
|
171 |
|
172 |
and Multi-query attention (MQA).
|
173 |
|
174 |
This allows the user to set a variable of number of kv_n_heads, rather than
|
175 |
+
just n_heads or 1, as in MHA and MQA. Using torch attention
|
176 |
implementation enables user to also use additive bias.
|
177 |
"""
|
178 |
|
179 |
+
def __init__(self, d_model: int, n_heads: int, kv_n_heads: int, attn_impl: str='flash', clip_qkv: Optional[float]=None, qk_ln: bool=False, qk_gn: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, bias: bool=True, sliding_window_size: int=-1):
|
180 |
super().__init__()
|
181 |
self.attn_impl = attn_impl
|
182 |
self.clip_qkv = clip_qkv
|
|
|
200 |
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
201 |
self.attn_dropout_p = attn_pdrop
|
202 |
fc_kwargs: dict[str, Any] = {'bias': bias}
|
203 |
+
fc_kwargs['device'] = device
|
|
|
204 |
self.Wqkv = FC_CLASS_REGISTRY[fc_type](self.d_model, self.d_model + 2 * self.kv_n_heads * self.head_dim, **fc_kwargs)
|
205 |
fuse_splits = [i * self.head_dim for i in range(1, self.n_heads + 2 * self.kv_n_heads)]
|
206 |
self.Wqkv._fused = (0, fuse_splits)
|
|
|
213 |
self.k_ln = norm_class(norm_size, device=device)
|
214 |
if self.attn_impl == 'flash':
|
215 |
self.attn_fn = flash_attn_fn
|
|
|
|
|
216 |
elif self.attn_impl == 'torch':
|
217 |
self.attn_fn = scaled_multihead_dot_product_attention
|
218 |
else:
|
|
|
224 |
qkv = self.Wqkv(x)
|
225 |
if self.clip_qkv:
|
226 |
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
|
227 |
+
query, key, value = qkv.split([self.d_model, self.kv_n_heads * self.head_dim, self.kv_n_heads * self.head_dim], dim=2)
|
228 |
key_padding_mask = attention_mask
|
229 |
if self.qk_ln or self.qk_gn:
|
230 |
+
q_shape, k_shape = (query.shape, key.shape)
|
231 |
if self.qk_gn:
|
232 |
+
b, s = query.shape[:2]
|
233 |
query = query.view(b, s, self.n_heads, -1)
|
234 |
key = key.view(b, s, self.kv_n_heads, -1)
|
235 |
dtype = query.dtype
|
|
|
239 |
rotary_emb = rotary_emb_w_meta_info['rotary_emb']
|
240 |
seq_len = rotary_emb_w_meta_info['seq_len']
|
241 |
offset_info = rotary_emb_w_meta_info['offset_info']
|
242 |
+
bsz, seqlen = query.shape[:2]
|
243 |
query = query.view(bsz, seqlen, -1, self.head_dim)
|
244 |
key = key.view(bsz, seqlen, -1, self.head_dim)
|
245 |
if rotary_emb_w_meta_info['impl'] == 'dail':
|
246 |
value = value.view(bsz, seqlen, -1, self.head_dim)
|
247 |
kv = torch.stack([key, value], dim=2)
|
248 |
+
query, kv = rotary_emb(query, kv, seqlen_offset=offset_info, max_seqlen=seq_len)
|
249 |
[key, value] = torch.unbind(kv, dim=2)
|
250 |
value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
|
251 |
elif rotary_emb_w_meta_info['impl'] == 'hf':
|
252 |
+
if is_transformers_version_gte('4.38'):
|
253 |
+
cos, sin = rotary_emb(x=value, position_ids=offset_info, seq_len=None)
|
254 |
+
else:
|
255 |
+
cos, sin = rotary_emb(x=value, seq_len=seq_len)
|
256 |
+
if is_transformers_version_gte('4.38'):
|
257 |
+
query, key = apply_rotary_pos_emb(q=query, k=key, cos=cos, sin=sin, position_ids=None, unsqueeze_dim=2)
|
258 |
+
elif is_transformers_version_gte('4.36'):
|
259 |
+
query, key = apply_rotary_pos_emb(q=query, k=key, cos=cos, sin=sin, position_ids=offset_info, unsqueeze_dim=2)
|
260 |
else:
|
261 |
query = query.transpose(1, 2)
|
262 |
key = key.transpose(1, 2)
|
263 |
+
query, key = apply_rotary_pos_emb(q=query, k=key, cos=cos, sin=sin, position_ids=offset_info)
|
264 |
query = query.transpose(1, 2)
|
265 |
key = key.transpose(1, 2)
|
266 |
query = query.view(bsz, seqlen, self.d_model)
|
|
|
269 |
if self.attn_impl == 'flash':
|
270 |
key_padding_mask = None
|
271 |
extra_attn_kwargs = {'should_repeat_kv_for_gqa': not is_flash_v2_installed(), 'sliding_window_size': self.sliding_window_size, 'alibi_slopes': alibi_slopes, 'flash_attn_padding_info': flash_attn_padding_info}
|
272 |
+
context, attn_weights, past_key_value = self.attn_fn(query, key, value, self.n_heads, self.kv_n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, **extra_attn_kwargs)
|
273 |
return (self.out_proj(context), attn_weights, past_key_value)
|
274 |
|
275 |
class MultiheadAttention(GroupedQueryAttention):
|
276 |
"""Multi-head self attention.
|
277 |
|
278 |
+
Using torch attention implementation enables user to also use additive bias.
|
|
|
279 |
"""
|
280 |
|
281 |
+
def __init__(self, d_model: int, n_heads: int, attn_impl: str='flash', clip_qkv: Optional[float]=None, qk_ln: bool=False, qk_gn: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, bias: bool=True, sliding_window_size: int=-1):
|
282 |
super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=n_heads, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device, bias=bias, sliding_window_size=sliding_window_size)
|
283 |
|
284 |
class MultiQueryAttention(GroupedQueryAttention):
|
285 |
"""Multi-Query self attention.
|
286 |
|
287 |
+
Using torch attention implementation enables user to also use additive bias.
|
|
|
288 |
"""
|
289 |
|
290 |
+
def __init__(self, d_model: int, n_heads: int, attn_impl: str='flash', clip_qkv: Optional[float]=None, qk_ln: bool=False, qk_gn: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, bias: bool=True, sliding_window_size: int=-1):
|
291 |
super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=1, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device, bias=bias, sliding_window_size=sliding_window_size)
|
292 |
|
293 |
+
def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool, causal: bool, use_sequence_id: bool) -> Optional[tuple[int, int, int, int]]:
|
294 |
if attn_impl == 'flash':
|
295 |
return None
|
296 |
+
elif attn_impl == 'torch':
|
297 |
if alibi:
|
298 |
+
if not causal or use_sequence_id:
|
299 |
return (1, n_heads, seq_len, seq_len)
|
300 |
return (1, n_heads, 1, seq_len)
|
301 |
+
elif use_sequence_id:
|
302 |
return (1, 1, seq_len, seq_len)
|
303 |
return None
|
304 |
else:
|
|
|
307 |
def build_attn_bias(attn_impl: str, attn_bias: torch.Tensor, n_heads: int, seq_len: int, causal: bool=False, alibi: bool=False, alibi_bias_max: int=8) -> Optional[torch.Tensor]:
|
308 |
if attn_impl == 'flash':
|
309 |
return None
|
310 |
+
elif attn_impl == 'torch':
|
311 |
if alibi:
|
312 |
+
device, dtype = (attn_bias.device, attn_bias.dtype)
|
313 |
attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
|
314 |
return attn_bias
|
315 |
else:
|
blocks.py
CHANGED
@@ -8,8 +8,8 @@ from .norm import NORM_CLASS_REGISTRY
|
|
8 |
try:
|
9 |
from flash_attn.bert_padding import unpad_input, pad_input
|
10 |
except:
|
11 |
-
|
12 |
-
attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': '
|
13 |
|
14 |
class MPTBlock(nn.Module):
|
15 |
|
@@ -23,8 +23,8 @@ class MPTBlock(nn.Module):
|
|
23 |
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
24 |
assert isinstance(attn_config['attn_type'], str)
|
25 |
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
|
26 |
-
args_to_exclude_in_attn_class = {'attn_type', '
|
27 |
-
attn_config_subset_for_attn_class = {k: v for
|
28 |
self.norm_1 = norm_class(d_model, device=device)
|
29 |
self.attn = attn_class(d_model=d_model, n_heads=n_heads, fc_type=fc_type, device=device, **attn_config_subset_for_attn_class, bias=not no_bias)
|
30 |
self.norm_2 = None
|
@@ -37,16 +37,16 @@ class MPTBlock(nn.Module):
|
|
37 |
|
38 |
def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, rotary_emb_w_meta_info: Optional[Dict]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True, output_attentions: bool=False, alibi_slopes: Optional[torch.Tensor]=None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]]=None) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
39 |
a = self.norm_1(x)
|
40 |
-
|
41 |
x = x + self.resid_attn_dropout(b)
|
42 |
m = x
|
43 |
if self.norm_2 is not None:
|
44 |
m = self.norm_2(x)
|
45 |
-
|
46 |
indices = None
|
47 |
if not self.use_pad_tok_in_ffn:
|
48 |
assert unpad_input is not None
|
49 |
-
|
50 |
n = self.ffn(m)
|
51 |
if not self.use_pad_tok_in_ffn:
|
52 |
assert pad_input is not None
|
|
|
8 |
try:
|
9 |
from flash_attn.bert_padding import unpad_input, pad_input
|
10 |
except:
|
11 |
+
unpad_input, pad_input = (None, None)
|
12 |
+
attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'flash', 'qk_ln': False, 'qk_gn': False, 'clip_qkv': None, 'softmax_scale': None, 'attn_uses_sequence_id': False, 'sliding_window_size': -1, 'alibi': False, 'alibi_bias_max': 8, 'rope': False, 'rope_theta': 10000, 'rope_impl': 'dail', 'rope_dail_config': {'type': 'original', 'pos_idx_in_fp32': True, 'xpos_scale_base': 512}, 'rope_hf_config': {'type': 'no_scaling', 'factor': 1.0}}
|
13 |
|
14 |
class MPTBlock(nn.Module):
|
15 |
|
|
|
23 |
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
24 |
assert isinstance(attn_config['attn_type'], str)
|
25 |
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
|
26 |
+
args_to_exclude_in_attn_class = {'attn_type', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max', 'rope', 'rope_theta', 'rope_impl', 'rope_dail_config', 'rope_hf_config'}
|
27 |
+
attn_config_subset_for_attn_class = {k: v for k, v in attn_config.items() if k not in args_to_exclude_in_attn_class}
|
28 |
self.norm_1 = norm_class(d_model, device=device)
|
29 |
self.attn = attn_class(d_model=d_model, n_heads=n_heads, fc_type=fc_type, device=device, **attn_config_subset_for_attn_class, bias=not no_bias)
|
30 |
self.norm_2 = None
|
|
|
37 |
|
38 |
def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, rotary_emb_w_meta_info: Optional[Dict]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True, output_attentions: bool=False, alibi_slopes: Optional[torch.Tensor]=None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]]=None) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
39 |
a = self.norm_1(x)
|
40 |
+
b, attn_weights, past_key_value = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info)
|
41 |
x = x + self.resid_attn_dropout(b)
|
42 |
m = x
|
43 |
if self.norm_2 is not None:
|
44 |
m = self.norm_2(x)
|
45 |
+
batch_size, seq_len = m.size()[:2]
|
46 |
indices = None
|
47 |
if not self.use_pad_tok_in_ffn:
|
48 |
assert unpad_input is not None
|
49 |
+
m, indices, _, _ = unpad_input(m, attention_mask)
|
50 |
n = self.ffn(m)
|
51 |
if not self.use_pad_tok_in_ffn:
|
52 |
assert pad_input is not None
|
builders.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import functools
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
from collections import OrderedDict
|
7 |
+
from typing import Any, ContextManager, Dict, Iterable, List, Optional, Tuple, Union
|
8 |
+
import torch
|
9 |
+
from torch.optim.optimizer import Optimizer
|
10 |
+
from torchmetrics import Metric
|
11 |
+
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
12 |
+
from .llmfoundry import registry
|
13 |
+
from .callbacks import EvalGauntlet
|
14 |
+
from .dataloader import build_dataloader
|
15 |
+
from .tiktoken import TiktokenTokenizerWrapper
|
16 |
+
from .registry_utils import construct_from_registry
|
17 |
+
log = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
def build_evaluators(eval_loader_config: Optional[Union[DictConfig, ListConfig]], icl_tasks_config: Optional[Union[str, ListConfig]], eval_gauntlet_config: Optional[Union[str, DictConfig]], *, tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int, icl_seq_len: int, icl_subset_num_batches: Optional[int]) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]:
|
20 |
+
evaluators = []
|
21 |
+
if eval_loader_config is not None:
|
22 |
+
evaluators = build_eval_loaders(eval_loader_config, tokenizer, device_eval_batch_size)
|
23 |
+
logger_keys = []
|
24 |
+
eval_gauntlet_callback = None
|
25 |
+
if icl_tasks_config is not None:
|
26 |
+
icl_evaluators, logger_keys, eval_gauntlet_callback = build_icl_data_and_gauntlet(icl_tasks_config, eval_gauntlet_config, tokenizer, device_eval_batch_size, icl_seq_len, icl_subset_num_batches)
|
27 |
+
evaluators.extend(icl_evaluators)
|
28 |
+
return (evaluators, logger_keys, eval_gauntlet_callback)
|
29 |
+
|
30 |
+
def build_eval_loaders(eval_loader_config: Union[DictConfig, ListConfig], tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int) -> List[Evaluator]:
|
31 |
+
evaluators: List[Evaluator] = []
|
32 |
+
if isinstance(eval_loader_config, ListConfig):
|
33 |
+
eval_configs: ListConfig = eval_loader_config
|
34 |
+
is_multi_eval = True
|
35 |
+
else:
|
36 |
+
eval_configs = ListConfig([eval_loader_config])
|
37 |
+
is_multi_eval = False
|
38 |
+
for eval_config in eval_configs:
|
39 |
+
eval_dataloader = build_dataloader(eval_config, tokenizer, device_eval_batch_size)
|
40 |
+
eval_loader: Evaluator = Evaluator(label=f'eval/{eval_config.label}' if is_multi_eval else 'eval', dataloader=eval_dataloader, metric_names=[])
|
41 |
+
evaluators.append(eval_loader)
|
42 |
+
return evaluators
|
43 |
+
|
44 |
+
def add_metrics_to_eval_loaders(evaluators: List[Evaluator], metric_names: List[str]) -> List[Evaluator]:
|
45 |
+
eval_loaders, other_evaluators = ([], [])
|
46 |
+
for evaluator in evaluators:
|
47 |
+
if evaluator.metric_names == []:
|
48 |
+
evaluator.metric_names = metric_names
|
49 |
+
eval_loaders.append(evaluator)
|
50 |
+
else:
|
51 |
+
other_evaluators.append(evaluator)
|
52 |
+
return eval_loaders + other_evaluators
|
53 |
+
|
54 |
+
def build_icl_data_and_gauntlet(icl_tasks_config: Union[str, ListConfig], eval_gauntlet_config: Optional[Union[str, DictConfig]], tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int, icl_seq_len: int, icl_subset_num_batches: Optional[int]=None) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]:
|
55 |
+
icl_evaluators, logger_keys = build_icl_evaluators(icl_tasks_config, tokenizer, icl_seq_len, device_eval_batch_size, icl_subset_num_batches=icl_subset_num_batches)
|
56 |
+
eval_gauntlet_cb = None
|
57 |
+
if eval_gauntlet_config is not None:
|
58 |
+
if isinstance(eval_gauntlet_config, str):
|
59 |
+
with open(eval_gauntlet_config, 'r') as icl_f:
|
60 |
+
eval_gauntlet_cfg = om.load(icl_f)
|
61 |
+
eval_gauntlet = eval_gauntlet_cfg.eval_gauntlet
|
62 |
+
elif isinstance(eval_gauntlet_config, DictConfig):
|
63 |
+
eval_gauntlet = eval_gauntlet_config
|
64 |
+
else:
|
65 |
+
raise ValueError(f'Got invalid type for eval_gauntlet_config: {type(eval_gauntlet_config)}')
|
66 |
+
eval_gauntlet.logger_keys = logger_keys
|
67 |
+
eval_gauntlet.benchmark_sizes = {e.label: e.dataloader.num_samples for e in icl_evaluators}
|
68 |
+
eval_gauntlet_cb = EvalGauntlet(**eval_gauntlet)
|
69 |
+
return (icl_evaluators, logger_keys, eval_gauntlet_cb)
|
70 |
+
|
71 |
+
def build_composer_model(name: str, cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, init_context: Optional[ContextManager]=None, master_weights_dtype: Optional[str]=None) -> ComposerModel:
|
72 |
+
"""Builds a ComposerModel from the registry.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
name (str): Name of the model to build.
|
76 |
+
cfg (DictConfig): Configuration for the model.
|
77 |
+
tokenizer (PreTrainedTokenizerBase): Tokenizer to use.
|
78 |
+
init_context (Optional[ContextManager], optional): Context manager to use for initialization. Defaults to None.
|
79 |
+
master_weights_dtype (Optional[str], optional): Master weights dtype. Defaults to None.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
ComposerModel: _description_
|
83 |
+
"""
|
84 |
+
if init_context is None:
|
85 |
+
init_context = contextlib.nullcontext()
|
86 |
+
with init_context:
|
87 |
+
model = construct_from_registry(name=name, registry=registry.models, pre_validation_function=ComposerModel, post_validation_function=None, kwargs={'om_model_config': cfg, 'tokenizer': tokenizer})
|
88 |
+
str_dtype_to_torch_dtype = {'f16': torch.float16, 'float16': torch.float16, 'bf16': torch.bfloat16, 'bfloat16': torch.bfloat16}
|
89 |
+
if master_weights_dtype is not None:
|
90 |
+
if master_weights_dtype not in str_dtype_to_torch_dtype:
|
91 |
+
raise ValueError(f'Invalid master_weights_dtype: {master_weights_dtype}. ' + f'Valid options are: {list(str_dtype_to_torch_dtype.keys())}.')
|
92 |
+
dtype = str_dtype_to_torch_dtype[master_weights_dtype]
|
93 |
+
model = model.to(dtype=dtype)
|
94 |
+
return model
|
95 |
+
|
96 |
+
def build_callback(name: str, kwargs: Optional[Dict[str, Any]]=None, config: Any=None) -> Callback:
|
97 |
+
"""Builds a callback from the registry."""
|
98 |
+
registry_to_use = registry.callbacks
|
99 |
+
if name in registry.callbacks_with_config:
|
100 |
+
if kwargs is None:
|
101 |
+
kwargs = {}
|
102 |
+
if 'config' in kwargs:
|
103 |
+
raise ValueError(f'`config` is a reserved keyword for callbacks with config. Please remove it from the kwargs.')
|
104 |
+
kwargs['config'] = config
|
105 |
+
registry_to_use = registry.callbacks_with_config
|
106 |
+
return construct_from_registry(name=name, registry=registry_to_use, partial_function=True, pre_validation_function=Callback, post_validation_function=None, kwargs=kwargs)
|
107 |
+
|
108 |
+
def build_logger(name: str, kwargs: Optional[Dict[str, Any]]=None) -> LoggerDestination:
|
109 |
+
"""Builds a logger from the registry."""
|
110 |
+
return construct_from_registry(name=name, registry=registry.loggers, partial_function=True, pre_validation_function=LoggerDestination, post_validation_function=None, kwargs=kwargs)
|
111 |
+
|
112 |
+
def build_algorithm(name: str, kwargs: Optional[Dict[str, Any]]=None) -> Algorithm:
|
113 |
+
"""Builds an algorithm from the registry."""
|
114 |
+
return construct_from_registry(name=name, registry=registry.algorithms, partial_function=True, pre_validation_function=Algorithm, post_validation_function=None, kwargs=kwargs)
|
115 |
+
|
116 |
+
def build_metric(name: str, kwargs: Optional[Dict[str, Any]]=None) -> Metric:
|
117 |
+
"""Builds a metric from the registry."""
|
118 |
+
return construct_from_registry(name=name, registry=registry.metrics, partial_function=True, pre_validation_function=Metric, post_validation_function=None, kwargs=kwargs)
|
119 |
+
|
120 |
+
def _extract_param_groups(model: torch.nn.Module, optimizer_config: Optional[Dict[str, Any]]=None) -> Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]:
|
121 |
+
"""Extracts parameter groups defined in the optimizer config.
|
122 |
+
|
123 |
+
The optimizer_config defines the optimizer args. It can additionally have key
|
124 |
+
`disable_grad` which is a string or list of strings. If a string matches a
|
125 |
+
parameter name, then that parameter will have `requires_grad=False`. This is
|
126 |
+
useful for freezing parameters. It can additionally have a key
|
127 |
+
`param_groups` which is a list of dicts. In this dict, key `param_str_match`
|
128 |
+
defines a string; if a parameter name contains this string, then it will be
|
129 |
+
in this parameter group. This is useful for grouping parameters together.
|
130 |
+
The dict can also contain any other key that is a valid optimizer arg.
|
131 |
+
Note: to handle name overlap conflicts, params are assigned to parameter
|
132 |
+
groups and added to `param_groups` in the order that `param_str_match` appear
|
133 |
+
in `param_groups`.
|
134 |
+
|
135 |
+
Usage
|
136 |
+
To disable gradient for all parameters that contain the string "norm" or "bias":
|
137 |
+
```
|
138 |
+
optimizer_config: {
|
139 |
+
"name": "decoupled_lionw",
|
140 |
+
"lr": 1e-3,
|
141 |
+
"weight_decay": 1e-2,
|
142 |
+
"betas": [0.9, 0.999],
|
143 |
+
"eps": 1e-8,
|
144 |
+
"disable_grad": ["norm", "bias"]
|
145 |
+
}
|
146 |
+
```
|
147 |
+
|
148 |
+
To create and modify the optimizer parameters for all parameters that contain
|
149 |
+
the string "norm" and "bias" separately:
|
150 |
+
```
|
151 |
+
optimizer_config: {
|
152 |
+
"name": "decoupled_lionw",
|
153 |
+
"lr": 1e-3,
|
154 |
+
"weight_decay": 1e-2,
|
155 |
+
"betas": [0.9, 0.999],
|
156 |
+
"eps": 1e-8,
|
157 |
+
"param_groups": [
|
158 |
+
{
|
159 |
+
"param_str_match": "norm",
|
160 |
+
"lr": 1e-4,
|
161 |
+
"weight_decay": 0.0,
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"param_str_match": "bias",
|
165 |
+
"lr": 5e-4,
|
166 |
+
"weight_decay": 0.0,
|
167 |
+
},
|
168 |
+
],
|
169 |
+
}
|
170 |
+
```
|
171 |
+
|
172 |
+
Args:
|
173 |
+
model (torch.nn.Module): model to extract parameters from
|
174 |
+
optimizer_config (Dict[str, Any]): optimizer config
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: an iterable of
|
178 |
+
torch.Tensor's or dict's. Specifies what Tensors should be optimized
|
179 |
+
and their param groupings.
|
180 |
+
"""
|
181 |
+
if optimizer_config is None:
|
182 |
+
return model.parameters()
|
183 |
+
if 'disable_grad' in optimizer_config.keys():
|
184 |
+
str_matches = optimizer_config.pop('disable_grad')
|
185 |
+
if isinstance(str_matches, str):
|
186 |
+
str_matches = [str_matches]
|
187 |
+
for str_match in str_matches:
|
188 |
+
for n, p in model.named_parameters():
|
189 |
+
if re.search(str_match, n):
|
190 |
+
p.requires_grad = False
|
191 |
+
log.debug(f'Setting `{n}.requires_grad = False`.')
|
192 |
+
param_groups_config = optimizer_config.pop('param_groups', None)
|
193 |
+
if param_groups_config is not None:
|
194 |
+
params = []
|
195 |
+
param_dict = OrderedDict(((n, p) for n, p in model.named_parameters()))
|
196 |
+
log.debug(f'Default optimizer settings: {optimizer_config}.')
|
197 |
+
for param_group_config in param_groups_config:
|
198 |
+
str_match = param_group_config.pop('param_str_match')
|
199 |
+
filter_fn = functools.partial(re.search, str_match)
|
200 |
+
param_names = [n for n in param_dict.keys() if filter_fn(n)]
|
201 |
+
group_params = {'params': [param_dict.pop(n) for n in param_names]}
|
202 |
+
group_params.update(param_group_config)
|
203 |
+
log.debug(f'Creating optimizer param_group with parameters: {param_names} ' + f'(extracted using str_match={str_match!r}). The param_group optimizer ' + f'setting overrides are: {param_group_config}.')
|
204 |
+
params.append(group_params)
|
205 |
+
params.insert(0, {'params': param_dict.values()})
|
206 |
+
return params
|
207 |
+
return model.parameters()
|
208 |
+
|
209 |
+
def build_optimizer(model: torch.nn.Module, name: str, optimizer_config: Optional[Dict[str, Any]]=None) -> Optimizer:
|
210 |
+
params = _extract_param_groups(model, optimizer_config)
|
211 |
+
kwargs = optimizer_config
|
212 |
+
if kwargs is None:
|
213 |
+
kwargs = {}
|
214 |
+
if 'params' in kwargs:
|
215 |
+
raise ValueError('The `params` will be automatically extracted from the model and ' + 'optimizer config. Please remove it from the optimizer config kwargs.')
|
216 |
+
kwargs['params'] = params
|
217 |
+
return construct_from_registry(name=name, registry=registry.optimizers, partial_function=True, pre_validation_function=Optimizer, post_validation_function=None, kwargs=kwargs)
|
218 |
+
|
219 |
+
def build_scheduler(name: str, scheduler_config: Optional[Dict[str, Any]]=None) -> ComposerScheduler:
|
220 |
+
return construct_from_registry(name=name, registry=registry.schedulers, partial_function=True, pre_validation_function=ComposerScheduler, post_validation_function=None, kwargs=scheduler_config)
|
221 |
+
|
222 |
+
def build_tokenizer(tokenizer_name: str, tokenizer_kwargs: Dict[str, Any]) -> PreTrainedTokenizerBase:
|
223 |
+
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
|
224 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
225 |
+
signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup'
|
226 |
+
if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
|
227 |
+
with dist.local_rank_zero_download_and_wait(signal_file_path):
|
228 |
+
pass
|
229 |
+
if tokenizer_name.startswith('tiktoken'):
|
230 |
+
tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs)
|
231 |
+
else:
|
232 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_kwargs)
|
233 |
+
tokenizer.model_max_length = tokenizer_kwargs.get('model_max_length', int(1e+30))
|
234 |
+
if not hasattr(tokenizer, 'eos_token') or tokenizer.eos_token is None:
|
235 |
+
raise ValueError(f'The tokenizer {tokenizer_name} must have an eos_token.')
|
236 |
+
if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
|
237 |
+
if dist.get_local_rank() == 0:
|
238 |
+
with open(signal_file_path, 'wb') as f:
|
239 |
+
f.write(b'local_rank0_completed_tokenizer_setup')
|
240 |
+
dist.barrier()
|
241 |
+
if dist.get_local_rank() == 0:
|
242 |
+
os.remove(signal_file_path)
|
243 |
+
return tokenizer
|
244 |
+
|
245 |
+
def build_icl_evaluators(icl_tasks: Union[str, ListConfig], tokenizer: PreTrainedTokenizerBase, default_max_seq_len: int, default_batch_size: int, destination_dir: Optional[str]=None, icl_subset_num_batches: Optional[int]=None) -> Tuple[List[Evaluator], List[str]]:
|
246 |
+
if destination_dir is None:
|
247 |
+
destination_dir = os.getcwd()
|
248 |
+
evaluators = []
|
249 |
+
logger_keys = []
|
250 |
+
icl_tasks_list = None
|
251 |
+
if isinstance(icl_tasks, str):
|
252 |
+
log.info(f'Extracting ICL task config from path: {icl_tasks}')
|
253 |
+
with open(icl_tasks, 'r') as icl_f:
|
254 |
+
icl_task_cfg = om.load(icl_f)
|
255 |
+
icl_tasks_list = icl_task_cfg.icl_tasks
|
256 |
+
else:
|
257 |
+
icl_tasks_list = icl_tasks
|
258 |
+
|
259 |
+
def _validate_cfg(icl_cfg: DictConfig):
|
260 |
+
assert 'label' in icl_cfg
|
261 |
+
assert 'dataset_uri' in icl_cfg and icl_cfg.dataset_uri is not None
|
262 |
+
assert 'icl_task_type' in icl_cfg
|
263 |
+
assert 'num_fewshot' in icl_cfg
|
264 |
+
if 'metric_names' not in icl_cfg:
|
265 |
+
if icl_cfg.icl_task_type == 'language_modeling':
|
266 |
+
icl_cfg.metric_names = ['InContextLearningLMAccuracy']
|
267 |
+
elif icl_cfg.icl_task_type == 'multiple_choice':
|
268 |
+
icl_cfg.metric_names = ['InContextLearningMultipleChoiceAccuracy']
|
269 |
+
elif icl_cfg.icl_task_type == 'schema':
|
270 |
+
icl_cfg.metric_names = ['InContextLearningMultipleChoiceAccuracy']
|
271 |
+
elif icl_cfg.icl_task_type == 'question_answering':
|
272 |
+
icl_cfg.metric_names = ['InContextLearningQAAccuracy']
|
273 |
+
elif icl_cfg.icl_task_type == 'code_evaluation':
|
274 |
+
icl_cfg.metric_names = ['InContextLearningCodeEvalAccuracy']
|
275 |
+
else:
|
276 |
+
raise ValueError(f'No metric_names defined, unable to build default metrics for icl_task_type={icl_cfg.icl_task_type}.')
|
277 |
+
if 'prompt_string' not in icl_cfg:
|
278 |
+
icl_cfg.prompt_string = ''
|
279 |
+
if 'example_delimiter' not in icl_cfg:
|
280 |
+
icl_cfg.example_delimiter = '\n'
|
281 |
+
if 'continuation_delimiter' not in icl_cfg:
|
282 |
+
icl_cfg.continuation_delimiter = ' '
|
283 |
+
if 'max_seq_len' not in icl_cfg:
|
284 |
+
icl_cfg.max_seq_len = default_max_seq_len
|
285 |
+
if 'batch_size' not in icl_cfg:
|
286 |
+
icl_cfg.batch_size = default_batch_size
|
287 |
+
if 'pass_at_k' not in icl_cfg:
|
288 |
+
icl_cfg.pass_at_k = 1
|
289 |
+
if 'fewshot_random_seed' not in icl_cfg:
|
290 |
+
icl_cfg.fewshot_random_seed = 1234
|
291 |
+
if 'generations_per_sample' not in icl_cfg:
|
292 |
+
icl_cfg.generations_per_sample = 1
|
293 |
+
if 'num_beams' in icl_cfg:
|
294 |
+
raise ValueError('num_beams is no longer supported as a top level icl_task parameter.' + 'Please use generation_kwargs.num_beams instead.')
|
295 |
+
for icl_cfg in icl_tasks_list:
|
296 |
+
assert isinstance(icl_cfg, DictConfig)
|
297 |
+
_validate_cfg(icl_cfg)
|
298 |
+
for num_fewshot in list(icl_cfg.num_fewshot):
|
299 |
+
if tokenizer.pad_token_id is None:
|
300 |
+
pad_tok_id = tokenizer.eos_token_id
|
301 |
+
else:
|
302 |
+
pad_tok_id = tokenizer.pad_token_id
|
303 |
+
label = f'{icl_cfg.label}/{num_fewshot}-shot'
|
304 |
+
metric_names = list(icl_cfg.metric_names)
|
305 |
+
destination_path = f'{destination_dir}/{icl_cfg.label}-{num_fewshot}.jsonl'
|
306 |
+
if dist.get_local_rank() == 0 and os.path.exists(destination_path):
|
307 |
+
os.remove(destination_path)
|
308 |
+
dist.barrier()
|
309 |
+
hf_parsing_map = icl_cfg.get('hf_parsing_map', {})
|
310 |
+
hf_loading_vars = icl_cfg.get('hf_loading_vars', {})
|
311 |
+
early_stopping_criteria = icl_cfg.get('early_stopping_criteria', None)
|
312 |
+
if isinstance(early_stopping_criteria, ListConfig):
|
313 |
+
early_stopping_criteria = om.to_container(early_stopping_criteria)
|
314 |
+
assert early_stopping_criteria is None or isinstance(early_stopping_criteria, list)
|
315 |
+
dataloaders = get_icl_task_dataloader(icl_cfg.icl_task_type, icl_cfg.dataset_uri, tokenizer, batch_size=icl_cfg.batch_size, max_seq_len=icl_cfg.max_seq_len, pad_tok_id=pad_tok_id, num_fewshot=num_fewshot, prompt_string=icl_cfg.prompt_string, example_delimiter=icl_cfg.example_delimiter, hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, continuation_delimiter=icl_cfg.continuation_delimiter, question_prelimiter=icl_cfg.get('question_prelimiter', ''), destination_path=destination_path, fewshot_random_seed=icl_cfg.fewshot_random_seed, pass_at_k=icl_cfg.pass_at_k, generations_per_sample=icl_cfg.generations_per_sample, has_categories=icl_cfg.get('has_categories', False), cot_delimiter=icl_cfg.get('cot_delimiter', ''), generation_kwargs=icl_cfg.get('generation_kwargs', {}), early_stopping_criteria=early_stopping_criteria, do_normalization=icl_cfg.get('do_normalization', True))
|
316 |
+
if hasattr(icl_cfg, 'has_categories') and icl_cfg.has_categories and isinstance(dataloaders, dict):
|
317 |
+
for category in dataloaders.keys():
|
318 |
+
logger_keys.extend([f'metrics/{label}/{category}/{m}' for m in metric_names])
|
319 |
+
evaluators.append(Evaluator(label=f'{label}/{category}', dataloader=dataloaders[category], metric_names=metric_names))
|
320 |
+
else:
|
321 |
+
logger_keys.extend([f'metrics/{label}/{m}' for m in metric_names])
|
322 |
+
evaluators.append(Evaluator(label=label, dataloader=dataloaders, metric_names=metric_names, subset_num_batches=icl_subset_num_batches))
|
323 |
+
return (evaluators, logger_keys)
|
callback_with_config.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
class CallbackWithConfig(Callback, abc.ABC):
|
5 |
+
"""A callback that takes a config dictionary as an argument, in addition to.
|
6 |
+
|
7 |
+
its other kwargs.
|
8 |
+
"""
|
9 |
+
|
10 |
+
def __init__(self, config: dict[str, Any], *args: Any, **kwargs: Any) -> None:
|
11 |
+
del config, args, kwargs
|
12 |
+
pass
|
callbacks.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .async_eval_callback import AsyncEval
|
2 |
+
from .curriculum_learning_callback import CurriculumLearning
|
3 |
+
from .eval_gauntlet_callback import EvalGauntlet
|
4 |
+
from .fdiff_callback import FDiffMetrics
|
5 |
+
from .hf_checkpointer import HuggingFaceCheckpointer
|
6 |
+
from .monolithic_ckpt_callback import MonolithicCheckpointSaver
|
7 |
+
from .resumption_callbacks import GlobalLRScaling, LayerFreezing
|
8 |
+
from .scheduled_gc_callback import ScheduledGarbageCollector
|
9 |
+
from .registry import callbacks, callbacks_with_config
|
10 |
+
callbacks.register('lr_monitor', func=LRMonitor)
|
11 |
+
callbacks.register('memory_monitor', func=MemoryMonitor)
|
12 |
+
callbacks.register('memory_snapshot', func=MemorySnapshot)
|
13 |
+
callbacks.register('speed_monitor', func=SpeedMonitor)
|
14 |
+
callbacks.register('runtime_estimator', func=RuntimeEstimator)
|
15 |
+
callbacks.register('optimizer_monitor', func=OptimizerMonitor)
|
16 |
+
callbacks.register('generate_callback', func=Generate)
|
17 |
+
callbacks.register('early_stopper', func=EarlyStopper)
|
18 |
+
callbacks.register('fdiff_metrics', func=FDiffMetrics)
|
19 |
+
callbacks.register('hf_checkpointer', func=HuggingFaceCheckpointer)
|
20 |
+
callbacks.register('global_lr_scaling', func=GlobalLRScaling)
|
21 |
+
callbacks.register('layer_freezing', func=LayerFreezing)
|
22 |
+
callbacks.register('mono_checkpoint_saver', func=MonolithicCheckpointSaver)
|
23 |
+
callbacks.register('scheduled_gc', func=ScheduledGarbageCollector)
|
24 |
+
callbacks.register('oom_observer', func=OOMObserver)
|
25 |
+
callbacks_with_config.register('async_eval', func=AsyncEval)
|
26 |
+
callbacks_with_config.register('curriculum_learning', func=CurriculumLearning)
|
checkpoint_conversion_helpers.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Helper methods for the checkpoint conversion scripts.
|
2 |
+
|
3 |
+
The checkpoint conversion scripts are located in the
|
4 |
+
llmfoundry/scripts/inference/benchmarking/ folder. Users should run those
|
5 |
+
scripts directly to convert between checkpoints; this file contains only common
|
6 |
+
utility functions that are present in multiple scripts.
|
7 |
+
"""
|
8 |
+
import json
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import random
|
12 |
+
import string
|
13 |
+
from pathlib import Path
|
14 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
15 |
+
import numpy as np
|
16 |
+
import sentencepiece as spm
|
17 |
+
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
18 |
+
log = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
def _get_weight_data_type(data_type: str):
|
21 |
+
if data_type == 'fp32':
|
22 |
+
return np.float32
|
23 |
+
elif data_type == 'fp16':
|
24 |
+
return np.float16
|
25 |
+
else:
|
26 |
+
raise RuntimeError('Unsupported data type: {data_type} for conversion.')
|
27 |
+
|
28 |
+
def get_hf_tokenizer_from_composer_state_dict(state_dict: Dict[str, Any], trust_remote_code: bool, tokenizer_save_dir: Optional[str]=None) -> Optional[PreTrainedTokenizer]:
|
29 |
+
if 'state' not in state_dict:
|
30 |
+
raise RuntimeError('Unexpected composer state dictionary. Did you pass in a full composer checkpoint?')
|
31 |
+
if 'integrations' not in state_dict['state'] or 'huggingface' not in state_dict['state']['integrations']:
|
32 |
+
raise RuntimeError('Did not find HuggingFace related state (e.g., tokenizer) in the provided composer checkpoint!')
|
33 |
+
hf_tokenizer_state = state_dict['state']['integrations']['huggingface']['tokenizer']
|
34 |
+
hf_tokenizer = None
|
35 |
+
if hf_tokenizer_state != {}:
|
36 |
+
if tokenizer_save_dir is None:
|
37 |
+
unique_suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=6))
|
38 |
+
tokenizer_save_dir = os.path.join(os.getcwd(), f'tokenizer-save-dir-{unique_suffix}')
|
39 |
+
os.makedirs(tokenizer_save_dir, exist_ok=True)
|
40 |
+
for filename, saved_content in hf_tokenizer_state.items():
|
41 |
+
if filename.endswith(saved_content['file_extension']):
|
42 |
+
tokenizer_file_name = filename
|
43 |
+
else:
|
44 |
+
tokenizer_file_name = filename + saved_content['file_extension']
|
45 |
+
tokenizer_file_path = Path(tokenizer_save_dir) / tokenizer_file_name
|
46 |
+
if saved_content['file_extension'] == '.json':
|
47 |
+
with open(tokenizer_file_path, 'w') as _tmp_file:
|
48 |
+
json.dump(saved_content['content'], _tmp_file)
|
49 |
+
elif saved_content['file_extension'] == '.txt':
|
50 |
+
with open(tokenizer_file_path, 'w') as _tmp_file:
|
51 |
+
for line in saved_content['content']:
|
52 |
+
_tmp_file.write(line)
|
53 |
+
_tmp_file.write('\n')
|
54 |
+
elif saved_content['file_extension'] == '.py':
|
55 |
+
with open(tokenizer_file_path, 'w') as _tmp_file:
|
56 |
+
_tmp_file.write(saved_content['content'])
|
57 |
+
elif saved_content['file_extension'] == '.model':
|
58 |
+
s = spm.SentencePieceProcessor()
|
59 |
+
s.load_from_serialized_proto(saved_content['content'])
|
60 |
+
with open(tokenizer_file_path, 'wb') as _tmp_file:
|
61 |
+
_tmp_file.write(s.serialized_model_proto())
|
62 |
+
hf_tokenizer = load_tokenizer(tokenizer_save_dir, trust_remote_code=trust_remote_code)
|
63 |
+
hf_tokenizer.name_or_path = ''
|
64 |
+
hf_tokenizer.init_kwargs['name_or_path'] = ''
|
65 |
+
return hf_tokenizer
|
66 |
+
|
67 |
+
def load_tokenizer(tokenizer_save_dir: str, trust_remote_code: bool) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
68 |
+
try:
|
69 |
+
return AutoTokenizer.from_pretrained(tokenizer_save_dir, trust_remote_code=trust_remote_code)
|
70 |
+
except ValueError as e:
|
71 |
+
raise ValueError(f'Got error while loading tokenizer with trust_remote_code={trust_remote_code}: {e}. ' + 'If accessing a tokenizer defined outside of the transformers module,' + ' please use --trust_remote_code.')
|
72 |
+
|
73 |
+
def _write_zero_bias(weight_name: str, weight_file_path: str, bias_shape: Union[Tuple[int, ...], int], np_data_type: np.dtype) -> None:
|
74 |
+
"""Write zeros for bias when converting MPT to FasterTransformer weights.
|
75 |
+
|
76 |
+
MPT model might not have bias while FT expects bias.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
weight_name (str): Name of the weight tensor.
|
80 |
+
weight_file_path (str): Output path for storing the weight (NOT zero bias).
|
81 |
+
bias_shape (Union[Tuple[int, ...], int]): Shape of the bias array.
|
82 |
+
np_data_type (np.dtype): The data type for bias.
|
83 |
+
"""
|
84 |
+
if 'weight' not in weight_file_path:
|
85 |
+
raise RuntimeError(f'Cannot write zero bias for {weight_name}. Input is not a weight tensor')
|
86 |
+
log.debug(f'zero bias for weight: {weight_name}')
|
87 |
+
bias_file_path = weight_file_path.replace('.weight', '.bias')
|
88 |
+
bias = np.zeros(bias_shape, dtype=np_data_type)
|
89 |
+
bias.tofile(bias_file_path)
|
90 |
+
|
91 |
+
def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int, tensor_name: str, config: Dict[str, Any], data: np.ndarray, np_weight_data_type: np.dtype) -> None:
|
92 |
+
"""Convert each MPT weight to a FasterTransformer compatible format.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
save_dir (str): Path of the directory to save the weight in FT format. The directory must already exist.
|
96 |
+
infer_gpu_num (int): The number of gpus you are planning to use for inference.
|
97 |
+
tensor_name (str): Name of the weight tensor. Used in naming the output file.
|
98 |
+
config (Dict[str, Any]): Configuration for the model. This is used in getting model specific parameters.
|
99 |
+
data (np.ndarray): Tensor data in np.ndarray format.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
None: Writes to a file in `save_dir`. File name is based on the `tensor_name`
|
103 |
+
"""
|
104 |
+
if tensor_name.find('input_layernorm.weight') != -1 or tensor_name.find('input_layernorm.bias') != -1 or tensor_name.find('attention.dense.bias') != -1 or (tensor_name.find('post_attention_layernorm.weight') != -1) or (tensor_name.find('post_attention_layernorm.bias') != -1) or (tensor_name.find('mlp.dense_4h_to_h.bias') != -1) or (tensor_name.find('final_layernorm.weight') != -1) or (tensor_name.find('final_layernorm.bias') != -1):
|
105 |
+
save_path = os.path.join(save_dir, f'model.{tensor_name}.bin')
|
106 |
+
data.tofile(save_path)
|
107 |
+
if 'weight' in tensor_name and config['no_bias']:
|
108 |
+
_write_zero_bias(tensor_name, save_path, data.shape[-1], np_weight_data_type)
|
109 |
+
elif tensor_name.find('attention.dense.weight') != -1:
|
110 |
+
assert data.shape == (config['d_model'], config['d_model']), f'unexpected dim for {tensor_name}'
|
111 |
+
data = data.T
|
112 |
+
split_vals = np.split(data, infer_gpu_num, axis=0)
|
113 |
+
for j in range(infer_gpu_num):
|
114 |
+
save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin')
|
115 |
+
split_vals[j].tofile(save_path)
|
116 |
+
if config['no_bias']:
|
117 |
+
fake_weight_path = os.path.join(save_dir, f'model.{tensor_name}.bin')
|
118 |
+
_write_zero_bias(tensor_name, fake_weight_path, data.shape[-1], np_weight_data_type)
|
119 |
+
elif tensor_name.find('mlp.dense_4h_to_h.weight') != -1:
|
120 |
+
assert data.shape == (config['d_model'], config['expansion_ratio'] * config['d_model']), f'unexpected dim for {tensor_name}'
|
121 |
+
data = data.T
|
122 |
+
split_vals = np.split(data, infer_gpu_num, axis=0)
|
123 |
+
for j in range(infer_gpu_num):
|
124 |
+
save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin')
|
125 |
+
split_vals[j].tofile(save_path)
|
126 |
+
if config['no_bias']:
|
127 |
+
fake_weight_path = os.path.join(save_dir, f'model.{tensor_name}.bin')
|
128 |
+
_write_zero_bias(tensor_name, fake_weight_path, data.shape[-1], np_weight_data_type)
|
129 |
+
elif tensor_name.find('mlp.dense_h_to_4h.weight') != -1:
|
130 |
+
assert data.shape == (config['expansion_ratio'] * config['d_model'], config['d_model']), f'unexpected dim for {tensor_name}'
|
131 |
+
data = data.T
|
132 |
+
split_vals = np.split(data, infer_gpu_num, axis=-1)
|
133 |
+
for j in range(infer_gpu_num):
|
134 |
+
save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin')
|
135 |
+
split_vals[j].tofile(save_path)
|
136 |
+
if config['no_bias']:
|
137 |
+
_write_zero_bias(tensor_name, save_path, split_vals[j].shape[-1], np_weight_data_type)
|
138 |
+
elif tensor_name.find('mlp.dense_h_to_4h.bias') != -1:
|
139 |
+
assert data.shape == (config['expansion_ratio'] * config['d_model'],), f'unexpected dim for {tensor_name}'
|
140 |
+
split_vals = np.split(data, infer_gpu_num, axis=-1)
|
141 |
+
for j in range(infer_gpu_num):
|
142 |
+
save_path = os.path.join(save_dir + f'model.{tensor_name}.{j}.bin')
|
143 |
+
split_vals[j].tofile(save_path)
|
144 |
+
elif tensor_name.find('attention.query_key_value.bias') != -1:
|
145 |
+
assert data.shape == (3 * config['d_model'],), f'unexpected dim for {tensor_name}'
|
146 |
+
data = data.reshape(3, config['d_model'])
|
147 |
+
split_vals = np.split(data, infer_gpu_num, axis=-1)
|
148 |
+
for j in range(infer_gpu_num):
|
149 |
+
save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin')
|
150 |
+
split_vals[j].tofile(save_path)
|
151 |
+
elif tensor_name.find('attention.query_key_value.weight') != -1:
|
152 |
+
assert data.shape == (3 * config['d_model'], config['d_model']), f'unexpected dim for {tensor_name}'
|
153 |
+
data = data.T
|
154 |
+
data = data.reshape(config['d_model'], 3, config['d_model'])
|
155 |
+
split_vals = np.split(data, infer_gpu_num, axis=-1)
|
156 |
+
for j in range(infer_gpu_num):
|
157 |
+
save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin')
|
158 |
+
split_vals[j].tofile(save_path)
|
159 |
+
if config['no_bias']:
|
160 |
+
_write_zero_bias(tensor_name, save_path, (3, split_vals[j].shape[-1]), np_weight_data_type)
|
161 |
+
else:
|
162 |
+
raise RuntimeError(f'Tensor with name {tensor_name} is not handled')
|
163 |
+
|
164 |
+
def convert_and_save_ft_weights(named_params: dict, config: dict, infer_gpu_num: int=1, weight_data_type: str='fp32', save_dir: str='') -> None:
|
165 |
+
"""Convert a Composer MPT checkpoint to a FasterTransformer format.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
named_params (Dict[str, Parameter]): A dictionary containing the Composer MPT model's parameter names and data.
|
169 |
+
config (Dict[str, Any]): Configuration for the model. This is used in getting model specific parameters.
|
170 |
+
infer_gpu_num (int): The number of gpus you are planning to use for inference.
|
171 |
+
weight_data_type (str): The dtype of the converted FasterTransformer model.
|
172 |
+
save_dir (str): Path of the directory to save the weight in FT format. The directory must already exist.
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
None: Writes to the `save_dir` folder. File names within this folder are based on the model parameter names.
|
176 |
+
"""
|
177 |
+
np_weight_data_type = _get_weight_data_type(weight_data_type)
|
178 |
+
param_remapping = {'norm_1.bias': 'input_layernorm.bias', 'norm_1.weight': 'input_layernorm.weight', 'attn.Wqkv.bias': 'attention.query_key_value.bias', 'attn.Wqkv.weight': 'attention.query_key_value.weight', 'attn.out_proj.bias': 'attention.dense.bias', 'attn.out_proj.weight': 'attention.dense.weight', 'norm_2.bias': 'post_attention_layernorm.bias', 'norm_2.weight': 'post_attention_layernorm.weight', 'ffn.up_proj.bias': 'mlp.dense_h_to_4h.bias', 'ffn.up_proj.weight': 'mlp.dense_h_to_4h.weight', 'ffn.down_proj.bias': 'mlp.dense_4h_to_h.bias', 'ffn.down_proj.weight': 'mlp.dense_4h_to_h.weight'}
|
179 |
+
for name, param in named_params.items():
|
180 |
+
log.debug(f'Working on parameter {name} ...')
|
181 |
+
data = param.detach().cpu().numpy().astype(np_weight_data_type)
|
182 |
+
if name.find('weight') == -1 and name.find('bias') == -1:
|
183 |
+
log.debug(f'found a parameter name that is not handled: {name}')
|
184 |
+
continue
|
185 |
+
if name == 'transformer.wpe.weight':
|
186 |
+
assert data.shape == (config['max_seq_len'], config['d_model']), f'unexpected dim for {name}'
|
187 |
+
data.tofile(os.path.join(save_dir, 'model.wpe.bin'))
|
188 |
+
elif name == 'transformer.wte.weight':
|
189 |
+
assert data.shape == (config['vocab_size'], config['d_model']), f'unexpected dim for {name}'
|
190 |
+
data.tofile(os.path.join(save_dir, 'model.wte.bin'))
|
191 |
+
elif name == 'transformer.norm_f.bias':
|
192 |
+
assert data.shape == (config['d_model'],), f'unexpected dim for {name}'
|
193 |
+
data.tofile(os.path.join(save_dir, 'model.final_layernorm.bias.bin'))
|
194 |
+
elif name == 'transformer.norm_f.weight':
|
195 |
+
assert data.shape == (config['d_model'],), f'unexpected dim for {name}'
|
196 |
+
save_path = os.path.join(save_dir, 'model.final_layernorm.weight.bin')
|
197 |
+
data.tofile(save_path)
|
198 |
+
if config['no_bias']:
|
199 |
+
_write_zero_bias(name, save_path, data.shape[-1], np_weight_data_type)
|
200 |
+
elif name == 'transformer.lm_head.weight':
|
201 |
+
data.tofile(os.path.join(save_dir, 'model.lm_head.weight.bin'))
|
202 |
+
else:
|
203 |
+
for mpt_pattern, ft_pattern in param_remapping.items():
|
204 |
+
if name.find(mpt_pattern) != -1:
|
205 |
+
new_name = name.replace('transformer.blocks.', 'layers.').replace(mpt_pattern, ft_pattern)
|
206 |
+
_convert_weight_to_ft_each(save_dir, infer_gpu_num, new_name, config, data, np_weight_data_type)
|
collator.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import warnings
|
3 |
+
from typing import Any, Dict, List, Optional, Union
|
4 |
+
import torch
|
5 |
+
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
6 |
+
log = logging.getLogger(__name__)
|
7 |
+
_HF_IGNORE_INDEX = -100
|
8 |
+
TokenizedExample = Dict[str, List[Dict[str, List[int]]]]
|
9 |
+
|
10 |
+
def ensure_list(x: Union[List, torch.Tensor]) -> List:
|
11 |
+
if isinstance(x, torch.Tensor):
|
12 |
+
x = list(x.flatten())
|
13 |
+
assert isinstance(x, list)
|
14 |
+
return x
|
15 |
+
|
16 |
+
def validate_target_settings(target_prompts: str, target_responses: str, decoder_only_format: bool):
|
17 |
+
"""Raises an error if target settings are invalid."""
|
18 |
+
if not decoder_only_format and (target_prompts != 'none' or target_responses != 'last'):
|
19 |
+
raise ValueError(f'When using encoder_decoder format, you must use target_prompts="none" and target_responses="last".')
|
20 |
+
if target_responses not in {'all', 'last'}:
|
21 |
+
raise ValueError(f'target_responses must be either "last" or "all" but target_responses={target_responses!r}')
|
22 |
+
if target_prompts.startswith('length>='):
|
23 |
+
cutoff = target_prompts[8:]
|
24 |
+
if not cutoff.isdigit():
|
25 |
+
raise ValueError(f'target_prompts starts with "length>=" but the rest of the string is not digits (target_prompts={target_prompts!r}). ' + 'To use this configuration option, set target_prompts "length>=XX" where "XX" is a positive integer indicating ' + 'the length cutoff. Prompts of at least XX tokens in length will be treated as targets.')
|
26 |
+
cutoff = int(cutoff)
|
27 |
+
if cutoff <= 0:
|
28 |
+
raise ValueError(f'You are trying to set the target_prompts length cutoff to a negative number cutoff={cutoff!r}. This is not allowed.')
|
29 |
+
elif target_prompts not in {'all', 'none'}:
|
30 |
+
raise ValueError(f'target_prompts must either be "all", "none" or "length>=XX" where "XX" is a positive integer, but target_prompts={target_prompts!r}')
|
31 |
+
|
32 |
+
def _sequence_to_labels_all(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]:
|
33 |
+
del is_last_turn, cutoff
|
34 |
+
return sequence
|
35 |
+
|
36 |
+
def _sequence_to_labels_none(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]:
|
37 |
+
del is_last_turn, cutoff
|
38 |
+
return [_HF_IGNORE_INDEX] * len(sequence)
|
39 |
+
|
40 |
+
def _sequence_to_labels_last(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]:
|
41 |
+
del cutoff
|
42 |
+
if is_last_turn:
|
43 |
+
return sequence
|
44 |
+
else:
|
45 |
+
return [_HF_IGNORE_INDEX] * len(sequence)
|
46 |
+
|
47 |
+
def _sequence_to_labels_cutoff(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]:
|
48 |
+
del is_last_turn
|
49 |
+
if cutoff is None:
|
50 |
+
raise ValueError('input ``cutoff`` must be provided')
|
51 |
+
if len(sequence) >= cutoff:
|
52 |
+
return sequence
|
53 |
+
else:
|
54 |
+
return [_HF_IGNORE_INDEX] * len(sequence)
|
55 |
+
_TARGET_POLICY_LOOKUP = {'all': _sequence_to_labels_all, 'none': _sequence_to_labels_none, 'last': _sequence_to_labels_last, 'length': _sequence_to_labels_cutoff}
|
56 |
+
|
57 |
+
def stitch_turns_decoder_only(example_turns: list[dict[str, list[int]]], target_prompts: str, target_responses: str, eos_token_id: Optional[int]=None, validate: bool=False) -> tuple[list[int], list[int]]:
|
58 |
+
target_prompts = target_prompts.lower()
|
59 |
+
target_responses = target_responses.lower()
|
60 |
+
if validate:
|
61 |
+
validate_target_settings(target_prompts, target_responses, decoder_only_format=True)
|
62 |
+
if target_prompts.startswith('length'):
|
63 |
+
prompt_cutoff = int(target_prompts.split('>=')[-1])
|
64 |
+
prompt_to_target = _TARGET_POLICY_LOOKUP['length']
|
65 |
+
else:
|
66 |
+
prompt_cutoff = None
|
67 |
+
prompt_to_target = _TARGET_POLICY_LOOKUP[target_prompts]
|
68 |
+
response_to_target = _TARGET_POLICY_LOOKUP[target_responses]
|
69 |
+
input_ids = []
|
70 |
+
labels = []
|
71 |
+
for idx, turn in enumerate(example_turns):
|
72 |
+
is_last_turn = idx + 1 == len(example_turns)
|
73 |
+
context = ensure_list(turn['input_ids'])
|
74 |
+
target = ensure_list(turn['labels'])
|
75 |
+
if is_last_turn and eos_token_id is not None:
|
76 |
+
if target[-1] != eos_token_id:
|
77 |
+
target = target + [eos_token_id]
|
78 |
+
input_ids += context
|
79 |
+
input_ids += target
|
80 |
+
labels += prompt_to_target(context, is_last_turn, prompt_cutoff)
|
81 |
+
labels += response_to_target(target, is_last_turn)
|
82 |
+
if len(input_ids) != len(labels):
|
83 |
+
raise ValueError(f'input_ids and labels should be the same length, len(input_ids)={len(input_ids)!r}, len(labels)={len(labels)!r}')
|
84 |
+
return (input_ids, labels)
|
85 |
+
|
86 |
+
def stitch_turns_encoder_decoder(example_turns: list[dict[str, list[int]]], eos_token_id: Optional[int]=None) -> tuple[list[int], list[int]]:
|
87 |
+
context = []
|
88 |
+
target = None
|
89 |
+
for idx, turn in enumerate(example_turns):
|
90 |
+
is_last_turn = idx + 1 == len(example_turns)
|
91 |
+
turn_context = ensure_list(turn['input_ids'])
|
92 |
+
turn_target = ensure_list(turn['labels'])
|
93 |
+
context += turn_context
|
94 |
+
if is_last_turn:
|
95 |
+
if eos_token_id is not None and turn_target[-1] != eos_token_id:
|
96 |
+
turn_target = turn_target + [eos_token_id]
|
97 |
+
target = turn_target
|
98 |
+
else:
|
99 |
+
context += turn_target
|
100 |
+
if target is None:
|
101 |
+
raise ValueError('target is still None but should be list[int]')
|
102 |
+
return (context, target)
|
103 |
+
|
104 |
+
class Seq2SeqFinetuningCollator:
|
105 |
+
"""A general-purpose collator for sequence-to-sequence training/evaluation.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
tokenizer: A HuggingFace tokenizer. Must have a pad_token set.
|
109 |
+
max_seq_len (int): The maximum sequence length of the combined
|
110 |
+
context/target sequence (decoder-only format) or of each the
|
111 |
+
context sequence and target sequence (encoder-decoder format).
|
112 |
+
decoder_only_format (bool): Whether to format the batches for a
|
113 |
+
decoder-only model (if True) or an encoder-decoder model (if False).
|
114 |
+
target_responses (str): For multi-turn examples, this controls which
|
115 |
+
responses are treated as training targets (i.e. generate loss).
|
116 |
+
Options are:
|
117 |
+
"last": (Default) Only the final response is used as the training
|
118 |
+
target; non-terminal responses are only part of the context.
|
119 |
+
"all": All of the responses are used as training targets.
|
120 |
+
target_prompts (str): This controls which prompts are treated as
|
121 |
+
training targets (i.e. generate loss).
|
122 |
+
Options are:
|
123 |
+
"none": (Default) Prompts are never used as training targets.
|
124 |
+
"all": Prompts are always used as training targets.
|
125 |
+
"length>=XX": Prompt sequences are used as training targets when
|
126 |
+
they have length of at least XX tokens. For instance,
|
127 |
+
setting "length>=512" instructs the collator to use a prompt
|
128 |
+
sequence as a training target when it is at least 512 tokens long.
|
129 |
+
allow_pad_trimming (bool, optional): Whether to allow the collator
|
130 |
+
to trim padding, which may result in smaller but inconsistent batch
|
131 |
+
sizes. Default: ``False`` ensures that all sequences are max_seq_len.
|
132 |
+
batch_metadata (dict, optional): A dictionary of metadata which will be added
|
133 |
+
to the batch.
|
134 |
+
"""
|
135 |
+
|
136 |
+
def __init__(self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], max_seq_len: int, decoder_only_format: bool, target_responses: str='last', target_prompts: str='none', allow_pad_trimming: bool=False, batch_metadata: Optional[Dict[str, Any]]=None):
|
137 |
+
self.tokenizer = tokenizer
|
138 |
+
self.max_seq_len = max_seq_len
|
139 |
+
self.decoder_only_format = decoder_only_format
|
140 |
+
self.target_responses = target_responses.lower()
|
141 |
+
self.target_prompts = target_prompts.lower()
|
142 |
+
self.batch_metadata = batch_metadata or {}
|
143 |
+
self._allow_pad_trimming = allow_pad_trimming
|
144 |
+
self._seen_first_batch = False
|
145 |
+
illegal_keys = ['input_ids', 'labels', 'attention_mask', 'decoder_input_ids', 'decoder_attention_mask']
|
146 |
+
found_keys = []
|
147 |
+
for illegal_key in illegal_keys:
|
148 |
+
if illegal_key in self.batch_metadata:
|
149 |
+
found_keys.append(illegal_key)
|
150 |
+
if found_keys:
|
151 |
+
raise ValueError(f"The following keys are in batch_metadata but are not allowed: {', '.join(found_keys)}.\n" + f'You cannot use keys that are used directly by the models. The prohibited keys are:\n' + f"{', '.join(illegal_keys)}")
|
152 |
+
if max_seq_len % 8 != 0:
|
153 |
+
log.warning('For performance, a max_seq_len as a multiple of 8 is recommended.')
|
154 |
+
if self.tokenizer.pad_token_id is None:
|
155 |
+
raise ValueError(f'{self.__class__.__name__} requires that the tokenizer has the pad token set, but it is None')
|
156 |
+
validate_target_settings(self.target_prompts, self.target_responses, self.decoder_only_format)
|
157 |
+
if self.target_prompts.startswith('length'):
|
158 |
+
self.prompt_cutoff = int(self.target_prompts.split('>=')[-1])
|
159 |
+
self.prompt_to_target = _TARGET_POLICY_LOOKUP['length']
|
160 |
+
else:
|
161 |
+
self.prompt_cutoff = None
|
162 |
+
self.prompt_to_target = _TARGET_POLICY_LOOKUP[self.target_prompts]
|
163 |
+
self.response_to_target = _TARGET_POLICY_LOOKUP[self.target_responses]
|
164 |
+
self._warned_truncated = False
|
165 |
+
self._warned_context = False
|
166 |
+
self._warned_target = False
|
167 |
+
|
168 |
+
def __call__(self, examples: List[TokenizedExample]) -> Dict[str, torch.Tensor]:
|
169 |
+
for check_key in ['input_ids', 'labels']:
|
170 |
+
if check_key not in examples[0]['turns'][0]:
|
171 |
+
raise KeyError(f'Examples returned by dataset do not include required key: {check_key}')
|
172 |
+
if self.decoder_only_format:
|
173 |
+
batch = self._process_and_batch_decoder_only(examples)
|
174 |
+
else:
|
175 |
+
batch = self._process_and_batch_encoder_decoder(examples)
|
176 |
+
batch_size = batch['input_ids'].shape[0]
|
177 |
+
batch.update({k: torch.tensor([v] * batch_size) for k, v in self.batch_metadata.items()})
|
178 |
+
return batch
|
179 |
+
|
180 |
+
def _process_and_batch_decoder_only(self, examples: List[TokenizedExample]) -> Dict[str, torch.Tensor]:
|
181 |
+
processed_examples = []
|
182 |
+
for example in examples:
|
183 |
+
input_ids, labels = stitch_turns_decoder_only(example_turns=example['turns'], target_prompts=self.target_prompts, target_responses=self.target_responses, eos_token_id=self.tokenizer.eos_token_id)
|
184 |
+
orig_size = len(input_ids)
|
185 |
+
if orig_size > self.max_seq_len:
|
186 |
+
input_ids = input_ids[:self.max_seq_len]
|
187 |
+
labels = labels[:self.max_seq_len]
|
188 |
+
if len([l for l in labels if l != _HF_IGNORE_INDEX]) == 0:
|
189 |
+
raise ValueError(f'Truncating to max_seq_len={self.max_seq_len} has removed all loss-generating tokens. ' + f'Pre-truncation sequence length was {orig_size}. ' + 'This sample should have been filtered out before reaching the collator. If using ' + 'pre-tokenized streaming data, this may have resulted from using different ' + '``target_prompts``, ``target_responses``, or ``max_seq_len`` ' + 'settings when preparing the streaming dataset than what are currently being used.')
|
190 |
+
if not self._warned_truncated:
|
191 |
+
warnings.warn(f'Truncating sequence of length={orig_size} to fit max_seq_len={self.max_seq_len}. ' + f'If truncation is a problem, consider increasing max_seq_len.')
|
192 |
+
self._warned_truncated = True
|
193 |
+
attention_mask = [1] * len(input_ids)
|
194 |
+
n_total = len(input_ids)
|
195 |
+
i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - n_total)
|
196 |
+
if self.tokenizer.padding_side == 'left':
|
197 |
+
labels = i_pad + labels
|
198 |
+
else:
|
199 |
+
labels = labels + i_pad
|
200 |
+
processed_example = {'input_ids': input_ids, 'labels': labels, 'attention_mask': attention_mask}
|
201 |
+
processed_examples.append(processed_example)
|
202 |
+
batch = self.tokenizer.pad(processed_examples, padding='max_length', max_length=self.max_seq_len, return_tensors='pt')
|
203 |
+
batch['sequence_id'] = batch['attention_mask'] - 1
|
204 |
+
if not (self._allow_pad_trimming and self._seen_first_batch):
|
205 |
+
self._seen_first_batch = True
|
206 |
+
return batch
|
207 |
+
self._seen_first_batch = True
|
208 |
+
multiple_of = 8
|
209 |
+
n_non_padding = batch['attention_mask'].sum(dim=1).max()
|
210 |
+
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
|
211 |
+
for k, v in batch.items():
|
212 |
+
if len(v.shape) < 2:
|
213 |
+
continue
|
214 |
+
if self.tokenizer.padding_side == 'left':
|
215 |
+
batch[k] = v[:, -keep_tokens:].contiguous()
|
216 |
+
else:
|
217 |
+
batch[k] = v[:, :keep_tokens].contiguous()
|
218 |
+
return batch
|
219 |
+
|
220 |
+
def _process_and_batch_encoder_decoder(self, examples: List[TokenizedExample]) -> Dict[str, torch.Tensor]:
|
221 |
+
processed_examples = []
|
222 |
+
for example in examples:
|
223 |
+
context, target = stitch_turns_encoder_decoder(example_turns=example['turns'], eos_token_id=self.tokenizer.eos_token_id)
|
224 |
+
if len(target) < self.max_seq_len:
|
225 |
+
i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - len(target))
|
226 |
+
target = target + i_pad
|
227 |
+
else:
|
228 |
+
if not self._warned_target:
|
229 |
+
warnings.warn(f'Truncating TARGET sequence of length={len(target)} ' + f'to max_seq_len={self.max_seq_len}. If truncation is ' + f'a problem, consider increasing max_seq_len.')
|
230 |
+
self._warned_target = True
|
231 |
+
target = target[:self.max_seq_len - 1] + [self.tokenizer.eos_token_id]
|
232 |
+
if len(context) > self.max_seq_len:
|
233 |
+
if not self._warned_context:
|
234 |
+
warnings.warn(f'Truncating CONTEXT sequence of length={len(context)} ' + f'to max_seq_len={self.max_seq_len}. If truncation is ' + f'a problem, consider increasing max_seq_len.')
|
235 |
+
self._warned_context = True
|
236 |
+
context = context[:self.max_seq_len - 1] + [self.tokenizer.eos_token_id]
|
237 |
+
processed_example = {'input_ids': context, 'labels': target, 'attention_mask': [1] * len(context)}
|
238 |
+
processed_examples.append(processed_example)
|
239 |
+
batch = self.tokenizer.pad(processed_examples, padding='max_length', max_length=self.max_seq_len, return_tensors='pt')
|
240 |
+
batch['decoder_input_ids'] = torch.cat([torch.full((len(processed_examples), 1), self.tokenizer.pad_token_id), batch['labels'][:, :-1]], dim=1)
|
241 |
+
batch['decoder_input_ids'].masked_fill_(batch['decoder_input_ids'] == _HF_IGNORE_INDEX, self.tokenizer.pad_token_id)
|
242 |
+
batch['decoder_attention_mask'] = torch.not_equal(batch['labels'], _HF_IGNORE_INDEX)
|
243 |
+
if not (self._allow_pad_trimming and self._seen_first_batch):
|
244 |
+
self._seen_first_batch = True
|
245 |
+
return batch
|
246 |
+
self._seen_first_batch = True
|
247 |
+
multiple_of = 8
|
248 |
+
n_non_padding = batch['attention_mask'].sum(dim=1).max()
|
249 |
+
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
|
250 |
+
for k in ['input_ids', 'attention_mask']:
|
251 |
+
batch[k] = batch[k][:, :keep_tokens].contiguous()
|
252 |
+
n_non_padding = batch['decoder_attention_mask'].sum(dim=1).max()
|
253 |
+
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
|
254 |
+
for k in ['decoder_input_ids', 'decoder_attention_mask', 'labels']:
|
255 |
+
batch[k] = batch[k][:, :keep_tokens].contiguous()
|
256 |
+
return batch
|
config_utils.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
import warnings
|
5 |
+
from typing import Any, Dict, Literal, Mapping, Optional, Tuple, Union
|
6 |
+
from .utils import init_empty_weights
|
7 |
+
log = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
def pop_config(cfg: DictConfig, key: str, must_exist: bool=True, default_value: Any=None, convert: bool=False) -> Any:
|
10 |
+
"""Pop a value from the main config file and return it.
|
11 |
+
|
12 |
+
If the key does not exist, return the default_value or raise a RuntimeError
|
13 |
+
depending on the must_exist flag. If the convert flag is set to True, then
|
14 |
+
we will convert the value to a python object using OmegaConf.to_container.
|
15 |
+
"""
|
16 |
+
value = cfg.pop(key, None)
|
17 |
+
if value is not None and convert:
|
18 |
+
if not isinstance(value, DictConfig) and (not isinstance(value, ListConfig)):
|
19 |
+
raise ValueError(f'The key {key} has a value of type {type(value)} that cannot be converted to a dict or list. Please check your yaml.')
|
20 |
+
return om.to_container(value)
|
21 |
+
elif value is not None:
|
22 |
+
return value
|
23 |
+
elif must_exist:
|
24 |
+
raise NameError(f'The {key} parameter is missing and must exist for execution. Please check your yaml.')
|
25 |
+
else:
|
26 |
+
return default_value
|
27 |
+
|
28 |
+
def calculate_batch_size_info(global_batch_size: int, device_microbatch_size: Union[int, Literal['auto']]) -> Tuple[int, Union[int, Literal['auto']], Union[int, Literal['auto']]]:
|
29 |
+
if global_batch_size % dist.get_world_size() != 0:
|
30 |
+
raise ValueError(f'Global batch size {global_batch_size} is not divisible by {dist.get_world_size()} ' + 'as a result, the batch size would be truncated, please adjust `global_batch_size` ' + f'to be divisible by world size, {dist.get_world_size()}.')
|
31 |
+
device_batch_size = global_batch_size // dist.get_world_size()
|
32 |
+
if device_microbatch_size == 'auto':
|
33 |
+
device_grad_accum = 'auto'
|
34 |
+
elif isinstance(device_microbatch_size, int):
|
35 |
+
if device_microbatch_size > device_batch_size:
|
36 |
+
log.warn(f'device_microbatch_size > device_batch_size, ' + f'will be reduced from {device_microbatch_size} -> {device_batch_size}.')
|
37 |
+
device_microbatch_size = device_batch_size
|
38 |
+
device_grad_accum = math.ceil(device_batch_size / device_microbatch_size)
|
39 |
+
else:
|
40 |
+
raise ValueError(f'Not sure how to parse device_microbatch_size={device_microbatch_size!r}')
|
41 |
+
return (device_batch_size, device_microbatch_size, device_grad_accum)
|
42 |
+
|
43 |
+
def update_batch_size_info(cfg: DictConfig) -> DictConfig:
|
44 |
+
device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info(cfg.global_train_batch_size, cfg.device_train_microbatch_size)
|
45 |
+
cfg.n_gpus = dist.get_world_size()
|
46 |
+
cfg.device_train_batch_size = device_train_batch_size
|
47 |
+
cfg.device_train_microbatch_size = device_train_microbatch_size
|
48 |
+
cfg.device_train_grad_accum = device_train_grad_accum
|
49 |
+
if 'device_eval_batch_size' not in cfg:
|
50 |
+
if cfg.device_train_microbatch_size == 'auto':
|
51 |
+
cfg.device_eval_batch_size = 1
|
52 |
+
else:
|
53 |
+
cfg.device_eval_batch_size = cfg.device_train_microbatch_size
|
54 |
+
return cfg
|
55 |
+
|
56 |
+
def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]):
|
57 |
+
init_context = contextlib.nullcontext()
|
58 |
+
if 'init_device' in model_cfg:
|
59 |
+
assert model_cfg.init_device in ['meta', 'cpu', 'mixed']
|
60 |
+
if fsdp_config is None and model_cfg.init_device == 'meta':
|
61 |
+
warnings.warn("Using `cfg.model.init_device='meta'` is only valid when using FSDP! " + "Reverting to `cfg.model.init_device='cpu'`.")
|
62 |
+
model_cfg.init_device = 'cpu'
|
63 |
+
if model_cfg.init_device == 'meta':
|
64 |
+
init_context = init_empty_weights()
|
65 |
+
if model_cfg.init_device == 'mixed':
|
66 |
+
if fsdp_config is None:
|
67 |
+
raise NotImplementedError('Using init_device `mixed` is only supported with FSDP. ' + 'Please add a FSDP config.')
|
68 |
+
if not fsdp_config.get('sync_module_states', False):
|
69 |
+
warnings.warn('Setting `sync_module_states = True` for FSDP. This is required when using mixed initialization.')
|
70 |
+
fsdp_config['sync_module_states'] = True
|
71 |
+
fsdp_config.setdefault('use_orig_params', False)
|
72 |
+
fsdp_config.setdefault('load_monolith_rank0_only', True)
|
73 |
+
master_dtype = model_cfg.get('master_weights_dtype')
|
74 |
+
small_dtypes = ('bf16', 'fp16', 'float16', 'bfloat16', 'amp_fp16', 'amp_bf16')
|
75 |
+
if fsdp_config and master_dtype in small_dtypes:
|
76 |
+
reduce_dtype = None
|
77 |
+
buffer_dtype = None
|
78 |
+
mixed_precision = fsdp_config.get('mixed_precision')
|
79 |
+
if isinstance(mixed_precision, Mapping):
|
80 |
+
reduce_dtype = mixed_precision.get('reduce_dtype')
|
81 |
+
buffer_dtype = mixed_precision.get('buffer_dtype')
|
82 |
+
fsdp_config['mixed_precision'] = {'param_dtype': None, 'reduce_dtype': reduce_dtype, 'buffer_dtype': buffer_dtype, 'keep_low_precision_grads': True}
|
83 |
+
return init_context
|
84 |
+
|
85 |
+
def log_config(cfg: DictConfig) -> None:
|
86 |
+
"""Logs the current config and updates the wandb and mlflow configs.
|
87 |
+
|
88 |
+
This function can be called multiple times to update the wandb and MLflow
|
89 |
+
config with different variables.
|
90 |
+
"""
|
91 |
+
print(om.to_yaml(cfg))
|
92 |
+
if 'wandb' in cfg.get('loggers', {}):
|
93 |
+
try:
|
94 |
+
import wandb
|
95 |
+
except ImportError as e:
|
96 |
+
raise e
|
97 |
+
if wandb.run:
|
98 |
+
wandb.config.update(om.to_container(cfg, resolve=True))
|
99 |
+
if 'mlflow' in cfg.get('loggers', {}):
|
100 |
+
try:
|
101 |
+
import mlflow
|
102 |
+
except ImportError as e:
|
103 |
+
raise e
|
104 |
+
if mlflow.active_run():
|
105 |
+
mlflow.log_params(params=om.to_container(cfg, resolve=True))
|
configuration_mpt.py
CHANGED
@@ -2,12 +2,11 @@
|
|
2 |
import warnings
|
3 |
from typing import Any, Dict, Optional, Union
|
4 |
from transformers import PretrainedConfig
|
5 |
-
from .attention import check_alibi_support,
|
6 |
from .blocks import attn_config_defaults
|
7 |
from .fc import FC_CLASS_REGISTRY
|
8 |
from .norm import LPLayerNorm
|
9 |
from .ffn import FFN_CLASS_REGISTRY
|
10 |
-
from .warnings import VersionedDeprecationWarning
|
11 |
ffn_config_defaults: Dict = {'ffn_type': 'mptmlp'}
|
12 |
init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0}
|
13 |
|
@@ -30,16 +29,13 @@ class MPTConfig(PretrainedConfig):
|
|
30 |
attn_config (Dict): A dictionary used to configure the model's attention module:
|
31 |
attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention, grouped_query_attention
|
32 |
attn_pdrop (float): The dropout probability for the attention layers.
|
33 |
-
attn_impl (str): The attention implementation to use. One of 'torch'
|
34 |
qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
|
35 |
qk_gn (bool): Whether to apply group normalization to the queries and keys in the attention layer.
|
36 |
clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
|
37 |
this value.
|
38 |
softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
|
39 |
use the default scale of ``1/sqrt(d_keys)``.
|
40 |
-
prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
|
41 |
-
extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
|
42 |
-
can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
|
43 |
attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
|
44 |
When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
|
45 |
which sub-sequence each token belongs to.
|
@@ -116,7 +112,7 @@ class MPTConfig(PretrainedConfig):
|
|
116 |
self._validate_config()
|
117 |
|
118 |
def _set_config_defaults(self, config: Dict[str, Any], config_defaults: Dict[str, Any]) -> Dict[str, Any]:
|
119 |
-
for
|
120 |
if k not in config:
|
121 |
config[k] = v
|
122 |
elif isinstance(v, dict):
|
@@ -131,18 +127,12 @@ class MPTConfig(PretrainedConfig):
|
|
131 |
raise ValueError('d_model must be divisible by n_heads')
|
132 |
if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):
|
133 |
raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
|
134 |
-
if self.attn_config['attn_impl'] not in ['torch', 'flash'
|
135 |
raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
|
136 |
-
if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
|
137 |
-
raise NotImplementedError('prefix_lm only implemented with torch and triton attention.')
|
138 |
-
if self.attn_config['attn_impl'] == 'flash' and is_flash_v1_installed():
|
139 |
-
warnings.warn(VersionedDeprecationWarning('Support for Flash Attention v1 is deprecated. Please upgrade to Flash Attention v2.4.2. To install Flash Attention v2.4.2, please run `pip install -e ".[gpu-flash2]"` from the root directory of the llm-foundry repository.', remove_version='0.6.0'))
|
140 |
-
if self.attn_config['attn_impl'] == 'triton' and (not self.attn_config['prefix_lm']):
|
141 |
-
warnings.warn(UserWarning('If not using a Prefix Language Model, we recommend setting "attn_impl" to "flash" instead of "triton".'))
|
142 |
if self.attn_config['alibi'] and (not check_alibi_support(self.attn_config['attn_impl'])):
|
143 |
-
raise NotImplementedError('alibi only implemented with torch
|
144 |
-
if self.attn_config['attn_uses_sequence_id'] and (not (self.attn_config['attn_impl']
|
145 |
-
raise NotImplementedError('attn_uses_sequence_id only implemented with torch
|
146 |
if self.attn_config['rope'] and self.attn_config['rope_impl'] not in ['dail', 'hf']:
|
147 |
raise ValueError('If rope is being used then rope_impl should be either "dail", or "hf".')
|
148 |
if self.attn_config['rope'] and self.attn_config['rope_impl'] == 'hf' and (self.attn_config['rope_hf_config']['type'] not in ['no_scaling', 'linear', 'dynamic']):
|
|
|
2 |
import warnings
|
3 |
from typing import Any, Dict, Optional, Union
|
4 |
from transformers import PretrainedConfig
|
5 |
+
from .attention import check_alibi_support, is_flash_v2_installed
|
6 |
from .blocks import attn_config_defaults
|
7 |
from .fc import FC_CLASS_REGISTRY
|
8 |
from .norm import LPLayerNorm
|
9 |
from .ffn import FFN_CLASS_REGISTRY
|
|
|
10 |
ffn_config_defaults: Dict = {'ffn_type': 'mptmlp'}
|
11 |
init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0}
|
12 |
|
|
|
29 |
attn_config (Dict): A dictionary used to configure the model's attention module:
|
30 |
attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention, grouped_query_attention
|
31 |
attn_pdrop (float): The dropout probability for the attention layers.
|
32 |
+
attn_impl (str): The attention implementation to use. One of 'torch' or 'flash'.
|
33 |
qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
|
34 |
qk_gn (bool): Whether to apply group normalization to the queries and keys in the attention layer.
|
35 |
clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
|
36 |
this value.
|
37 |
softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
|
38 |
use the default scale of ``1/sqrt(d_keys)``.
|
|
|
|
|
|
|
39 |
attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
|
40 |
When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
|
41 |
which sub-sequence each token belongs to.
|
|
|
112 |
self._validate_config()
|
113 |
|
114 |
def _set_config_defaults(self, config: Dict[str, Any], config_defaults: Dict[str, Any]) -> Dict[str, Any]:
|
115 |
+
for k, v in config_defaults.items():
|
116 |
if k not in config:
|
117 |
config[k] = v
|
118 |
elif isinstance(v, dict):
|
|
|
127 |
raise ValueError('d_model must be divisible by n_heads')
|
128 |
if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):
|
129 |
raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
|
130 |
+
if self.attn_config['attn_impl'] not in ['torch', 'flash']:
|
131 |
raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
if self.attn_config['alibi'] and (not check_alibi_support(self.attn_config['attn_impl'])):
|
133 |
+
raise NotImplementedError('alibi only implemented with torch and flash (v2.4.2 or higher) attention.')
|
134 |
+
if self.attn_config['attn_uses_sequence_id'] and (not (self.attn_config['attn_impl'] == 'torch' or (self.attn_config['attn_impl'] == 'flash' and is_flash_v2_installed(v2_version='v2.1.2')))):
|
135 |
+
raise NotImplementedError('attn_uses_sequence_id only implemented with torch and flash (v2.1.2 or higher) attention.')
|
136 |
if self.attn_config['rope'] and self.attn_config['rope_impl'] not in ['dail', 'hf']:
|
137 |
raise ValueError('If rope is being used then rope_impl should be either "dail", or "hf".')
|
138 |
if self.attn_config['rope'] and self.attn_config['rope_impl'] == 'hf' and (self.attn_config['rope_hf_config']['type'] not in ['no_scaling', 'linear', 'dynamic']):
|
curriculum_learning_callback.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Enable curriculum learning by resuming with a different dataset.
|
2 |
+
|
3 |
+
This callback is currently experimental. The API may change without warning in
|
4 |
+
the future.
|
5 |
+
"""
|
6 |
+
import logging
|
7 |
+
from typing import Any, Dict
|
8 |
+
from streaming import StreamingDataset
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
from .interfaces import CallbackWithConfig
|
11 |
+
from .warnings import experimental_class
|
12 |
+
log = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
@experimental_class('CurriculumLearning callback')
|
15 |
+
class CurriculumLearning(CallbackWithConfig):
|
16 |
+
"""Starts an epoch with a different dataset when resuming from a checkpoint.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
dataset_index (int): The index of the dataset currently being used.
|
20 |
+
current_dataset_config (Dict): The configuration of the dataset currently
|
21 |
+
being used.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, dataset_index: int, train_config: Dict):
|
25 |
+
self.dataset_index = dataset_index
|
26 |
+
self.saved_dataset_index = 0
|
27 |
+
self.all_dataset_configs = []
|
28 |
+
self.current_dataset_state = {}
|
29 |
+
self.current_dataset_config = train_config['dataloader']
|
30 |
+
|
31 |
+
def before_load(self, state: State, logger: Logger):
|
32 |
+
del logger
|
33 |
+
train_loader = state.train_dataloader
|
34 |
+
if not isinstance(train_loader, DataLoader):
|
35 |
+
raise ValueError(f'CurriculumLearning callback can only be used with a train ', f'dataloader of type DataLoader, but got {type(train_loader)}.')
|
36 |
+
dataset = train_loader.dataset
|
37 |
+
if not isinstance(dataset, StreamingDataset):
|
38 |
+
raise ValueError(f'CurriculumLearning callback only supports StreamingDataset ', f'because it requires loading and saving dataset state. ', f'Instead, got a dataset of type {type(dataset)}')
|
39 |
+
assert isinstance(dataset, StreamingDataset)
|
40 |
+
self.current_dataset_state = dataset.state_dict(num_samples=0, from_beginning=False)
|
41 |
+
|
42 |
+
def after_load(self, state: State, logger: Logger):
|
43 |
+
del logger
|
44 |
+
train_loader = state._train_dataloader
|
45 |
+
assert isinstance(train_loader, DataLoader), 'CurriculumLearning callback requires a DataLoader.'
|
46 |
+
dataset = train_loader.dataset
|
47 |
+
assert isinstance(dataset, StreamingDataset), 'CurriculumLearning callback requires a StreamingDataset.'
|
48 |
+
if self.saved_dataset_index < self.dataset_index:
|
49 |
+
if self.current_dataset_state['epoch'] < 0:
|
50 |
+
self.current_dataset_state['epoch'] = 0
|
51 |
+
dataset.load_state_dict(self.current_dataset_state)
|
52 |
+
state.timestamp = state.timestamp.to_next_epoch()
|
53 |
+
self.all_dataset_configs.append(self.current_dataset_config)
|
54 |
+
elif self.dataset_index == 0 and len(self.all_dataset_configs) == 0:
|
55 |
+
self.all_dataset_configs.append(self.current_dataset_config)
|
56 |
+
|
57 |
+
def state_dict(self):
|
58 |
+
return {'dataset_index': self.dataset_index, 'all_dataset_configs': self.all_dataset_configs}
|
59 |
+
|
60 |
+
def load_state_dict(self, state: Dict[str, Any]):
|
61 |
+
self.saved_dataset_index = state.get('dataset_index', 0)
|
62 |
+
self.all_dataset_configs = state.get('all_dataset_configs', [])
|
data.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Datasets for converting to MDS Shards."""
|
2 |
+
import os
|
3 |
+
import warnings
|
4 |
+
from typing import Dict, Iterable, Union
|
5 |
+
import datasets as hf_datasets
|
6 |
+
import numpy as np
|
7 |
+
from torch.utils.data import IterableDataset
|
8 |
+
from transformers import PreTrainedTokenizerBase
|
9 |
+
|
10 |
+
class NoConcatDataset(IterableDataset):
|
11 |
+
"""An IterableDataset that returns text samples for MDSWriter.
|
12 |
+
|
13 |
+
Returns dicts of {'text': bytes}
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset]):
|
17 |
+
self.hf_dataset = hf_dataset
|
18 |
+
|
19 |
+
def __iter__(self) -> Iterable[Dict[str, bytes]]:
|
20 |
+
for sample in self.hf_dataset:
|
21 |
+
yield {'text': sample['text'].encode('utf-8')}
|
22 |
+
|
23 |
+
class ConcatTokensDataset(IterableDataset):
|
24 |
+
"""An IterableDataset that returns token samples for MDSWriter.
|
25 |
+
|
26 |
+
Returns dicts of {'tokens': bytes}
|
27 |
+
|
28 |
+
To use data created by this class and written to MDS format:
|
29 |
+
|
30 |
+
```python
|
31 |
+
import torch
|
32 |
+
from streaming.base import StreamingDataset
|
33 |
+
from transformers import AutoTokenizer
|
34 |
+
|
35 |
+
tokenizer = AutoTokenizer.from_pretrained('your/tokenizer')
|
36 |
+
ds = StreamingDataset(local='mds-data-folder', split='val')
|
37 |
+
|
38 |
+
# note, you need to copy the numpy array because the original is non-writeable
|
39 |
+
# and torch does not support non-writeable tensors, so you get a scary warning and
|
40 |
+
# if you do try to write to the tensor you get undefined behavior
|
41 |
+
tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int64).copy())
|
42 |
+
print(tokenizer.decode(tokens))
|
43 |
+
```
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset], tokenizer: PreTrainedTokenizerBase, max_length: int, bos_text: str, eos_text: str, no_wrap: bool):
|
47 |
+
self.hf_dataset = hf_dataset
|
48 |
+
self.tokenizer = tokenizer
|
49 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
50 |
+
self.max_length = max_length
|
51 |
+
self.bos_text = bos_text
|
52 |
+
self.eos_text = eos_text
|
53 |
+
self.should_wrap = not no_wrap
|
54 |
+
self.bos_tokens = self.tokenizer(self.bos_text, truncation=False, padding=False, add_special_tokens=False)['input_ids']
|
55 |
+
if len(self.bos_tokens) > 1:
|
56 |
+
warnings.warn(f'You specified --concat_tokens with --bos_text, but your BOS text is not tokenizing to one token , instead we got {self.bos_tokens}. Quit if this was in error.')
|
57 |
+
self.eos_tokens = self.tokenizer(self.eos_text, truncation=False, padding=False, add_special_tokens=False)['input_ids']
|
58 |
+
if len(self.eos_tokens) > 1:
|
59 |
+
warnings.warn(f'You specified --concat_tokens with --eos_text, but your EOS text is not tokenizing to one token , instead we got {self.eos_tokens}. Quit if this was in error.')
|
60 |
+
eos_text_provided = self.eos_text != ''
|
61 |
+
bos_text_provided = self.bos_text != ''
|
62 |
+
test_text = self.tokenizer('')
|
63 |
+
if len(test_text['input_ids']) > 0 and (eos_text_provided or bos_text_provided):
|
64 |
+
message = 'both eos and bos' if eos_text_provided and bos_text_provided else 'eos_text' if eos_text_provided else 'bos_text'
|
65 |
+
warnings.warn(f'The provided tokenizer adds special tokens, but you also specified {message}. This may result ' + 'in duplicated special tokens. Please be sure this is what you intend.')
|
66 |
+
|
67 |
+
def __iter__(self) -> Iterable[Dict[str, bytes]]:
|
68 |
+
buffer = []
|
69 |
+
for sample in self.hf_dataset:
|
70 |
+
encoded = self.tokenizer(sample['text'], truncation=False, padding=False)
|
71 |
+
iids = encoded['input_ids']
|
72 |
+
buffer = buffer + self.bos_tokens + iids + self.eos_tokens
|
73 |
+
while len(buffer) >= self.max_length:
|
74 |
+
concat_sample = buffer[:self.max_length]
|
75 |
+
buffer = buffer[self.max_length:] if self.should_wrap else []
|
76 |
+
yield {'tokens': np.asarray(concat_sample).tobytes()}
|
data_prep_utils.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from glob import glob
|
4 |
+
from typing import List, Optional
|
5 |
+
|
6 |
+
def with_id(basename: str, shard_id: int) -> str:
|
7 |
+
"""Get a new basename with the given shard_id.
|
8 |
+
|
9 |
+
From https://github.com/mosaicml/streaming/blob/main/examples/multiprocess_dataset_conversion.ipynb.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
basename (str): Old basename of file.
|
13 |
+
shard_id (int): New shard ID.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
str: New basename of file.
|
17 |
+
"""
|
18 |
+
parts = basename.split('.')
|
19 |
+
parts[1] = f'{shard_id:05}'
|
20 |
+
return '.'.join(parts)
|
21 |
+
|
22 |
+
def merge_shard_groups(root: str) -> None:
|
23 |
+
"""Merge ephemeral sub-datasets created in parallel into one dataset.
|
24 |
+
|
25 |
+
From https://github.com/mosaicml/streaming/blob/main/examples/multiprocess_dataset
|
26 |
+
_conversion.ipynb.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
root (str): Root directory.
|
30 |
+
"""
|
31 |
+
pattern = os.path.join(root, '*')
|
32 |
+
subdirs = sorted(glob(pattern))
|
33 |
+
shard_id = 0
|
34 |
+
infos = []
|
35 |
+
for subdir in subdirs:
|
36 |
+
index_filename = os.path.join(subdir, 'index.json')
|
37 |
+
with open(index_filename) as index_file:
|
38 |
+
obj = json.load(index_file)
|
39 |
+
for info in obj['shards']:
|
40 |
+
old_basename = info['raw_data']['basename']
|
41 |
+
new_basename = with_id(old_basename, shard_id)
|
42 |
+
info['raw_data']['basename'] = new_basename
|
43 |
+
if info['zip_data'] is not None:
|
44 |
+
old_basename = info['zip_data']['basename']
|
45 |
+
new_basename = with_id(old_basename, shard_id)
|
46 |
+
info['zip_data']['basename'] = new_basename
|
47 |
+
old_filename = os.path.join(subdir, old_basename)
|
48 |
+
new_filename = os.path.join(root, new_basename)
|
49 |
+
os.rename(old_filename, new_filename)
|
50 |
+
shard_id += 1
|
51 |
+
infos.append(info)
|
52 |
+
os.remove(index_filename)
|
53 |
+
os.rmdir(subdir)
|
54 |
+
index_filename = os.path.join(root, 'index.json')
|
55 |
+
obj = {'version': 2, 'shards': infos}
|
56 |
+
text = json.dumps(obj, sort_keys=True)
|
57 |
+
with open(index_filename, 'w') as out:
|
58 |
+
out.write(text)
|
59 |
+
|
60 |
+
class DownloadingIterable:
|
61 |
+
|
62 |
+
def __init__(self, object_names: List[str], output_folder: str, object_store: Optional[ObjectStore]):
|
63 |
+
"""Iterable that downloads files from an object store before yielding.
|
64 |
+
|
65 |
+
If object_store is None, input_folder_prefix is treated as a local path.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
object_names (List[str]): Names of objects to download
|
69 |
+
output_folder (str): Local folder to write downloaded files to
|
70 |
+
object_store (Optiona[ObjectStore]): Object store to download from
|
71 |
+
"""
|
72 |
+
self.object_names = object_names
|
73 |
+
self.object_store = object_store
|
74 |
+
self.output_folder = output_folder
|
75 |
+
|
76 |
+
def __iter__(self):
|
77 |
+
for object_name in self.object_names:
|
78 |
+
output_filename = object_name
|
79 |
+
if self.object_store is not None:
|
80 |
+
output_filename = os.path.join(self.output_folder, object_name.strip('/'))
|
81 |
+
self.object_store.download_object(object_name=object_name, filename=output_filename, overwrite=True)
|
82 |
+
with open(output_filename) as _txt_file:
|
83 |
+
txt = _txt_file.read()
|
84 |
+
yield {'text': txt}
|
dataloader.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from typing import Tuple, Union
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from transformers import PreTrainedTokenizerBase
|
7 |
+
from .collator import Seq2SeqFinetuningCollator, validate_target_settings
|
8 |
+
from .tasks import DOWNLOADED_FT_DATASETS_DIRPATH, SUPPORTED_EXTENSIONS, dataset_constructor
|
9 |
+
from .packing import BinPackCollator, auto_packing_ratio
|
10 |
+
from .text_data import build_streams, get_tokens_per_batch_func
|
11 |
+
from .exceptions import MissingHuggingFaceURLSplitError, NotEnoughDatasetSamplesError
|
12 |
+
log = logging.getLogger(__name__)
|
13 |
+
_HF_IGNORE_INDEX = -100
|
14 |
+
_DEFAULT_TARGET_RESPONSES = 'last'
|
15 |
+
_DEFAULT_TARGET_PROMPTS = 'none'
|
16 |
+
|
17 |
+
def build_finetuning_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int) -> DataSpec:
|
18 |
+
"""Builds a finetuning dataloader for training or evaluating.
|
19 |
+
|
20 |
+
The underlying dataset can be built through one of two code paths:
|
21 |
+
1. As a HuggingFace dataset, via `datasets.load_dataset(...)`
|
22 |
+
2. As a streaming dataset
|
23 |
+
You will need to set slightly different dataset config fields depending
|
24 |
+
on which you intend to use, as explained below.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
cfg (DictConfig): An omegaconf dictionary used to configure the loader:
|
28 |
+
cfg.name (str): The type of dataloader to build. Must = "finetuning".
|
29 |
+
---
|
30 |
+
*** HuggingFace dataset config fields ***
|
31 |
+
cfg.dataset.hf_name (str, optional): The name of the HuggingFace dataset
|
32 |
+
to use. Can also be a remote http(s) directory or object store bucket
|
33 |
+
containing the file {split}.jsonl in the format (prompt, response),
|
34 |
+
in which case the builder will create a HuggingFace dataset.
|
35 |
+
cfg.dataset.hf_kwargs (DictConfig, optional): Additional kwargs to
|
36 |
+
pass to `datasets.load_dataset`, which can be used to load
|
37 |
+
a dataset from local files.
|
38 |
+
cfg.dataset.preprocessing_fn (str, optional): The name/import path of
|
39 |
+
the preprocessing function to use for formatting the data examples.
|
40 |
+
If ``None`` (default), the builder will use the preprocessing function
|
41 |
+
registered under `hf_name` (see `tasks.py`), if one exists,
|
42 |
+
otherwise it will skip preprocessing.
|
43 |
+
If `preprocessing_fn` corresponds to a registered preprocessing
|
44 |
+
function in `tasks.py`, the builder will use that.
|
45 |
+
Otherwise, it will interpret `preprocessing_fn` as a
|
46 |
+
"import.path:function_name" import path; e.g., it will call
|
47 |
+
`from import.path import function_name` and use the imported
|
48 |
+
function as the preprocessing function.
|
49 |
+
*** Streaming dataset config fields ***
|
50 |
+
cfg.dataset.remote (str, optional): Location of a MDS-formatted
|
51 |
+
streaming dataset to use. Setting this will tell the builder
|
52 |
+
to create a streaming dataset rather than a HuggingFace dataset.
|
53 |
+
cfg.dataset.local (str, optional): Local path where remote data
|
54 |
+
will be streamed to. Only valid if `cfg.dataset.remote` has
|
55 |
+
also been set.
|
56 |
+
*** Shared dataset configs fields ***
|
57 |
+
cfg.dataset.max_seq_len (int): The maximum length of sequences
|
58 |
+
in the batch. See :class:`Seq2SeqFinetuningCollator` docstring
|
59 |
+
for details.
|
60 |
+
cfg.dataset.decoder_only_format (bool): Whether to format the
|
61 |
+
examples for a decoder-only model. See :class:`Seq2SeqFinetuningCollator`
|
62 |
+
docstring for details.
|
63 |
+
cfg.dataset.target_responses (str): Which responses are used as training targets.
|
64 |
+
Defaults to "last", meaning only the final response in multi-turn examples
|
65 |
+
will serve as training targets. See :class:`Seq2SeqFinetuningCollator` docstring for
|
66 |
+
details.
|
67 |
+
cfg.dataset.target_prompts (str): Which prompts are used as training targets.
|
68 |
+
Defaults to "none", meaning prompts are never used as training targets.
|
69 |
+
See :class:`Seq2SeqFinetuningCollator` docstring for details.
|
70 |
+
cfg.dataset.allow_pad_trimming (bool, optional): Whether to allow
|
71 |
+
the collator to trim padding. See :class:`Seq2SeqFinetuningCollator`
|
72 |
+
docstring for details. Default: ``False``.
|
73 |
+
cfg.dataset.packing_ratio (Optional[float, Literal['auto']]): If provided, this invokes
|
74 |
+
a collator wrapper that packs device_batch_size*packing_ratio
|
75 |
+
raw examples into device_batch_size packed examples. This helps
|
76 |
+
minimize padding while preserving sequence integrity.
|
77 |
+
This adds `sequence_id` to the batch, which indicates which unique
|
78 |
+
sequence each token belongs to.
|
79 |
+
|
80 |
+
If set to 'auto', packing_ratio is profiled and the highest observed packing ratio with
|
81 |
+
zero waste is selected.
|
82 |
+
In practice, this may result in > 0 waste because profiling is done on only a portion
|
83 |
+
of the dataset.
|
84 |
+
|
85 |
+
Note: Using this feature will not change device_batch_size but it
|
86 |
+
will determine the number of raw examples consumed by the dataloader
|
87 |
+
per batch. Some examples may be discarded if they do not fit when
|
88 |
+
packing.
|
89 |
+
Select packing_ratio **carefully** based on the dataset
|
90 |
+
statistics, max_seq_len, and tolerance for discarding samples!
|
91 |
+
The script `scripts/misc/profile_packing.py` can help
|
92 |
+
you choose the best packing_ratio.
|
93 |
+
cfg.dataset.shuffle (bool): Whether to shuffle the dataset.
|
94 |
+
___
|
95 |
+
See :class:`StreamingFinetuningDataset` for info on other standard config
|
96 |
+
options within `cfg.dataset` that will be passed as kwargs if
|
97 |
+
using the streaming codepath.
|
98 |
+
---
|
99 |
+
See :class:`DataLoader` for standard argument options to the pytorch
|
100 |
+
dataloader, such as `cfg.drop_last`, `cfg.num_workers`, etc.
|
101 |
+
tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to
|
102 |
+
prepare the data from raw text. Any missing sentinel tokens will
|
103 |
+
be added by the collator.
|
104 |
+
device_batch_size (int): The size of the batches (number of examples)
|
105 |
+
that the dataloader will produce.
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
A pytorch dataloader
|
109 |
+
|
110 |
+
Note:
|
111 |
+
You can run the script inside `scripts/misc/profile_packing.py` to quickly test the
|
112 |
+
padding/waste rates for different `cfg.dataset.packing_ratio` choices,
|
113 |
+
given a starting workload YAML.
|
114 |
+
"""
|
115 |
+
_validate_config(cfg.dataset)
|
116 |
+
if tokenizer.pad_token is None:
|
117 |
+
tokenizer.pad_token = tokenizer.eos_token
|
118 |
+
collate_fn, dataloader_batch_size = _build_collate_fn(cfg, tokenizer, device_batch_size)
|
119 |
+
dataset = None
|
120 |
+
sampler = None
|
121 |
+
if cfg.dataset.get('remote') is not None or cfg.dataset.get('streams') is not None:
|
122 |
+
streams = build_streams(cfg.dataset)
|
123 |
+
dataset = dataset_constructor.build_from_streaming(tokenizer=tokenizer, streams=streams, local=cfg.dataset.get('local', None), remote=cfg.dataset.get('remote', None), split=cfg.dataset.get('split', None), download_retry=cfg.dataset.get('download_retry', 2), download_timeout=cfg.dataset.get('download_timeout', 60), validate_hash=cfg.dataset.get('validate_hash', None), keep_zip=cfg.dataset.get('keep_zip', False), epoch_size=cfg.dataset.get('epoch_size', None), predownload=cfg.dataset.get('predownload', None), cache_limit=cfg.dataset.get('cache_limit', None), partition_algo=cfg.dataset.get('partition_algo', 'relaxed'), num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None), batch_size=device_batch_size, shuffle=cfg.dataset.get('shuffle', False), shuffle_algo=cfg.dataset.get('shuffle_algo', 'py1e'), shuffle_seed=cfg.dataset.get('shuffle_seed', 9176), shuffle_block_size=cfg.dataset.get('shuffle_block_size', None), sampling_method=cfg.dataset.get('sampling_method', 'balanced'), sampling_granularity=cfg.dataset.get('sampling_granularity', 1), batching_method=cfg.dataset.get('batching_method', 'random'), max_seq_len=cfg.dataset.max_seq_len)
|
124 |
+
else:
|
125 |
+
dataset_name_or_path = cfg.dataset.hf_name
|
126 |
+
split = cfg.dataset.get('split')
|
127 |
+
if split is None:
|
128 |
+
raise MissingHuggingFaceURLSplitError()
|
129 |
+
backend, _, _ = parse_uri(dataset_name_or_path)
|
130 |
+
if backend not in ['', None]:
|
131 |
+
dataset_name_or_path = _download_remote_hf_dataset(remote_path=dataset_name_or_path, split=split)
|
132 |
+
split = split.replace('-', '_')
|
133 |
+
proto_preprocessing_fn = cfg.dataset.get('preprocessing_fn')
|
134 |
+
if isinstance(proto_preprocessing_fn, (dict, DictConfig)):
|
135 |
+
preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_dict(dict(proto_preprocessing_fn))
|
136 |
+
else:
|
137 |
+
preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_str(proto_preprocessing_fn, dataset_name_or_path)
|
138 |
+
dataset = dataset_constructor.build_from_hf(dataset_name=dataset_name_or_path, split=split, safe_load=cfg.dataset.get('safe_load', False), max_seq_len=cfg.dataset.max_seq_len, preprocessing_fn=preprocessing_fn, tokenizer=tokenizer, target_prompts=cfg.dataset.get('target_prompts', _DEFAULT_TARGET_PROMPTS), target_responses=cfg.dataset.get('target_responses', _DEFAULT_TARGET_RESPONSES), decoder_only_format=cfg.dataset.decoder_only_format, hf_kwargs=cfg.dataset.get('hf_kwargs', {}))
|
139 |
+
if cfg.drop_last:
|
140 |
+
world_size = dist.get_world_size()
|
141 |
+
minimum_dataset_size = world_size * dataloader_batch_size
|
142 |
+
if hasattr(dataset, '__len__'):
|
143 |
+
full_dataset_size = len(dataset)
|
144 |
+
if full_dataset_size < minimum_dataset_size:
|
145 |
+
raise NotEnoughDatasetSamplesError(dataset_name=cfg.dataset.hf_name, split=split, dataloader_batch_size=dataloader_batch_size, world_size=world_size, full_dataset_size=full_dataset_size, minimum_dataset_size=minimum_dataset_size)
|
146 |
+
sampler = dist.get_sampler(dataset, drop_last=cfg.drop_last, shuffle=cfg.dataset.shuffle)
|
147 |
+
assert dataset is not None
|
148 |
+
dl = DataLoader(dataset, collate_fn=collate_fn, batch_size=dataloader_batch_size, drop_last=cfg.drop_last, sampler=sampler, num_workers=cfg.num_workers, pin_memory=cfg.get('pin_memory', True), prefetch_factor=cfg.get('prefetch_factor', 2), persistent_workers=cfg.get('persistent_workers', True), timeout=cfg.get('timeout', 0))
|
149 |
+
token_counting_func = get_tokens_per_batch_func()
|
150 |
+
return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)
|
151 |
+
|
152 |
+
def _validate_config(dataset_cfg: DictConfig) -> None:
|
153 |
+
"""Validates the dataset configuration.
|
154 |
+
|
155 |
+
Makes sure that the dataset is properly configured for either
|
156 |
+
a HuggingFace dataset or a streaming dataset. Must be valid for one or
|
157 |
+
the other.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
dataset_cfg (DictConfig): The dataset configuration to be validated.
|
161 |
+
|
162 |
+
Raises:
|
163 |
+
ValueError: If the dataset configuration does not meet the requirements.
|
164 |
+
"""
|
165 |
+
if dataset_cfg.get('hf_name') is not None:
|
166 |
+
illegal_keys = ['local', 'remote']
|
167 |
+
discovered_illegal_keys = []
|
168 |
+
for key in illegal_keys:
|
169 |
+
if dataset_cfg.get(key) is not None:
|
170 |
+
discovered_illegal_keys.append('`' + key + '`')
|
171 |
+
if discovered_illegal_keys:
|
172 |
+
raise ValueError('The dataset config sets a value for `hf_name` as well as the ' + f"following keys: {', '.join(discovered_illegal_keys)}.\n" + 'Those keys are used when building from a streaming dataset, but ' + 'setting `hf_name` instructs the dataset to build from a HuggingFace dataset.')
|
173 |
+
elif dataset_cfg.get('remote') is not None:
|
174 |
+
illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn', 'safe_load']
|
175 |
+
discovered_illegal_keys = []
|
176 |
+
for key in illegal_keys:
|
177 |
+
if dataset_cfg.get(key) is not None:
|
178 |
+
discovered_illegal_keys.append('`' + key + '`')
|
179 |
+
if discovered_illegal_keys:
|
180 |
+
raise ValueError('The dataset config sets a value for `remote` as well as the ' + f"following keys: {', '.join(discovered_illegal_keys)}.\n" + 'Those keys are used when building from a HuggingFace dataset, but ' + 'setting `remote` instructs the dataset to build from a streaming dataset.')
|
181 |
+
if dataset_cfg.get('local') is None:
|
182 |
+
raise ValueError('Using a streaming dataset requires setting both `remote` and `local`, ' + 'but dataset.local is None.')
|
183 |
+
elif dataset_cfg.get('streams') is not None:
|
184 |
+
illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn', 'safe_load']
|
185 |
+
discovered_illegal_keys = []
|
186 |
+
for key in illegal_keys:
|
187 |
+
if dataset_cfg.get(key) is not None:
|
188 |
+
discovered_illegal_keys.append('`' + key + '`')
|
189 |
+
if discovered_illegal_keys:
|
190 |
+
raise ValueError('The dataset config sets a value for `streams` as well as the ' + f"following keys: {', '.join(discovered_illegal_keys)}.\n" + 'Those keys are used when building from a HuggingFace dataset, but ' + 'setting `streams` instructs the dataset to build from a streaming dataset.')
|
191 |
+
illegal_keys = ['remote', 'local']
|
192 |
+
discovered_illegal_keys = []
|
193 |
+
for key in illegal_keys:
|
194 |
+
if dataset_cfg.get(key) is not None:
|
195 |
+
discovered_illegal_keys.append('`' + key + '`')
|
196 |
+
if discovered_illegal_keys:
|
197 |
+
raise ValueError('The dataset config sets a value for `streams` as well as the ' + f"following keys: {', '.join(discovered_illegal_keys)}.\n" + 'Please either use single stream (set remote/local only) ' + 'or put remote/local under streams')
|
198 |
+
else:
|
199 |
+
raise ValueError('In the dataset config, you must set `hf_name` to use a HuggingFace ' + 'dataset, or set `remote` to use a streaming dataset, or set ' + '`streams` to use multiple streaming datasets, but all were None.')
|
200 |
+
if dataset_cfg.get('max_seq_len') is None:
|
201 |
+
raise ValueError('In the dataset config, you must set the `max_seq_len`')
|
202 |
+
target_responses = str(dataset_cfg.get('target_responses', _DEFAULT_TARGET_RESPONSES)).lower()
|
203 |
+
target_prompts = str(dataset_cfg.get('target_prompts', _DEFAULT_TARGET_PROMPTS)).lower()
|
204 |
+
decoder_only_format = dataset_cfg.decoder_only_format
|
205 |
+
validate_target_settings(target_prompts, target_responses, decoder_only_format)
|
206 |
+
|
207 |
+
def _download_remote_hf_dataset(remote_path: str, split: str) -> str:
|
208 |
+
"""Downloads a dataset from a remote object store.
|
209 |
+
|
210 |
+
This function supports 'jsonl', 'csv', and 'parquet' file formats for the dataset. It will attempt to download
|
211 |
+
the dataset, then once it is downloaded, convert it into HuggingFace ``datasets`` format, and then return this
|
212 |
+
dataset.
|
213 |
+
|
214 |
+
The function also ensures synchronicity across multiple processes during the file download. It creates a signal
|
215 |
+
file that is used to synchronize the start of the download across different processes. Once the download is
|
216 |
+
completed, the function removes the signal file.
|
217 |
+
|
218 |
+
Args:
|
219 |
+
hf_name (str): The path of the HuggingFace dataset to download.
|
220 |
+
split (str): The dataset split to download (e.g., 'train', 'validation', 'test').
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
A local directory path where the dataset files are stored.
|
224 |
+
|
225 |
+
Raises:
|
226 |
+
FileNotFoundError: Raised if the dataset file cannot be found with any of the supported extensions.
|
227 |
+
"""
|
228 |
+
hf_formatted_split = split.replace('-', '_')
|
229 |
+
finetune_dir = os.path.join(DOWNLOADED_FT_DATASETS_DIRPATH, hf_formatted_split if hf_formatted_split != 'data' else 'data_not')
|
230 |
+
os.makedirs(finetune_dir, exist_ok=True)
|
231 |
+
for extension in SUPPORTED_EXTENSIONS:
|
232 |
+
name = f"{remote_path.strip('/')}/{split}{extension}"
|
233 |
+
destination = str(os.path.abspath(os.path.join(finetune_dir, 'data', f'{hf_formatted_split}-00000-of-00001{extension}')))
|
234 |
+
signal_file_path = os.path.join(finetune_dir, f'.node_{dist.get_node_rank()}_local_rank0_completed')
|
235 |
+
if dist.get_local_rank() == 0:
|
236 |
+
try:
|
237 |
+
get_file(path=name, destination=destination, overwrite=True)
|
238 |
+
except FileNotFoundError as e:
|
239 |
+
if extension == SUPPORTED_EXTENSIONS[-1]:
|
240 |
+
files_searched = [f'{cfg.dataset.hf_name}/{cfg.dataset.split}{ext}' for ext in SUPPORTED_EXTENSIONS]
|
241 |
+
raise FileNotFoundError(f'Could not find a file with any of ' + f'the supported extensions: {SUPPORTED_EXTENSIONS}\n' + f'at {files_searched}') from e
|
242 |
+
else:
|
243 |
+
log.debug(f'Could not find {name}, looking for another extension')
|
244 |
+
continue
|
245 |
+
os.makedirs(os.path.dirname(signal_file_path), exist_ok=True)
|
246 |
+
with open(signal_file_path, 'wb') as f:
|
247 |
+
f.write(b'local_rank0_completed_download')
|
248 |
+
with dist.local_rank_zero_download_and_wait(signal_file_path):
|
249 |
+
dist.barrier()
|
250 |
+
if dist.get_local_rank() == 0:
|
251 |
+
os.remove(signal_file_path)
|
252 |
+
dist.barrier()
|
253 |
+
break
|
254 |
+
return finetune_dir
|
255 |
+
|
256 |
+
def _build_collate_fn(dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator], int]:
|
257 |
+
dataset_cfg = dataloader_cfg.dataset
|
258 |
+
max_seq_len = dataset_cfg.max_seq_len
|
259 |
+
collate_fn = Seq2SeqFinetuningCollator(tokenizer=tokenizer, max_seq_len=max_seq_len, decoder_only_format=dataset_cfg.decoder_only_format, target_responses=dataset_cfg.get('target_responses', _DEFAULT_TARGET_RESPONSES), target_prompts=dataset_cfg.get('target_prompts', _DEFAULT_TARGET_PROMPTS), allow_pad_trimming=dataset_cfg.get('allow_pad_trimming', False))
|
260 |
+
packing_ratio = dataset_cfg.get('packing_ratio')
|
261 |
+
max_leftover_bins_to_keep = dataset_cfg.get('max_leftover_bins_to_keep')
|
262 |
+
if packing_ratio is None:
|
263 |
+
if max_leftover_bins_to_keep is not None:
|
264 |
+
raise ValueError('dataset.max_leftover_bins_to_keep has been defined, ' + 'but dataset.packing_ratio has not been set. Please set ' + 'the latter to turn on packing or remove the former from the config.')
|
265 |
+
return (collate_fn, device_batch_size)
|
266 |
+
if packing_ratio == 'auto':
|
267 |
+
packing_ratio = auto_packing_ratio(dataloader_cfg, tokenizer, device_batch_size)
|
268 |
+
if isinstance(packing_ratio, str):
|
269 |
+
raise ValueError('dataset.packing_ratio must be a float or "auto", but it was set to ' + f'{packing_ratio}.')
|
270 |
+
log.info(f'Using packing ratio {packing_ratio}')
|
271 |
+
if packing_ratio == 1.0:
|
272 |
+
return (collate_fn, device_batch_size)
|
273 |
+
elif packing_ratio < 1.0:
|
274 |
+
raise ValueError('packing_ratio must be >= 1, if supplied')
|
275 |
+
if not dataset_cfg.decoder_only_format:
|
276 |
+
raise NotImplementedError('On-the-fly packing is currently only supported for decoder-only formats.')
|
277 |
+
collate_fn = BinPackCollator(collator=collate_fn, target_batch_size=device_batch_size, max_seq_len=max_seq_len, pad_token_id=tokenizer.pad_token_id, padding_side=tokenizer.padding_side, max_leftover_bins_to_keep=max_leftover_bins_to_keep)
|
278 |
+
n_examples_to_pack = int(device_batch_size * packing_ratio)
|
279 |
+
return (collate_fn, n_examples_to_pack)
|
280 |
+
if __name__ == '__main__':
|
281 |
+
import torch
|
282 |
+
from .utils import build_tokenizer
|
283 |
+
cfg = om.create({'dataset': {'hf_name': 'tatsu-lab/alpaca', 'preprocessing_fn': 'llmfoundry.data.finetuning.tasks:alpaca_preprocessing_function', 'split': 'train', 'packing_ratio': 18.0, 'max_seq_len': 2048, 'decoder_only_format': True, 'allow_pad_trimming': False, 'num_canonical_nodes': 472, 'shuffle': True, 'target_responses': 'last', 'target_prompts': 'none'}, 'drop_last': False, 'num_workers': 0, 'pin_memory': False, 'prefetch_factor': None, 'persistent_workers': False, 'timeout': 0})
|
284 |
+
tokenizer_name = 'EleutherAI/gpt-neox-20b'
|
285 |
+
tokenizer_kwargs = {'model_max_length': cfg.dataset.max_seq_len}
|
286 |
+
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
|
287 |
+
device_batch_size = 1
|
288 |
+
dataloader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size).dataloader
|
289 |
+
packing = cfg.dataset.get('packing_ratio') is not None
|
290 |
+
for i, batch in enumerate(dataloader):
|
291 |
+
if i >= 5:
|
292 |
+
break
|
293 |
+
print(f'-----Batch {i}-----')
|
294 |
+
for k, v in batch.items():
|
295 |
+
if isinstance(v, torch.Tensor):
|
296 |
+
print(k, v.shape)
|
297 |
+
else:
|
298 |
+
print(k, v)
|
299 |
+
for j in range(device_batch_size):
|
300 |
+
print(f'--- Sample {j} ---')
|
301 |
+
if cfg.dataset.decoder_only_format:
|
302 |
+
if packing:
|
303 |
+
for subseq in range(int(batch['sequence_id'][j].max()) + 1):
|
304 |
+
is_subseq = batch['sequence_id'][j] == subseq
|
305 |
+
print('\x1b[93m{}\x1b[00m\n'.format('INPUT IDS:'), tokenizer.decode(batch['input_ids'][j, torch.logical_and(is_subseq, batch['attention_mask'][j] == 1)], skip_special_tokens=False, clean_up_tokenization_spaces=True))
|
306 |
+
print('\x1b[91m{}\x1b[00m\n'.format('TARGET: '), tokenizer.decode(batch['input_ids'][j, torch.logical_and(is_subseq, batch['labels'][j] != _HF_IGNORE_INDEX)], skip_special_tokens=False, clean_up_tokenization_spaces=True))
|
307 |
+
else:
|
308 |
+
print('\x1b[93m{}\x1b[00m\n'.format('INPUT IDS:'), tokenizer.decode(batch['input_ids'][j, batch['attention_mask'][j] == 1], skip_special_tokens=False, clean_up_tokenization_spaces=True))
|
309 |
+
print('\x1b[91m{}\x1b[00m\n'.format('TARGET: '), tokenizer.decode(batch['input_ids'][j, batch['labels'][j] != _HF_IGNORE_INDEX], skip_special_tokens=False, clean_up_tokenization_spaces=True))
|
310 |
+
else:
|
311 |
+
print('\x1b[92m{}\x1b[00m\n'.format('CONTEXT: '), tokenizer.decode(batch['input_ids'][j, batch['attention_mask'][j] == 1], skip_special_tokens=False, clean_up_tokenization_spaces=True))
|
312 |
+
print('\x1b[91m{}\x1b[00m\n'.format('TARGET: '), tokenizer.decode(batch['labels'][j, batch['decoder_attention_mask'][j] == 1], skip_special_tokens=False, clean_up_tokenization_spaces=True))
|
313 |
+
print(' ')
|
eval_gauntlet_callback.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Aggregate ICL evals into composite scores."""
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
from enum import Enum
|
5 |
+
from typing import Dict, Optional
|
6 |
+
log = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
class Weighting(Enum):
|
9 |
+
EQUAL = 1
|
10 |
+
SAMPLE_SZ = 2
|
11 |
+
LOG_SAMPLE_SZ = 3
|
12 |
+
|
13 |
+
def calculate_named_averages(average_names: Dict[str, list], category_scores: Dict[str, float]):
|
14 |
+
"""Calculates the named averages based off the raw category scores.
|
15 |
+
|
16 |
+
For each named average, take a simple average of all the category scores associated with that named average.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
average_names (dict[str, list]): Contains a mapping of named averages to which category scores that average should consist of.
|
20 |
+
category_scores (dict[str, float]): Contains the raw scores corresponding to each category.
|
21 |
+
"""
|
22 |
+
average_scores = {}
|
23 |
+
for avg_name, category_list in average_names.items():
|
24 |
+
composite_subset = {category: score for category, score in category_scores.items() if category in category_list}
|
25 |
+
if len(composite_subset.values()) > 0:
|
26 |
+
average_scores[avg_name] = sum(composite_subset.values()) / len(composite_subset.values())
|
27 |
+
else:
|
28 |
+
average_scores[avg_name] = 0
|
29 |
+
return average_scores
|
30 |
+
|
31 |
+
class EvalGauntlet(Callback):
|
32 |
+
"""The EvalGauntlet aggregates ICL eval results.
|
33 |
+
|
34 |
+
After `eval_end`, this callback inspects the logger for different ICL metrics and aggregates the scores according to the aggregation
|
35 |
+
specification provided in the constructor.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
logger_keys (list): These are the exact keys that the individual benchmark metrics will be
|
39 |
+
logged under in the logger after eval
|
40 |
+
categories (dict): This contains the list of categories, as well as the subtasks within them, the
|
41 |
+
random baseline accuracy of each subtask, and the number of fewshot examples
|
42 |
+
used for the task. See `llmfoundry/scripts/eval/yamls/eval_gauntlet_v0.2.yaml` to see the structure.
|
43 |
+
weighting (Weighting): The weighting scheme used to balance different tasks within each category.
|
44 |
+
Either assign them all equal weight, assign them weight proportional
|
45 |
+
to the dataset size, or assign them weight proportional to the log2 of the dataset size.
|
46 |
+
Options are 'EQUAL', 'SAMPLE_SZ', and 'LOG_SAMPLE_SZ'.
|
47 |
+
subtract_random_baseline (bool): Flag determining whether to subtract random baseline accuracy
|
48 |
+
from the performance on each individual benchmark before aggregating.
|
49 |
+
rescale_accuracy (bool): Flag determining whether to rescale the accuracy on each benchmark
|
50 |
+
by (1-random_baseline_accuracy) before aggregating. Using this ensures that all benchmarks max out at 1.0.
|
51 |
+
benchmark_sizes (Optional[dict]): Optional data on benchmark sizes, used when not relying on equal weighting.
|
52 |
+
averages (Optional[dict]): Optional dictionary specifying a mapping from a average names to lists of categories used produce each named average.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self, logger_keys: list, categories: dict, weighting: str='EQUAL', subtract_random_baseline: bool=True, rescale_accuracy: bool=True, benchmark_sizes: Optional[dict]=None, averages: Optional[dict]=None):
|
56 |
+
if isinstance(logger_keys, dict):
|
57 |
+
raise ValueError('logger_keys now requires a list type as input, not a dict')
|
58 |
+
if weighting != Weighting.EQUAL and benchmark_sizes is None:
|
59 |
+
raise Exception('When not using equal weighting, you must provide the benchmark sizes.')
|
60 |
+
if rescale_accuracy and (not subtract_random_baseline):
|
61 |
+
raise Exception('Only use accuracy rescaling in conjunction with subtracting random baseline accuracy.')
|
62 |
+
self.categories = categories
|
63 |
+
self.category_names = [conf.get('name') for conf in self.categories]
|
64 |
+
self.weighting = Weighting[weighting]
|
65 |
+
self.subtract_random_baseline = subtract_random_baseline
|
66 |
+
self.rescale_accuracy = rescale_accuracy
|
67 |
+
self.logger_keys = logger_keys
|
68 |
+
for category in self.categories:
|
69 |
+
for benchmark in category['benchmarks']:
|
70 |
+
bench_name = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot"
|
71 |
+
if self.weighting != Weighting.EQUAL:
|
72 |
+
assert benchmark_sizes is not None
|
73 |
+
cumulative_samples = max(sum((count for name, count in benchmark_sizes.items() if name.startswith(bench_name))), 1)
|
74 |
+
else:
|
75 |
+
cumulative_samples = -1
|
76 |
+
weight = None
|
77 |
+
if self.weighting == Weighting.EQUAL:
|
78 |
+
weight = 1
|
79 |
+
elif self.weighting == Weighting.SAMPLE_SZ:
|
80 |
+
weight = cumulative_samples
|
81 |
+
elif self.weighting == Weighting.LOG_SAMPLE_SZ:
|
82 |
+
weight = max(math.log2(cumulative_samples), 1)
|
83 |
+
assert weight is not None
|
84 |
+
benchmark['weighting'] = weight
|
85 |
+
self.averages = {}
|
86 |
+
if averages is not None:
|
87 |
+
self.averages = averages
|
88 |
+
else:
|
89 |
+
self.averages['default_average'] = self.category_names
|
90 |
+
for avg_name in self.averages:
|
91 |
+
if avg_name in self.category_names:
|
92 |
+
raise ValueError(f'Found average name `{avg_name}` used as category name. Average names and category names must be non-overlapping.')
|
93 |
+
|
94 |
+
def extract_metrics_from_state(self, state: State) -> Dict[str, float]:
|
95 |
+
results = {}
|
96 |
+
for key in self.logger_keys:
|
97 |
+
dl_name, metric_name = (key.split('/')[1:-1], key.split('/')[-1])
|
98 |
+
if 'Accuracy' not in metric_name:
|
99 |
+
continue
|
100 |
+
metric = state.eval_metrics.get('/'.join(dl_name), {}).get(metric_name, None)
|
101 |
+
if metric is None:
|
102 |
+
continue
|
103 |
+
val = metric.compute().item()
|
104 |
+
key = '/'.join(dl_name[0:2])
|
105 |
+
if key not in results:
|
106 |
+
results[key] = []
|
107 |
+
results[key].append(val)
|
108 |
+
return {k: sum(v) / len(v) for k, v in results.items()}
|
109 |
+
|
110 |
+
def eval_after_all(self, state: State, logger: Logger) -> Dict[str, float]:
|
111 |
+
computed_metrics = self.extract_metrics_from_state(state)
|
112 |
+
if len(computed_metrics) == 0:
|
113 |
+
return {}
|
114 |
+
category_scores = {}
|
115 |
+
for category in self.categories:
|
116 |
+
missing_metrics = []
|
117 |
+
category_scores[category['name']] = []
|
118 |
+
for benchmark in category['benchmarks']:
|
119 |
+
key = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot"
|
120 |
+
if key not in computed_metrics:
|
121 |
+
log.warning(f'Could not find results for benchmark: {benchmark}.')
|
122 |
+
missing_metrics.append(key)
|
123 |
+
else:
|
124 |
+
score = computed_metrics[key]
|
125 |
+
if self.subtract_random_baseline:
|
126 |
+
score -= benchmark['random_baseline']
|
127 |
+
if self.rescale_accuracy and self.subtract_random_baseline:
|
128 |
+
score /= 1.0 - benchmark['random_baseline']
|
129 |
+
category_scores[category['name']].append({'name': benchmark['name'], 'score': score, 'weighting': benchmark['weighting']})
|
130 |
+
if len(missing_metrics) > 0:
|
131 |
+
log.warning(f"Removing category `{category['name']}` from scores because benchmarks were missing: {missing_metrics}")
|
132 |
+
del category_scores[category['name']]
|
133 |
+
continue
|
134 |
+
total_weight = sum((k['weighting'] for k in category_scores[category['name']]))
|
135 |
+
category_scores[category['name']] = sum((k['score'] * (k['weighting'] / total_weight) for k in category_scores[category['name']]))
|
136 |
+
named_averages = calculate_named_averages(self.averages, category_scores)
|
137 |
+
category_scores.update(named_averages)
|
138 |
+
category_scores = {f'icl/metrics/eval_gauntlet/{k}': v for k, v in category_scores.items()}
|
139 |
+
if logger is not None:
|
140 |
+
logger.log_metrics(category_scores)
|
141 |
+
return category_scores
|
exceptions.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Custom exceptions for the LLMFoundry."""
|
2 |
+
from collections.abc import Mapping
|
3 |
+
from typing import Any, Dict, List
|
4 |
+
|
5 |
+
class MissingHuggingFaceURLSplitError(ValueError):
|
6 |
+
"""Error thrown when there's no split used in HF dataset config."""
|
7 |
+
|
8 |
+
def __init__(self) -> None:
|
9 |
+
message = 'When using a HuggingFace dataset from a URL, you must set the ' + '`split` key in the dataset config.'
|
10 |
+
super().__init__(message)
|
11 |
+
|
12 |
+
class NotEnoughDatasetSamplesError(ValueError):
|
13 |
+
"""Error thrown when there is not enough data to train a model."""
|
14 |
+
|
15 |
+
def __init__(self, dataset_name: str, split: str, dataloader_batch_size: int, world_size: int, full_dataset_size: int, minimum_dataset_size: int) -> None:
|
16 |
+
self.dataset_name = dataset_name
|
17 |
+
self.split = split
|
18 |
+
self.dataloader_batch_size = dataloader_batch_size
|
19 |
+
self.world_size = world_size
|
20 |
+
self.full_dataset_size = full_dataset_size
|
21 |
+
self.minimum_dataset_size = minimum_dataset_size
|
22 |
+
message = f'Your dataset (name={dataset_name}, split={split}) ' + f'has {full_dataset_size} samples, but your minimum batch size ' + f'is {minimum_dataset_size} because you are running on {world_size} gpus and ' + f'your per device batch size is {dataloader_batch_size}. Please increase the number ' + f'of samples in your dataset to at least {minimum_dataset_size}.'
|
23 |
+
super().__init__(message)
|
24 |
+
|
25 |
+
class UnknownExampleTypeError(KeyError):
|
26 |
+
"""Error thrown when an unknown example type is used in a task."""
|
27 |
+
|
28 |
+
def __init__(self, example: Mapping) -> None:
|
29 |
+
message = f'Unknown example type example={example!r}'
|
30 |
+
super().__init__(message)
|
31 |
+
|
32 |
+
class TooManyKeysInExampleError(ValueError):
|
33 |
+
"""Error thrown when a data sample has too many keys."""
|
34 |
+
|
35 |
+
def __init__(self, desired_keys: set[str], keys: set[str]) -> None:
|
36 |
+
message = f'Data sample has {len(keys)} keys in `allowed_keys`: {desired_keys} Please specify exactly one. Provided keys: {keys}'
|
37 |
+
super().__init__(message)
|
38 |
+
|
39 |
+
class NotEnoughChatDataError(ValueError):
|
40 |
+
"""Error thrown when there is not enough chat data to train a model."""
|
41 |
+
|
42 |
+
def __init__(self) -> None:
|
43 |
+
message = 'Chat example must have at least two messages'
|
44 |
+
super().__init__(message)
|
45 |
+
|
46 |
+
class ConsecutiveRepeatedChatRolesError(ValueError):
|
47 |
+
"""Error thrown when there are consecutive repeated chat roles."""
|
48 |
+
|
49 |
+
def __init__(self, repeated_role: str) -> None:
|
50 |
+
self.repeated_role = repeated_role
|
51 |
+
message = f'Conversation roles must alternate but found {repeated_role} repeated consecutively.'
|
52 |
+
super().__init__(message)
|
53 |
+
|
54 |
+
class InvalidLastChatMessageRoleError(ValueError):
|
55 |
+
"""Error thrown when the last message role in a chat example is invalid."""
|
56 |
+
|
57 |
+
def __init__(self, last_role: str, expected_roles: set[str]) -> None:
|
58 |
+
message = f'Invalid last message role: {last_role}. Expected one of: {expected_roles}'
|
59 |
+
super().__init__(message)
|
60 |
+
|
61 |
+
class IncorrectMessageKeyQuantityError(ValueError):
|
62 |
+
"""Error thrown when a message has an incorrect number of keys."""
|
63 |
+
|
64 |
+
def __init__(self, keys: List[str]) -> None:
|
65 |
+
self.keys = keys
|
66 |
+
message = f'Expected 2 keys in message, but found {len(keys)}'
|
67 |
+
super().__init__(message)
|
68 |
+
|
69 |
+
class InvalidRoleError(ValueError):
|
70 |
+
"""Error thrown when a role is invalid."""
|
71 |
+
|
72 |
+
def __init__(self, role: str, valid_roles: set[str]) -> None:
|
73 |
+
self.role = role
|
74 |
+
self.valid_roles = valid_roles
|
75 |
+
message = f'Expected role to be one of {valid_roles} but found: {role}'
|
76 |
+
super().__init__(message)
|
77 |
+
|
78 |
+
class InvalidContentTypeError(TypeError):
|
79 |
+
"""Error thrown when the content type is invalid."""
|
80 |
+
|
81 |
+
def __init__(self, content_type: type) -> None:
|
82 |
+
self.content_type = content_type
|
83 |
+
message = f'Expected content to be a string, but found {content_type}'
|
84 |
+
super().__init__(message)
|
85 |
+
|
86 |
+
class InvalidPromptTypeError(TypeError):
|
87 |
+
"""Error thrown when the prompt type is invalid."""
|
88 |
+
|
89 |
+
def __init__(self, prompt_type: type) -> None:
|
90 |
+
self.prompt_type = prompt_type
|
91 |
+
message = f'Expected prompt to be a string, but found {prompt_type}'
|
92 |
+
super().__init__(message)
|
93 |
+
|
94 |
+
class InvalidResponseTypeError(TypeError):
|
95 |
+
"""Error thrown when the response type is invalid."""
|
96 |
+
|
97 |
+
def __init__(self, response_type: type) -> None:
|
98 |
+
self.response_type = response_type
|
99 |
+
message = f'Expected response to be a string, but found {response_type}'
|
100 |
+
super().__init__(message)
|
101 |
+
|
102 |
+
class InvalidPromptResponseKeysError(ValueError):
|
103 |
+
"""Error thrown when missing expected prompt and response keys."""
|
104 |
+
|
105 |
+
def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]):
|
106 |
+
self.example = example
|
107 |
+
message = f'Expected mapping={mapping!r} to have keys "prompt" and "response".'
|
108 |
+
super().__init__(message)
|
109 |
+
|
110 |
+
class InvalidFileExtensionError(FileNotFoundError):
|
111 |
+
"""Error thrown when a file extension is not a safe extension."""
|
112 |
+
|
113 |
+
def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None:
|
114 |
+
self.dataset_name = dataset_name
|
115 |
+
self.valid_extensions = valid_extensions
|
116 |
+
message = f'safe_load is set to True. No data files with safe extensions {valid_extensions} ' + f'found for dataset at local path {dataset_name}.'
|
117 |
+
super().__init__(message)
|
118 |
+
|
119 |
+
class UnableToProcessPromptResponseError(ValueError):
|
120 |
+
"""Error thrown when a prompt and response cannot be processed."""
|
121 |
+
|
122 |
+
def __init__(self, input: Dict) -> None:
|
123 |
+
message = f'Unable to extract prompt/response from {input}'
|
124 |
+
super().__init__(message)
|
125 |
+
|
126 |
+
class ClusterDoesNotExistError(ValueError):
|
127 |
+
"""Error thrown when the cluster does not exist."""
|
128 |
+
|
129 |
+
def __init__(self, cluster_id: str) -> None:
|
130 |
+
self.cluster_id = cluster_id
|
131 |
+
message = f'Cluster with id {cluster_id} does not exist. Check cluster id and try again!'
|
132 |
+
super().__init__(message)
|
133 |
+
|
134 |
+
class FailedToCreateSQLConnectionError(RuntimeError):
|
135 |
+
"""Error thrown when client can't sql connect to Databricks."""
|
136 |
+
|
137 |
+
def __init__(self) -> None:
|
138 |
+
message = 'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!'
|
139 |
+
super().__init__(message)
|
140 |
+
|
141 |
+
class FailedToConnectToDatabricksError(RuntimeError):
|
142 |
+
"""Error thrown when the client fails to connect to Databricks."""
|
143 |
+
|
144 |
+
def __init__(self) -> None:
|
145 |
+
message = 'Failed to create databricks connection. Check hostname and access token!'
|
146 |
+
super().__init__(message)
|
147 |
+
|
148 |
+
class InputFolderMissingDataError(ValueError):
|
149 |
+
"""Error thrown when the input folder is missing data."""
|
150 |
+
|
151 |
+
def __init__(self, input_folder: str) -> None:
|
152 |
+
self.input_folder = input_folder
|
153 |
+
message = f'No text files were found at {input_folder}.'
|
154 |
+
super().__init__(message)
|
155 |
+
|
156 |
+
class OutputFolderNotEmptyError(FileExistsError):
|
157 |
+
"""Error thrown when the output folder is not empty."""
|
158 |
+
|
159 |
+
def __init__(self, output_folder: str) -> None:
|
160 |
+
self.output_folder = output_folder
|
161 |
+
message = f'{output_folder} is not empty. Please remove or empty it and retry.'
|
162 |
+
super().__init__(message)
|
fdiff_callback.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Monitor rate of change of loss."""
|
2 |
+
from __future__ import annotations
|
3 |
+
import torch
|
4 |
+
|
5 |
+
class FDiffMetrics(Callback):
|
6 |
+
"""Rate of change of metrics.
|
7 |
+
|
8 |
+
tracks and plots the rate of change of metrics effectively taking the
|
9 |
+
numerical derivative of the metrics
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, diff_train_metrics: bool=False, diff_eval_metrics: bool=True):
|
13 |
+
self.diff_train_metrics = diff_train_metrics
|
14 |
+
self.diff_eval_metrics = diff_eval_metrics
|
15 |
+
self.train_prev_loss = None
|
16 |
+
self.train_prev_metric = {}
|
17 |
+
self.eval_prev_metric = {}
|
18 |
+
|
19 |
+
def batch_end(self, state: State, logger: Logger) -> None:
|
20 |
+
if self.diff_train_metrics:
|
21 |
+
if not isinstance(state.loss, torch.Tensor):
|
22 |
+
raise NotImplementedError('Multiple losses not supported yet')
|
23 |
+
loss = state.loss.item()
|
24 |
+
if self.train_prev_loss:
|
25 |
+
logger.log_metrics({'loss/train/total_fdiff': loss - self.train_prev_loss})
|
26 |
+
self.train_prev_loss = loss
|
27 |
+
for k in self.train_prev_metric.keys():
|
28 |
+
logger.log_metrics({f'metrics/train/{k}_fdiff': state.train_metric_values[k] - self.train_prev_metric[k]})
|
29 |
+
for k in state.train_metric_values.keys():
|
30 |
+
value = state.train_metric_values[k]
|
31 |
+
self.train_prev_metric[k] = value
|
32 |
+
|
33 |
+
def eval_end(self, state: State, logger: Logger) -> None:
|
34 |
+
if self.diff_eval_metrics:
|
35 |
+
evaluator = state.dataloader_label
|
36 |
+
assert evaluator is not None, 'dataloader should have been set'
|
37 |
+
metrics = list(state.eval_metrics[evaluator].keys())
|
38 |
+
for k in metrics:
|
39 |
+
mkey = '/'.join(['metrics', evaluator, k])
|
40 |
+
if mkey in self.eval_prev_metric.keys():
|
41 |
+
logger.log_metrics({f'{mkey}_fdiff': state.eval_metric_values[k] - self.eval_prev_metric[mkey]})
|
42 |
+
for k in metrics:
|
43 |
+
mkey = '/'.join(['metrics', evaluator, k])
|
44 |
+
self.eval_prev_metric[mkey] = state.eval_metric_values[k]
|
ffn.py
CHANGED
@@ -59,8 +59,7 @@ class MPTMLP(nn.Module):
|
|
59 |
super().__init__()
|
60 |
ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, ffn_hidden_size)
|
61 |
self.fc_kwargs: dict[str, Any] = {'bias': bias}
|
62 |
-
|
63 |
-
self.fc_kwargs['device'] = device
|
64 |
self.up_proj = FC_CLASS_REGISTRY[fc_type](d_model, ffn_hidden_size, **self.fc_kwargs)
|
65 |
self.act = act_fn
|
66 |
self.down_proj = FC_CLASS_REGISTRY[fc_type](ffn_hidden_size, d_model, **self.fc_kwargs)
|
@@ -75,6 +74,7 @@ class MPTGLU(MPTMLP):
|
|
75 |
super().__init__(d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, ffn_hidden_size=ffn_hidden_size, act_fn=act_fn, device=device, bias=bias)
|
76 |
self.gate_proj = FC_CLASS_REGISTRY[fc_type](d_model, self.up_proj.out_features, **self.fc_kwargs)
|
77 |
|
|
|
78 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
79 |
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
|
80 |
FFN_CLASS_REGISTRY = {'mptmlp': MPTMLP, 'mptglu': MPTGLU}
|
|
|
59 |
super().__init__()
|
60 |
ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, ffn_hidden_size)
|
61 |
self.fc_kwargs: dict[str, Any] = {'bias': bias}
|
62 |
+
self.fc_kwargs['device'] = device
|
|
|
63 |
self.up_proj = FC_CLASS_REGISTRY[fc_type](d_model, ffn_hidden_size, **self.fc_kwargs)
|
64 |
self.act = act_fn
|
65 |
self.down_proj = FC_CLASS_REGISTRY[fc_type](ffn_hidden_size, d_model, **self.fc_kwargs)
|
|
|
74 |
super().__init__(d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, ffn_hidden_size=ffn_hidden_size, act_fn=act_fn, device=device, bias=bias)
|
75 |
self.gate_proj = FC_CLASS_REGISTRY[fc_type](d_model, self.up_proj.out_features, **self.fc_kwargs)
|
76 |
|
77 |
+
@torch.compile
|
78 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
79 |
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
|
80 |
FFN_CLASS_REGISTRY = {'mptmlp': MPTMLP, 'mptglu': MPTGLU}
|
finetuning.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .collator import Seq2SeqFinetuningCollator
|
2 |
+
from .dataloader import build_finetuning_dataloader
|
hf.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .hf_causal_lm import ComposerHFCausalLM
|
2 |
+
from .hf_fsdp import prepare_hf_causal_lm_model_for_fsdp, prepare_hf_enc_dec_model_for_fsdp, prepare_hf_model_for_fsdp
|
3 |
+
from .hf_t5 import ComposerHFT5
|
hf_causal_lm.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`."""
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import warnings
|
5 |
+
from typing import TYPE_CHECKING, Any, Dict, Mapping
|
6 |
+
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase
|
7 |
+
from .hf_fsdp import hf_get_init_device
|
8 |
+
from .model_wrapper import HuggingFaceModelWithFSDP
|
9 |
+
from .attention import is_flash_v2_installed
|
10 |
+
from .utils import init_empty_weights
|
11 |
+
from .config_utils import pop_config
|
12 |
+
if TYPE_CHECKING:
|
13 |
+
from peft import PeftConfig
|
14 |
+
log = logging.getLogger(__name__)
|
hf_checkpointer.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import copy
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
import tempfile
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Any, Dict, Optional, Sequence, Union
|
10 |
+
import torch
|
11 |
+
from mlflow.transformers import _fetch_model_card, _write_license_information
|
12 |
+
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
13 |
+
from .mpt import MPTConfig, MPTForCausalLM
|
14 |
+
from .utils import init_empty_weights
|
15 |
+
from .huggingface_hub_utils import edit_files_for_hf_compatibility
|
16 |
+
log = logging.getLogger(__name__)
|
17 |
+
_LICENSE_FILE_PATTERN = re.compile('license(\\.[a-z]+|$)', re.IGNORECASE)
|
18 |
+
|
19 |
+
def _maybe_get_license_filename(local_dir: str, pretrained_model_name: Optional[str]=None) -> Optional[str]:
|
20 |
+
"""Returns the name of the license file if it exists in the local_dir.
|
21 |
+
|
22 |
+
Note: This is intended to be consistent with the code in MLflow.
|
23 |
+
https://github.com/mlflow/mlflow/blob/5d13d6ec620a02de9a5e31201bf1becdb9722ea5/mlflow/transformers/__init__.py#L1152
|
24 |
+
|
25 |
+
Since LLM Foundry supports local model files being used rather than fetching the files from the Hugging Face Hub,
|
26 |
+
MLflow's logic to fetch and write the license information on model save is not applicable; it will try to search for
|
27 |
+
a Hugging Face repo named after the local path. However, the user can provide the original pretrained model name,
|
28 |
+
in which case this function will use that to fetch the correct license information.
|
29 |
+
|
30 |
+
If the license file does not exist, returns None.
|
31 |
+
"""
|
32 |
+
try:
|
33 |
+
license_filename = next((file for file in os.listdir(local_dir) if _LICENSE_FILE_PATTERN.search(file)))
|
34 |
+
if pretrained_model_name is not None:
|
35 |
+
log.info(f'Overwriting license file {license_filename} with license info for model {pretrained_model_name} from Hugging Face Hub')
|
36 |
+
os.remove(os.path.join(local_dir, license_filename))
|
37 |
+
model_card = _fetch_model_card(pretrained_model_name)
|
38 |
+
local_dir_path = Path(local_dir).absolute()
|
39 |
+
_write_license_information(pretrained_model_name, model_card, local_dir_path)
|
40 |
+
license_filename = next((file for file in os.listdir(local_dir) if _LICENSE_FILE_PATTERN.search(file)))
|
41 |
+
return license_filename
|
42 |
+
except StopIteration:
|
43 |
+
return None
|
44 |
+
|
45 |
+
class HuggingFaceCheckpointer(Callback):
|
46 |
+
"""Save a huggingface formatted checkpoint during training.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
save_folder (str): Top level folder to save checkpoints to (can be a
|
50 |
+
URI). It is likely that this would be the same as your save_folder.
|
51 |
+
save_interval: Union[str, int, Time]: The interval describing how often
|
52 |
+
checkpoints should be saved. If an integer, it will be assumed to be
|
53 |
+
in :attr:`.TimeUnit.EPOCH`. Otherwise, the unit must be either
|
54 |
+
:attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
|
55 |
+
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
|
56 |
+
huggingface_folder_name (str): Folder to save each checkpoint under (can
|
57 |
+
be a format string). Default is ``ba{batch}``.
|
58 |
+
precision: The precision to save the model in. Default is ``float32``.
|
59 |
+
Options are ``bfloat16``, ``float16``, or ``float32``.
|
60 |
+
overwrite (bool): Whether to overwrite previous checkpoints.
|
61 |
+
mlflow_registered_model_name (Optional[str]): The name to register the
|
62 |
+
model under in the MLflow model registry. If ``None``, the model
|
63 |
+
will not be registered. Default is ``None``.
|
64 |
+
mlflow_logging_config (Optional[dict]): A dictionary of config arguments
|
65 |
+
that will get passed along to the MLflow ``save_model`` call.
|
66 |
+
Expected to contain ``metadata`` and ``task`` keys. If either is
|
67 |
+
unspecified, the defaults are ``'text-generation'`` and
|
68 |
+
``{'task': 'llm/v1/completions'}`` respectively. A default input example
|
69 |
+
and signature intended for text generation is also included under the
|
70 |
+
keys ``input_example`` and ``signature``.
|
71 |
+
flatten_imports (Sequence[str]): A sequence of import prefixes that will
|
72 |
+
be flattened when editing MPT files.
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(self, save_folder: str, save_interval: Union[str, int, Time], huggingface_folder_name: str='ba{batch}', precision: str='float32', overwrite: bool=True, mlflow_registered_model_name: Optional[str]=None, mlflow_logging_config: Optional[dict]=None, flatten_imports: Sequence[str]=('llmfoundry',)):
|
76 |
+
_, _, self.save_dir_format_str = parse_uri(save_folder)
|
77 |
+
self.overwrite = overwrite
|
78 |
+
self.precision = precision
|
79 |
+
self.dtype = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16}[precision]
|
80 |
+
self.flatten_imports = flatten_imports
|
81 |
+
self.mlflow_registered_model_name = mlflow_registered_model_name
|
82 |
+
if mlflow_logging_config is None:
|
83 |
+
mlflow_logging_config = {}
|
84 |
+
if self.mlflow_registered_model_name is not None:
|
85 |
+
import numpy as np
|
86 |
+
passed_metadata = mlflow_logging_config.get('metadata', {})
|
87 |
+
mlflow_logging_config['metadata'] = passed_metadata
|
88 |
+
mlflow_logging_config.setdefault('task', 'llm/v1/completions')
|
89 |
+
default_input_example = {'prompt': np.array(['What is Machine Learning?'])}
|
90 |
+
is_chat = mlflow_logging_config['task'].endswith('chat') or mlflow_logging_config['metadata'].get('task', '').endswith('chat')
|
91 |
+
if is_chat:
|
92 |
+
default_input_example = {'messages': np.array([{'role': 'user', 'content': 'What is Machine Learning?'}])}
|
93 |
+
mlflow_logging_config.setdefault('example_no_conversion', True)
|
94 |
+
mlflow_logging_config.setdefault('input_example', default_input_example)
|
95 |
+
self.mlflow_logging_config = mlflow_logging_config
|
96 |
+
self.huggingface_folder_name_fstr = os.path.join('huggingface', huggingface_folder_name)
|
97 |
+
self.save_interval: Time = Time.from_input(save_interval, TimeUnit.EPOCH)
|
98 |
+
self.check_interval = create_interval_scheduler(self.save_interval, include_end_of_training=True)
|
99 |
+
self.remote_ud = maybe_create_remote_uploader_downloader_from_uri(save_folder, loggers=[])
|
100 |
+
if self.remote_ud is not None:
|
101 |
+
self.remote_ud._num_concurrent_uploads = 4
|
102 |
+
self.last_checkpoint_batch: Optional[Time] = None
|
103 |
+
self.mlflow_loggers = []
|
104 |
+
|
105 |
+
def run_event(self, event: Event, state: State, logger: Logger) -> None:
|
106 |
+
if state.get_elapsed_duration() is not None and self.check_interval(state, event) and (self.last_checkpoint_batch != state.timestamp.batch):
|
107 |
+
self._save_checkpoint(state, logger)
|
108 |
+
elif event == Event.INIT:
|
109 |
+
if not isinstance(state.model, HuggingFaceModel):
|
110 |
+
raise ValueError(f'`HuggingFaceCheckpointer` is only compatible with `HuggingFaceModel`s. ' + f'Got {type(state.model)} instead.')
|
111 |
+
if self.remote_ud is not None:
|
112 |
+
self.remote_ud.init(state, logger)
|
113 |
+
state.callbacks.append(self.remote_ud)
|
114 |
+
if self.mlflow_registered_model_name is not None:
|
115 |
+
self.mlflow_loggers = [logger_destination for logger_destination in logger.destinations if isinstance(logger_destination, MLFlowLogger)]
|
116 |
+
if len(self.mlflow_loggers) == 0:
|
117 |
+
raise ValueError(f'`mlflow_registered_model_name` was set, but no `MLFlowLogger` was found in the `logger.destinations` list. ' + 'Please add an `MLFlowLogger` or set `mlflow_registered_model_name` to `None`.')
|
118 |
+
import mlflow
|
119 |
+
mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set('5GB')
|
120 |
+
|
121 |
+
def _is_last_batch(self, state: State):
|
122 |
+
elapsed_duration = state.get_elapsed_duration()
|
123 |
+
if elapsed_duration is not None and elapsed_duration >= 1.0:
|
124 |
+
return True
|
125 |
+
assert state.max_duration is not None
|
126 |
+
if self.save_interval.unit == TimeUnit.DURATION and self.save_interval.value == 1 and (state.max_duration.unit == TimeUnit.EPOCH):
|
127 |
+
assert state.dataloader_len is not None
|
128 |
+
return int(state.timestamp.batch) % math.ceil(state.max_duration.value * state.dataloader_len) == 0
|
129 |
+
return False
|
130 |
+
|
131 |
+
def _save_checkpoint(self, state: State, logger: Logger):
|
132 |
+
del logger
|
133 |
+
self.last_checkpoint_batch = state.timestamp.batch
|
134 |
+
log.info('Saving HuggingFace formatted checkpoint')
|
135 |
+
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
|
136 |
+
CONFIG_MAPPING._extra_content['mpt'] = MPTConfig
|
137 |
+
MPTConfig.register_for_auto_class()
|
138 |
+
MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')
|
139 |
+
save_dir = format_name_with_dist_and_time(str(Path(self.save_dir_format_str) / self.huggingface_folder_name_fstr), state.run_name, state.timestamp)
|
140 |
+
dir_context_mgr = tempfile.TemporaryDirectory() if self.remote_ud is not None else contextlib.nullcontext(enter_result=save_dir)
|
141 |
+
with dir_context_mgr as temp_save_dir:
|
142 |
+
assert isinstance(temp_save_dir, str)
|
143 |
+
log.debug('Gathering state dict')
|
144 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
145 |
+
if state.is_model_ddp:
|
146 |
+
composer_model = state.model.module
|
147 |
+
original_model: PreTrainedModel = state.model.module.model
|
148 |
+
state_dict_model = state.model.module.model
|
149 |
+
original_tokenizer = state.model.module.tokenizer
|
150 |
+
elif isinstance(state.model.model, FSDP):
|
151 |
+
composer_model = state.model
|
152 |
+
original_model: PreTrainedModel = state.model.model.module
|
153 |
+
state_dict_model = state.model.model
|
154 |
+
original_tokenizer = state.model.tokenizer
|
155 |
+
else:
|
156 |
+
composer_model = state.model
|
157 |
+
original_model: PreTrainedModel = state.model.model
|
158 |
+
state_dict_model = state.model.model
|
159 |
+
original_tokenizer = state.model.tokenizer
|
160 |
+
state_dict_context = fsdp_state_dict_type_context(original_model, state_dict_type='full') if not state.is_model_ddp and isinstance(state_dict_model, FSDP) else contextlib.nullcontext()
|
161 |
+
with state_dict_context:
|
162 |
+
state_dict = state_dict_model.state_dict()
|
163 |
+
for k, v in state_dict.items():
|
164 |
+
if isinstance(v, torch.Tensor):
|
165 |
+
state_dict[k] = v.to(dtype=self.dtype)
|
166 |
+
if dist.get_global_rank() == 0:
|
167 |
+
log.debug('Saving Hugging Face checkpoint in global rank 0')
|
168 |
+
copied_config = copy.deepcopy(original_model.config)
|
169 |
+
if copied_config.model_type == 'mpt':
|
170 |
+
copied_config.attn_config['attn_impl'] = 'torch'
|
171 |
+
copied_config.init_device = 'cpu'
|
172 |
+
log.debug(f'Creating new model instance')
|
173 |
+
if composer_model.using_peft:
|
174 |
+
active_adapter = original_model.active_adapter
|
175 |
+
base_model = original_model.get_base_model()
|
176 |
+
new_base_model_instance = type(base_model)(copied_config)
|
177 |
+
new_model_instance = type(original_model)(new_base_model_instance, original_model.peft_config[active_adapter])
|
178 |
+
new_model_instance.to(dtype=self.dtype)
|
179 |
+
else:
|
180 |
+
with init_empty_weights():
|
181 |
+
new_model_instance = type(original_model)(copied_config)
|
182 |
+
new_model_instance.load_state_dict(state_dict, assign=True)
|
183 |
+
del state_dict
|
184 |
+
log.debug('Saving Hugging Face checkpoint to disk')
|
185 |
+
new_model_instance.save_pretrained(temp_save_dir)
|
186 |
+
if original_tokenizer is not None:
|
187 |
+
assert isinstance(original_tokenizer, PreTrainedTokenizerBase)
|
188 |
+
original_tokenizer.save_pretrained(temp_save_dir)
|
189 |
+
if original_model.config.model_type == 'mpt':
|
190 |
+
log.debug('Editing MPT files for HuggingFace compatibility')
|
191 |
+
edit_files_for_hf_compatibility(temp_save_dir, self.flatten_imports)
|
192 |
+
if self.remote_ud is not None:
|
193 |
+
for filename in os.listdir(temp_save_dir):
|
194 |
+
remote_file_name = os.path.join(save_dir, filename)
|
195 |
+
remote_file_uri = self.remote_ud.remote_backend.get_uri(remote_file_name)
|
196 |
+
log.info(f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}')
|
197 |
+
self.remote_ud.upload_file(state=state, remote_file_name=remote_file_name, file_path=Path(os.path.join(temp_save_dir, filename)), overwrite=self.overwrite)
|
198 |
+
if self.mlflow_registered_model_name and self._is_last_batch(state):
|
199 |
+
components = {'model': new_model_instance}
|
200 |
+
if original_tokenizer is not None:
|
201 |
+
components['tokenizer'] = original_tokenizer
|
202 |
+
log.debug('Logging Hugging Face model to MLFlow')
|
203 |
+
for i, mlflow_logger in enumerate(self.mlflow_loggers):
|
204 |
+
log.debug(f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}')
|
205 |
+
local_save_path = str(Path(temp_save_dir) / f'mlflow_save_{i}')
|
206 |
+
import mlflow
|
207 |
+
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
|
208 |
+
model_saving_kwargs: Dict[str, Any] = {'path': local_save_path}
|
209 |
+
if composer_model.using_peft:
|
210 |
+
model_saving_kwargs['flavor'] = 'peft'
|
211 |
+
model_saving_kwargs['save_pretrained_dir'] = temp_save_dir
|
212 |
+
model_saving_kwargs['metadata'] = self.mlflow_logging_config['metadata']
|
213 |
+
else:
|
214 |
+
model_saving_kwargs['flavor'] = 'transformers'
|
215 |
+
model_saving_kwargs['transformers_model'] = components
|
216 |
+
model_saving_kwargs.update(self.mlflow_logging_config)
|
217 |
+
mlflow_logger.save_model(**model_saving_kwargs)
|
218 |
+
license_filename = _maybe_get_license_filename(local_save_path, self.mlflow_logging_config['metadata'].get('pretrained_model_name', None))
|
219 |
+
if license_filename is not None:
|
220 |
+
mlflow_logger._mlflow_client.log_artifact(mlflow_logger._run_id, os.path.join(local_save_path, license_filename))
|
221 |
+
mlflow_logger.register_model_with_run_id(model_uri=local_save_path, name=self.mlflow_registered_model_name, await_creation_for=3600)
|
hf_fsdp.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union
|
3 |
+
from transformers import PreTrainedModel
|
4 |
+
from transformers.models.opt.modeling_opt import OPTDecoder
|
5 |
+
if TYPE_CHECKING:
|
6 |
+
from peft import PeftModel
|
7 |
+
|
8 |
+
def rhasattr(obj: Any, attr: str) -> bool:
|
9 |
+
"""A chain-able attribute version of hasattr.
|
10 |
+
|
11 |
+
For example, to check if
|
12 |
+
`obj` has the attribute `foo.bar.baz`, you can use:
|
13 |
+
`rhasattr(obj, "foo.bar.baz")`
|
14 |
+
Reference: https://stackoverflow.com/a/67303315
|
15 |
+
"""
|
16 |
+
_nested_attrs = attr.split('.')
|
17 |
+
_curr_obj = obj
|
18 |
+
for _a in _nested_attrs[:-1]:
|
19 |
+
if hasattr(_curr_obj, _a):
|
20 |
+
_curr_obj = getattr(_curr_obj, _a)
|
21 |
+
else:
|
22 |
+
return False
|
23 |
+
return hasattr(_curr_obj, _nested_attrs[-1])
|
24 |
+
|
25 |
+
def rgetattr(obj: Any, attr: str, *args: List[Any]) -> Any:
|
26 |
+
"""A chain-able attribute version of getattr.
|
27 |
+
|
28 |
+
For example, to get the attribute `foo.bar.baz` from `obj`, you can use:
|
29 |
+
`rgetattr(obj, "foo.bar.baz")`
|
30 |
+
Reference: https://stackoverflow.com/a/31174427
|
31 |
+
"""
|
32 |
+
|
33 |
+
def _getattr(obj: Any, attr: str):
|
34 |
+
return getattr(obj, attr, *args)
|
35 |
+
return functools.reduce(_getattr, [obj] + attr.split('.'))
|
36 |
+
|
37 |
+
def findattr(obj: Any, attrs: Iterable[str]) -> Optional[Any]:
|
38 |
+
for attr in attrs:
|
39 |
+
if rhasattr(obj, attr):
|
40 |
+
return rgetattr(obj, attr)
|
41 |
+
return None
|
42 |
+
|
43 |
+
def hf_get_causal_base_model(model: PreTrainedModel) -> Any:
|
44 |
+
"""Returns the causal decoder backbone of the specified HuggingFace model.
|
45 |
+
|
46 |
+
Newer HF models have a `self.get_decoder()` method. Older models do not.
|
47 |
+
|
48 |
+
NOTE: Different model configurations have different causal decoder attribute
|
49 |
+
names.
|
50 |
+
- transformer: (GPT2LMHeadModel, GPTJConfig)
|
51 |
+
- model.decoder: (OPTConfig, BloomConfig)
|
52 |
+
- gpt_neox: (GPTNeoXConfig)
|
53 |
+
"""
|
54 |
+
if hasattr(model, 'get_decoder'):
|
55 |
+
return model.get_decoder()
|
56 |
+
decoder_attrs = ('transformer', 'model.decoder', 'gpt_neox', 'model.transformer')
|
57 |
+
causal_base_model = findattr(model, decoder_attrs)
|
58 |
+
if causal_base_model is None:
|
59 |
+
raise ValueError(f'Unable to FSDP-wrap model {model}. Please open a github issue to add support.')
|
60 |
+
return causal_base_model
|
61 |
+
|
62 |
+
def hf_get_hidden_layers(model: PreTrainedModel) -> Any:
|
63 |
+
"""Returns the hidden layers of the specified model.
|
64 |
+
|
65 |
+
Expects to receive the causal decoder backbone, not he XXForCausalLM wrapper.
|
66 |
+
|
67 |
+
NOTE: Different model configurations have different hidden layer attribute names.
|
68 |
+
- h: (BloomForCausalLM, GPT2LMHeadModel, GPTJForCausalLM)
|
69 |
+
- decoder.layers: (OPTForCausalLM)
|
70 |
+
- layers: (GPTNeoXForCausalLM, LlaMaForCausalLM)
|
71 |
+
- blocks: (MPTForCausalLM)
|
72 |
+
"""
|
73 |
+
hidden_layers_attrs = ('h', 'decoder.layers', 'layers', 'block', 'blocks')
|
74 |
+
layers = findattr(model, hidden_layers_attrs)
|
75 |
+
if layers is None:
|
76 |
+
raise ValueError(f'Unable to find hidden layer for {model}. Model must have one of the following attributes: {hidden_layers_attrs}')
|
77 |
+
return layers
|
78 |
+
|
79 |
+
def hf_get_init_device(init_device: Optional[str]) -> Optional[str]:
|
80 |
+
"""Returns the appropriate device to initialize models."""
|
81 |
+
if init_device == 'mixed':
|
82 |
+
if dist.get_local_rank() == 0:
|
83 |
+
return 'cpu'
|
84 |
+
return 'meta'
|
85 |
+
return init_device
|
86 |
+
|
87 |
+
def prepare_hf_model_for_fsdp(model: PreTrainedModel, init_device: Optional[str]) -> None:
|
88 |
+
"""FSDP wrap a HuggingFace model.
|
89 |
+
|
90 |
+
Call specific functions
|
91 |
+
"""
|
92 |
+
if model.config.is_encoder_decoder:
|
93 |
+
prepare_hf_enc_dec_model_for_fsdp(model, init_device)
|
94 |
+
else:
|
95 |
+
prepare_hf_causal_lm_model_for_fsdp(model, init_device)
|
96 |
+
|
97 |
+
def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel, 'PeftModel'], init_device: Optional[str]) -> None:
|
98 |
+
"""FSDP wrap a HuggingFace decoder.
|
99 |
+
|
100 |
+
Wrap any model for FSDP which follows one of the 3 existing conventions from
|
101 |
+
HuggingFace for decoder-only LLMs.
|
102 |
+
"""
|
103 |
+
causal_base_model = hf_get_causal_base_model(model)
|
104 |
+
if isinstance(causal_base_model, OPTDecoder) or model.config.model_type == 'olmo':
|
105 |
+
underlying_model = maybe_get_underlying_model(model)
|
106 |
+
underlying_model.model._fsdp_wrap = False
|
107 |
+
model_block = hf_get_hidden_layers(causal_base_model)
|
108 |
+
lm_head = model.get_output_embeddings()
|
109 |
+
try:
|
110 |
+
tied_embeddings = causal_base_model.get_input_embeddings()
|
111 |
+
except:
|
112 |
+
tied_embeddings = model.get_input_embeddings()
|
113 |
+
modules = {'base_model': causal_base_model, 'model_block': model_block, 'lm_head': lm_head, 'tied_embeddings': tied_embeddings}
|
114 |
+
for mod_name, module in modules.items():
|
115 |
+
if module is None:
|
116 |
+
raise ValueError(f'Unable to FSDP-wrap this model! `{mod_name}` does not ' + 'follow common layer/weight naming conventions.')
|
117 |
+
block_type = type(model_block[0])
|
118 |
+
if model.config.tie_word_embeddings:
|
119 |
+
causal_base_model._fsdp_wrap = False
|
120 |
+
tied_embeddings._fsdp_wrap = False
|
121 |
+
lm_head._fsdp_wrap = False
|
122 |
+
if hasattr(model, 'peft_type') and model.peft_type is not None:
|
123 |
+
peft_type = model.peft_type.lower()
|
124 |
+
active_adapters = [adapter.lower() for adapter in model.active_adapters]
|
125 |
+
for name, module in model.named_modules():
|
126 |
+
if peft_type in name.lower() and any((adapter in name.lower() for adapter in active_adapters)):
|
127 |
+
has_parameters = next(module.parameters(), None) is not None
|
128 |
+
has_buffers = next(module.buffers(), None) is not None
|
129 |
+
if has_parameters or has_buffers:
|
130 |
+
module._fsdp_wrap = True
|
131 |
+
model.fsdp_wrap_fn = lambda module: isinstance(module, block_type)
|
132 |
+
model.activation_checkpointing_fn = lambda module: isinstance(module, block_type)
|
133 |
+
|
134 |
+
def prepare_hf_enc_dec_model_for_fsdp(model: PreTrainedModel, init_device: Optional[str]) -> None:
|
135 |
+
"""Wrap an encoder/decoder HF model.
|
136 |
+
|
137 |
+
This works for T5, BART, Pegasus, PegasusX, but not all enc/dec (ProphetNet)
|
138 |
+
You have model.shared, model.encoder, model.decoder and model.lm_head, where
|
139 |
+
model.shared are the embeddings which are tied to model.lm_head, and
|
140 |
+
model.shared == model.encoder.embed_tokens and model.shared ==
|
141 |
+
model.decoder.embed_tokens
|
142 |
+
"""
|
143 |
+
tied_embeddings = model.get_input_embeddings()
|
144 |
+
encoder = model.get_encoder()
|
145 |
+
decoder = model.get_decoder()
|
146 |
+
lm_head = model.get_output_embeddings()
|
147 |
+
encoder_block = hf_get_hidden_layers(encoder)
|
148 |
+
decoder_block = hf_get_hidden_layers(decoder)
|
149 |
+
modules = {'encoder': encoder, 'decoder': decoder, 'encoder_block': encoder_block, 'decoder_block': decoder_block, 'lm_head': lm_head, 'tied_embeddings': tied_embeddings}
|
150 |
+
for mod_name, module in modules.items():
|
151 |
+
if module is None:
|
152 |
+
raise ValueError(f'Unable to FSDP-wrap this model! `{mod_name}` does not ' + 'follow common layer/weight naming conventions.')
|
153 |
+
decoder_block_type = type(decoder_block[0])
|
154 |
+
encoder_block_type = type(encoder_block[0])
|
155 |
+
if model.config.tie_word_embeddings:
|
156 |
+
tied_embeddings._fsdp_wrap = False
|
157 |
+
encoder._fsdp_wrap = False
|
158 |
+
decoder._fsdp_wrap = False
|
159 |
+
lm_head._fsdp_wrap = False
|
160 |
+
model.fsdp_wrap_fn = lambda module: isinstance(module, decoder_block_type)
|
161 |
+
model.activation_checkpointing_fn = lambda module: isinstance(module, decoder_block_type)
|
162 |
+
if encoder_block_type == decoder_block_type:
|
163 |
+
return
|
164 |
+
model.fsdp_wrap_fn = lambda module: isinstance(module, encoder_block_type)
|
165 |
+
model.activation_checkpointing_fn = lambda module: isinstance(module, encoder_block_type)
|
hf_t5.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Implements a Hugging Face T5 wrapped inside a :class:`.ComposerModel`."""
|
2 |
+
from __future__ import annotations
|
3 |
+
from typing import Mapping
|
4 |
+
from transformers import AutoConfig, PreTrainedTokenizerBase, T5ForConditionalGeneration
|
5 |
+
from .hf_fsdp import hf_get_init_device
|
6 |
+
from .model_wrapper import HuggingFaceModelWithFSDP
|
7 |
+
from .utils import init_empty_weights
|
8 |
+
from .warnings import experimental_class
|
huggingface_hub_utils.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import importlib
|
3 |
+
import os
|
4 |
+
from typing import Optional, Sequence
|
5 |
+
|
6 |
+
class DeleteSpecificNodes(ast.NodeTransformer):
|
7 |
+
|
8 |
+
def __init__(self, nodes_to_remove: list[ast.AST]):
|
9 |
+
self.nodes_to_remove = nodes_to_remove
|
10 |
+
|
11 |
+
def visit(self, node: ast.AST) -> Optional[ast.AST]:
|
12 |
+
if node in self.nodes_to_remove:
|
13 |
+
return None
|
14 |
+
return super().visit(node)
|
15 |
+
|
16 |
+
def convert_to_relative_import(module_name: str, original_parent_module_name: Optional[str]) -> str:
|
17 |
+
parts = module_name.split('.')
|
18 |
+
if parts[-1] == original_parent_module_name:
|
19 |
+
return '.'
|
20 |
+
return '.' + parts[-1]
|
21 |
+
|
22 |
+
def find_module_file(module_name: str) -> str:
|
23 |
+
if not module_name:
|
24 |
+
raise ValueError(f'Invalid input: module_name={module_name!r}')
|
25 |
+
module = importlib.import_module(module_name)
|
26 |
+
module_file = module.__file__
|
27 |
+
if module_file is None:
|
28 |
+
raise ValueError(f'Could not find file for module: {module_name}')
|
29 |
+
return module_file
|
30 |
+
|
31 |
+
def _flatten_import(node: ast.ImportFrom, flatten_imports_prefix: Sequence[str]) -> bool:
|
32 |
+
"""Returns True if import should be flattened.
|
33 |
+
|
34 |
+
Checks whether the node starts the same as any of the imports in
|
35 |
+
flatten_imports_prefix.
|
36 |
+
"""
|
37 |
+
for import_prefix in flatten_imports_prefix:
|
38 |
+
if node.module is not None and node.module.startswith(import_prefix):
|
39 |
+
return True
|
40 |
+
return False
|
41 |
+
|
42 |
+
def _remove_import(node: ast.ImportFrom, remove_imports_prefix: Sequence[str]) -> bool:
|
43 |
+
"""Returns True if import should be removed.
|
44 |
+
|
45 |
+
Checks whether the node starts the same as any of the imports in
|
46 |
+
remove_imports_prefix.
|
47 |
+
"""
|
48 |
+
for import_prefix in remove_imports_prefix:
|
49 |
+
if node.module is not None and node.module.startswith(import_prefix):
|
50 |
+
return True
|
51 |
+
return False
|
52 |
+
|
53 |
+
def process_file(file_path: str, folder_path: str, flatten_imports_prefix: Sequence[str], remove_imports_prefix: Sequence[str]) -> list[str]:
|
54 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
55 |
+
source = f.read()
|
56 |
+
parent_module_name = None
|
57 |
+
if os.path.basename(file_path) == '__init__.py':
|
58 |
+
parent_module_name = os.path.basename(os.path.dirname(file_path))
|
59 |
+
tree = ast.parse(source)
|
60 |
+
new_files_to_process = []
|
61 |
+
nodes_to_remove = []
|
62 |
+
for node in ast.walk(tree):
|
63 |
+
if isinstance(node, ast.ImportFrom) and node.module is not None and _remove_import(node, remove_imports_prefix):
|
64 |
+
nodes_to_remove.append(node)
|
65 |
+
elif isinstance(node, ast.ImportFrom) and node.module is not None and _flatten_import(node, flatten_imports_prefix):
|
66 |
+
module_path = find_module_file(node.module)
|
67 |
+
node.module = convert_to_relative_import(node.module, parent_module_name)
|
68 |
+
new_files_to_process.append(module_path)
|
69 |
+
elif isinstance(node, ast.ClassDef) and node.name.startswith('Composer'):
|
70 |
+
nodes_to_remove.append(node)
|
71 |
+
elif isinstance(node, ast.Assign) and len(node.targets) == 1 and isinstance(node.targets[0], ast.Name) and (node.targets[0].id == '__all__'):
|
72 |
+
nodes_to_remove.append(node)
|
73 |
+
transformer = DeleteSpecificNodes(nodes_to_remove)
|
74 |
+
new_tree = transformer.visit(tree)
|
75 |
+
new_filename = os.path.basename(file_path)
|
76 |
+
if new_filename == '__init__.py':
|
77 |
+
new_filename = file_path.split('/')[-2] + '.py'
|
78 |
+
new_file_path = os.path.join(folder_path, new_filename)
|
79 |
+
with open(new_file_path, 'w', encoding='utf-8') as f:
|
80 |
+
assert new_tree is not None
|
81 |
+
f.write(ast.unparse(new_tree))
|
82 |
+
return new_files_to_process
|
83 |
+
|
84 |
+
def edit_files_for_hf_compatibility(folder: str, flatten_imports_prefix: Sequence[str]=('llmfoundry',), remove_imports_prefix: Sequence[str]=('composer', 'omegaconf', 'llmfoundry.metrics')) -> None:
|
85 |
+
"""Edit files to be compatible with Hugging Face Hub.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
folder (str): The folder to process.
|
89 |
+
flatten_imports_prefix (Sequence[str], optional): Sequence of prefixes to flatten. Defaults to ('llmfoundry',).
|
90 |
+
remove_imports_prefix (Sequence[str], optional): Sequence of prefixes to remove. Takes precedence over flattening.
|
91 |
+
Defaults to ('composer', 'omegaconf', 'llmfoundry.metrics').
|
92 |
+
"""
|
93 |
+
files_to_process = [os.path.join(folder, filename) for filename in os.listdir(folder) if filename.endswith('.py')]
|
94 |
+
files_processed_and_queued = set(files_to_process)
|
95 |
+
while len(files_to_process) > 0:
|
96 |
+
to_process = files_to_process.pop()
|
97 |
+
if os.path.isfile(to_process) and to_process.endswith('.py'):
|
98 |
+
to_add = process_file(to_process, folder, flatten_imports_prefix, remove_imports_prefix)
|
99 |
+
for file in to_add:
|
100 |
+
if file not in files_processed_and_queued:
|
101 |
+
files_to_process.append(file)
|
102 |
+
files_processed_and_queued.add(file)
|
interfaces.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .callback_with_config import CallbackWithConfig
|
llmfoundry.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
warnings.filterwarnings('ignore', category=UserWarning, module='bitsandbytes')
|
3 |
+
import logging
|
4 |
+
from .logging_utils import SpecificWarningFilter
|
5 |
+
hf_dynamic_modules_logger = logging.getLogger('transformers.dynamic_module_utils')
|
6 |
+
new_files_warning_filter = SpecificWarningFilter('A new version of the following files was downloaded from')
|
7 |
+
hf_dynamic_modules_logger.addFilter(new_files_warning_filter)
|
8 |
+
from . import algorithms, callbacks, loggers, optim, registry, utils
|
9 |
+
from .data import ConcatTokensDataset, NoConcatDataset, Seq2SeqFinetuningCollator, build_finetuning_dataloader
|
10 |
+
from .hf import ComposerHFCausalLM, ComposerHFT5
|
11 |
+
from .attention import MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn, scaled_multihead_dot_product_attention
|
12 |
+
from .blocks import MPTBlock
|
13 |
+
from .ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
|
14 |
+
from .mpt import ComposerMPTCausalLM, MPTConfig, MPTForCausalLM, MPTModel, MPTPreTrainedModel
|
15 |
+
from .tokenizers import TiktokenTokenizerWrapper
|
16 |
+
__version__ = '0.7.0'
|
logging_utils.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
|
4 |
+
class SpecificWarningFilter(logging.Filter):
|
5 |
+
|
6 |
+
def __init__(self, message_to_suppress: str):
|
7 |
+
"""Filter out a specific warning message based on its content.
|
8 |
+
|
9 |
+
This can be useful for filtering out specific warning messages from third party packages.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
message_to_suppress (str): The warning message to suppress.
|
13 |
+
"""
|
14 |
+
super().__init__()
|
15 |
+
self.message_to_suppress = message_to_suppress
|
16 |
+
|
17 |
+
def filter(self, record: logging.LogRecord) -> bool:
|
18 |
+
return self.message_to_suppress not in record.getMessage()
|
19 |
+
|
20 |
+
def get_mosaicml_logger():
|
21 |
+
if os.environ.get(MOSAICML_PLATFORM_ENV_VAR, 'false').lower() == 'true' and os.environ.get(MOSAICML_ACCESS_TOKEN_ENV_VAR):
|
22 |
+
return MosaicMLLogger()
|
23 |
+
else:
|
24 |
+
return None
|
meta_init_context.py
CHANGED
@@ -95,5 +95,5 @@ def init_on_device(device: torch.device, include_buffers: bool=False):
|
|
95 |
nn.Module.register_parameter = old_register_parameter
|
96 |
if include_buffers:
|
97 |
nn.Module.register_buffer = old_register_buffer
|
98 |
-
for
|
99 |
setattr(torch, torch_function_name, old_torch_function)
|
|
|
95 |
nn.Module.register_parameter = old_register_parameter
|
96 |
if include_buffers:
|
97 |
nn.Module.register_buffer = old_register_buffer
|
98 |
+
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
99 |
setattr(torch, torch_function_name, old_torch_function)
|
model_download_utils.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions for downloading models."""
|
2 |
+
import copy
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
import subprocess
|
7 |
+
import time
|
8 |
+
import warnings
|
9 |
+
from http import HTTPStatus
|
10 |
+
from typing import Optional
|
11 |
+
from urllib.parse import urljoin
|
12 |
+
import huggingface_hub as hf_hub
|
13 |
+
import requests
|
14 |
+
import tenacity
|
15 |
+
import yaml
|
16 |
+
from bs4 import BeautifulSoup
|
17 |
+
from requests.packages.urllib3.exceptions import InsecureRequestWarning
|
18 |
+
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
|
19 |
+
from transformers.utils import WEIGHTS_INDEX_NAME as PYTORCH_WEIGHTS_INDEX_NAME
|
20 |
+
from transformers.utils import WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME
|
21 |
+
DEFAULT_IGNORE_PATTERNS = ['*.ckpt', '*.h5', '*.msgpack']
|
22 |
+
PYTORCH_WEIGHTS_PATTERN = 'pytorch_model*.bin*'
|
23 |
+
SAFE_WEIGHTS_PATTERN = 'model*.safetensors*'
|
24 |
+
TOKENIZER_FILES = ['special_tokens_map.json', 'tokenizer.json', 'tokenizer.model', 'tokenizer_config.json']
|
25 |
+
ORAS_PASSWD_PLACEHOLDER = '<placeholder_for_passwd>'
|
26 |
+
ORAS_CLI = 'oras'
|
27 |
+
log = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
@tenacity.retry(retry=tenacity.retry_if_not_exception_type((ValueError, hf_hub.utils.RepositoryNotFoundError)), stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(min=1, max=10))
|
30 |
+
def download_from_hf_hub(model: str, save_dir: str, prefer_safetensors: bool=True, tokenizer_only: bool=False, token: Optional[str]=None):
|
31 |
+
"""Downloads model files from a Hugging Face Hub model repo.
|
32 |
+
|
33 |
+
Only supports models stored in Safetensors and PyTorch formats for now. If both formats are available, only the
|
34 |
+
Safetensors weights will be downloaded unless `prefer_safetensors` is set to False.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
repo_id (str): The Hugging Face Hub repo ID.
|
38 |
+
save_dir (str, optional): The local path to the directory where the model files will be downloaded.
|
39 |
+
prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are
|
40 |
+
available. Defaults to True.
|
41 |
+
tokenizer_only (bool): If true, only download tokenizer files.
|
42 |
+
token (str, optional): The HuggingFace API token. If not provided, the token will be read from the
|
43 |
+
`HUGGING_FACE_HUB_TOKEN` environment variable.
|
44 |
+
|
45 |
+
Raises:
|
46 |
+
RepositoryNotFoundError: If the model repo doesn't exist or the token is unauthorized.
|
47 |
+
ValueError: If the model repo doesn't contain any supported model weights.
|
48 |
+
"""
|
49 |
+
repo_files = set(hf_hub.list_repo_files(model))
|
50 |
+
ignore_patterns = copy.deepcopy(DEFAULT_IGNORE_PATTERNS)
|
51 |
+
safetensors_available = SAFE_WEIGHTS_NAME in repo_files or SAFE_WEIGHTS_INDEX_NAME in repo_files
|
52 |
+
pytorch_available = PYTORCH_WEIGHTS_NAME in repo_files or PYTORCH_WEIGHTS_INDEX_NAME in repo_files
|
53 |
+
if safetensors_available and pytorch_available:
|
54 |
+
if prefer_safetensors:
|
55 |
+
log.info('Safetensors available and preferred. Excluding pytorch weights.')
|
56 |
+
ignore_patterns.append(PYTORCH_WEIGHTS_PATTERN)
|
57 |
+
else:
|
58 |
+
log.info('Pytorch available and preferred. Excluding safetensors weights.')
|
59 |
+
ignore_patterns.append(SAFE_WEIGHTS_PATTERN)
|
60 |
+
elif safetensors_available:
|
61 |
+
log.info('Only safetensors available. Ignoring weights preference.')
|
62 |
+
elif pytorch_available:
|
63 |
+
log.info('Only pytorch available. Ignoring weights preference.')
|
64 |
+
else:
|
65 |
+
raise ValueError(f'No supported model weights found in repo {model}.' + ' Please make sure the repo contains either safetensors or pytorch weights.')
|
66 |
+
allow_patterns = TOKENIZER_FILES if tokenizer_only else None
|
67 |
+
download_start = time.time()
|
68 |
+
hf_hub.snapshot_download(model, local_dir=save_dir, local_dir_use_symlinks=False, ignore_patterns=ignore_patterns, allow_patterns=allow_patterns, token=token)
|
69 |
+
download_duration = time.time() - download_start
|
70 |
+
log.info(f'Downloaded model {model} from Hugging Face Hub in {download_duration} seconds')
|
71 |
+
|
72 |
+
def _extract_links_from_html(html: str):
|
73 |
+
"""Extracts links from HTML content.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
html (str): The HTML content
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
list[str]: A list of links to download.
|
80 |
+
"""
|
81 |
+
soup = BeautifulSoup(html, 'html.parser')
|
82 |
+
links = [a['href'] for a in soup.find_all('a')]
|
83 |
+
return links
|
84 |
+
|
85 |
+
def _recursive_download(session: requests.Session, base_url: str, path: str, save_dir: str, ignore_cert: bool=False):
|
86 |
+
"""Downloads all files/subdirectories from a directory on a remote server.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
session: A requests.Session through which to make requests to the remote server.
|
90 |
+
url (str): The base URL where the files are located.
|
91 |
+
path (str): The path from the base URL to the files to download. The full URL for the download is equal to
|
92 |
+
'<base_url>/<path>'.
|
93 |
+
save_dir (str): The directory to save downloaded files to.
|
94 |
+
ignore_cert (bool): Whether or not to ignore the validity of the SSL certificate of the remote server.
|
95 |
+
Defaults to False.
|
96 |
+
WARNING: Setting this to true is *not* secure, as no certificate verification will be performed.
|
97 |
+
|
98 |
+
Raises:
|
99 |
+
PermissionError: If the remote server returns a 401 Unauthorized status code.
|
100 |
+
ValueError: If the remote server returns a 404 Not Found status code.
|
101 |
+
RuntimeError: If the remote server returns a status code other than 200 OK or 401 Unauthorized.
|
102 |
+
"""
|
103 |
+
url = urljoin(base_url, path)
|
104 |
+
print(url)
|
105 |
+
response = session.get(url, verify=not ignore_cert)
|
106 |
+
if response.status_code == HTTPStatus.UNAUTHORIZED:
|
107 |
+
raise PermissionError(f'Not authorized to download file from {url}. Received status code {response.status_code}. ')
|
108 |
+
elif response.status_code == HTTPStatus.NOT_FOUND:
|
109 |
+
raise ValueError(f'Could not find file at {url}. Received status code {response.status_code}')
|
110 |
+
elif response.status_code != HTTPStatus.OK:
|
111 |
+
raise RuntimeError(f'Could not download file from {url}. Received unexpected status code {response.status_code}')
|
112 |
+
if not url.endswith('/'):
|
113 |
+
save_path = os.path.join(save_dir, path)
|
114 |
+
parent_dir = os.path.dirname(save_path)
|
115 |
+
if not os.path.exists(parent_dir):
|
116 |
+
os.makedirs(parent_dir)
|
117 |
+
with open(save_path, 'wb') as f:
|
118 |
+
f.write(response.content)
|
119 |
+
log.info(f'Downloaded file {save_path}')
|
120 |
+
return
|
121 |
+
child_links = _extract_links_from_html(response.content.decode())
|
122 |
+
print(child_links)
|
123 |
+
for child_link in child_links:
|
124 |
+
_recursive_download(session, base_url, urljoin(path, child_link), save_dir, ignore_cert=ignore_cert)
|
125 |
+
|
126 |
+
@tenacity.retry(retry=tenacity.retry_if_not_exception_type((PermissionError, ValueError)), stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(min=1, max=10))
|
127 |
+
def download_from_http_fileserver(url: str, save_dir: str, ignore_cert: bool=False):
|
128 |
+
"""Downloads files from a remote HTTP file server.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
url (str): The base URL where the files are located.
|
132 |
+
save_dir (str): The directory to save downloaded files to.
|
133 |
+
ignore_cert (bool): Whether or not to ignore the validity of the SSL certificate of the remote server.
|
134 |
+
Defaults to False.
|
135 |
+
WARNING: Setting this to true is *not* secure, as no certificate verification will be performed.
|
136 |
+
"""
|
137 |
+
with requests.Session() as session:
|
138 |
+
with warnings.catch_warnings():
|
139 |
+
if ignore_cert:
|
140 |
+
warnings.simplefilter('ignore', category=InsecureRequestWarning)
|
141 |
+
_recursive_download(session, url, '', save_dir, ignore_cert=ignore_cert)
|
142 |
+
|
143 |
+
def download_from_oras(model: str, config_file: str, credentials_dir: str, save_dir: str, tokenizer_only: bool=False, concurrency: int=10):
|
144 |
+
"""Download from an OCI-compliant registry using oras.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
model (str): The name of the model to download.
|
148 |
+
config_file (str): Path to a YAML config file that maps model and tokenizer names to registry paths.
|
149 |
+
credentials_dir (str): Path to a directory containing credentials for the registry. It is expected to contain three
|
150 |
+
files: `username`, `password`, and `registry`, each of which contains the corresponding credential.
|
151 |
+
save_dir (str): Path to the directory where files will be downloaded.
|
152 |
+
tokenizer_only (bool): If true, only download the tokenzier files.
|
153 |
+
concurrency (int): The number of concurrent downloads to run.
|
154 |
+
"""
|
155 |
+
if shutil.which(ORAS_CLI) is None:
|
156 |
+
raise Exception(f'oras cli command `{ORAS_CLI}` is not found. Please install oras: https://oras.land/docs/installation ')
|
157 |
+
|
158 |
+
def _read_secrets_file(secret_file_path: str):
|
159 |
+
try:
|
160 |
+
with open(secret_file_path, encoding='utf-8') as f:
|
161 |
+
return f.read().strip()
|
162 |
+
except Exception as error:
|
163 |
+
raise ValueError(f'secrets file {secret_file_path} failed to be read') from error
|
164 |
+
secrets = {}
|
165 |
+
for secret in ['username', 'password', 'registry']:
|
166 |
+
secrets[secret] = _read_secrets_file(os.path.join(credentials_dir, secret))
|
167 |
+
with open(config_file, 'r', encoding='utf-8') as f:
|
168 |
+
configs = yaml.safe_load(f.read())
|
169 |
+
config_type = 'tokenizers' if tokenizer_only else 'models'
|
170 |
+
path = configs[config_type][model]
|
171 |
+
registry = secrets['registry']
|
172 |
+
|
173 |
+
def get_oras_cmd(username: Optional[str]=None, password: Optional[str]=None):
|
174 |
+
cmd = [ORAS_CLI, 'pull', f'{registry}/{path}', '-o', save_dir, '--verbose', '--concurrency', str(concurrency)]
|
175 |
+
if username is not None:
|
176 |
+
cmd.extend(['--username', username])
|
177 |
+
if password is not None:
|
178 |
+
cmd.extend(['--password', password])
|
179 |
+
return cmd
|
180 |
+
cmd_without_creds = get_oras_cmd()
|
181 |
+
log.info(f"CMD for oras cli to run: {' '.join(cmd_without_creds)}")
|
182 |
+
cmd_to_run = get_oras_cmd(username=secrets['username'], password=secrets['password'])
|
183 |
+
try:
|
184 |
+
subprocess.run(cmd_to_run, check=True)
|
185 |
+
except subprocess.CalledProcessError as e:
|
186 |
+
raise subprocess.CalledProcessError(e.returncode, cmd_without_creds, e.output, e.stderr)
|
model_wrapper.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Re-usable :class:`.ComposerModel` for LLM HF Models."""
|
2 |
+
from __future__ import annotations
|
3 |
+
from collections import UserDict
|
4 |
+
from typing import TYPE_CHECKING, List, Mapping, Optional
|
5 |
+
import transformers
|
6 |
+
from torchmetrics import Metric
|
7 |
+
from transformers import PreTrainedTokenizerBase
|
8 |
+
from transformers.utils.generic import ModelOutput
|
9 |
+
from .hf_fsdp import prepare_hf_model_for_fsdp
|
10 |
+
if TYPE_CHECKING:
|
11 |
+
from peft import PeftConfig
|
12 |
+
_HF_IGNORE_INDEX = -100
|
13 |
+
|
14 |
+
class HuggingFaceModelWithFSDP(HuggingFaceModel):
|
15 |
+
"""Wrapper around HuggingFaceModel.
|
16 |
+
|
17 |
+
Handles preparation for FSDP wrapping.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, model: transformers.PreTrainedModel, tokenizer: Optional[PreTrainedTokenizerBase]=None, metrics: Optional[List[Metric]]=None, eval_metrics: Optional[List[Metric]]=None, shift_labels: bool=False, init_device: Optional[str]=None, peft_config: Optional['PeftConfig']=None):
|
21 |
+
super().__init__(model, tokenizer, use_logits=True, metrics=metrics, eval_metrics=eval_metrics, shift_labels=shift_labels, peft_config=peft_config, should_save_peft_only=True)
|
22 |
+
prepare_hf_model_for_fsdp(self.model, init_device)
|
23 |
+
self.model.param_init_fn = lambda module: self.model._init_weights(module)
|
24 |
+
|
25 |
+
def forward(self, batch: Mapping):
|
26 |
+
if isinstance(batch, dict) or isinstance(batch, UserDict):
|
27 |
+
batch = {k: v for k, v in batch.items() if k in self.model_forward_args}
|
28 |
+
output = self.model(**batch)
|
29 |
+
else:
|
30 |
+
raise ValueError('Unexpected batch type. Expected a dictionary with keys corresponding to the inputs to the forward function of the Huggingface model')
|
31 |
+
return output
|
32 |
+
|
33 |
+
def loss(self, outputs: ModelOutput, batch: Mapping):
|
34 |
+
if self.config.use_return_dict:
|
35 |
+
return outputs['loss']
|
36 |
+
return outputs[:2]
|
modeling_mpt.py
CHANGED
@@ -9,40 +9,27 @@ from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, Un
|
|
9 |
import torch
|
10 |
import torch.nn as nn
|
11 |
import torch.nn.functional as F
|
12 |
-
from .attention import
|
|
|
13 |
if is_flash_v2_installed():
|
14 |
try:
|
15 |
from flash_attn import bert_padding
|
16 |
from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding
|
17 |
except Exception as e:
|
18 |
raise e
|
19 |
-
if is_flash_v1_installed():
|
20 |
-
try:
|
21 |
-
from flash_attn import bert_padding
|
22 |
-
except Exception as e:
|
23 |
-
raise e
|
24 |
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
25 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
26 |
from transformers.models.llama.modeling_llama import LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding
|
27 |
from transformers.models.llama.modeling_llama import LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding
|
28 |
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding as HFRotaryEmbedding
|
29 |
-
from .attention import
|
30 |
from .blocks import MPTBlock
|
31 |
from .custom_embedding import SharedEmbedding
|
32 |
-
from .fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY
|
33 |
-
from .ffn import FFN_CLASS_REGISTRY as FFN_CLASS_REGISTRY
|
34 |
-
from .ffn import MPTMLP as MPTMLP
|
35 |
from .ffn import build_ffn as build_ffn
|
36 |
-
from .norm import NORM_CLASS_REGISTRY
|
37 |
from .configuration_mpt import MPTConfig
|
38 |
-
from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
|
39 |
-
from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
|
40 |
from .meta_init_context import init_empty_weights
|
41 |
from .param_init_fns import generic_param_init_fn_, MODEL_INIT_REGISTRY
|
42 |
-
|
43 |
-
from .flash_attn_triton import flash_attn_func as flash_attn_func
|
44 |
-
except:
|
45 |
-
pass
|
46 |
import logging
|
47 |
log = logging.getLogger(__name__)
|
48 |
|
@@ -140,9 +127,9 @@ def gen_flash_attn_padding_info(bsz: int, S: int, past_key_len: int, device: tor
|
|
140 |
key_padding_mask = attention_mask_in_length
|
141 |
query_padding_mask = attention_mask_in_length
|
142 |
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
flash_attn_padding_info['indices_q'] = indices_q
|
147 |
flash_attn_padding_info['indices_k'] = indices_k
|
148 |
flash_attn_padding_info['indices_v'] = indices_v
|
@@ -176,7 +163,6 @@ class MPTModel(MPTPreTrainedModel):
|
|
176 |
config._validate_config()
|
177 |
super().__init__(config)
|
178 |
self.attn_impl = config.attn_config['attn_impl']
|
179 |
-
self.prefix_lm = config.attn_config['prefix_lm']
|
180 |
self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
|
181 |
self.alibi = config.attn_config['alibi']
|
182 |
self.alibi_bias_max = config.attn_config['alibi_bias_max']
|
@@ -196,6 +182,10 @@ class MPTModel(MPTPreTrainedModel):
|
|
196 |
self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
|
197 |
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
198 |
self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
|
|
|
|
|
|
|
|
|
199 |
self.norm_f = norm_class(config.d_model, device=config.init_device)
|
200 |
self.rope = config.attn_config['rope']
|
201 |
self.rope_impl = None
|
@@ -205,10 +195,10 @@ class MPTModel(MPTPreTrainedModel):
|
|
205 |
if config.init_device != 'meta':
|
206 |
log.info(f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.')
|
207 |
self.apply(self.param_init_fn)
|
208 |
-
self.is_causal =
|
209 |
self._attn_bias_initialized = False
|
210 |
self.attn_bias = None
|
211 |
-
self.attn_bias_shape = attn_bias_shape(self.attn_impl, config.n_heads, config.max_seq_len, self.alibi,
|
212 |
if config.no_bias:
|
213 |
for module in self.modules():
|
214 |
if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
|
@@ -227,7 +217,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
227 |
self.wte = value
|
228 |
|
229 |
@torch.no_grad()
|
230 |
-
def _attn_bias(self, device: torch.device, dtype: torch.dtype, attention_mask: Optional[torch.ByteTensor]=None,
|
231 |
if not self._attn_bias_initialized:
|
232 |
if self.attn_bias_shape:
|
233 |
self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
|
@@ -238,10 +228,6 @@ class MPTModel(MPTPreTrainedModel):
|
|
238 |
if self.attn_bias is not None:
|
239 |
self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
|
240 |
attn_bias = self.attn_bias
|
241 |
-
if self.prefix_lm:
|
242 |
-
assert isinstance(attn_bias, torch.Tensor)
|
243 |
-
assert isinstance(prefix_mask, torch.Tensor)
|
244 |
-
attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
|
245 |
if self.attn_uses_sequence_id and sequence_id is not None:
|
246 |
assert isinstance(attn_bias, torch.Tensor)
|
247 |
attn_bias = apply_sequence_id(attn_bias, sequence_id, self.config.max_seq_len)
|
@@ -252,43 +238,22 @@ class MPTModel(MPTPreTrainedModel):
|
|
252 |
else:
|
253 |
_s_k = max(0, attn_bias.size(-1) - s_k)
|
254 |
attn_bias = attn_bias[:, :, :, _s_k:]
|
255 |
-
if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
|
256 |
-
raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
|
257 |
min_val = torch.finfo(attn_bias.dtype).min
|
258 |
attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
|
259 |
return (attn_bias, attention_mask)
|
260 |
|
261 |
-
def
|
262 |
-
(s_k, s_q) = attn_bias.shape[-2:]
|
263 |
-
if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
|
264 |
-
raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.')
|
265 |
-
seq_len = prefix_mask.shape[-1]
|
266 |
-
if seq_len > self.config.max_seq_len:
|
267 |
-
raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
|
268 |
-
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
269 |
-
causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
|
270 |
-
prefix = prefix_mask.view(-1, 1, 1, seq_len)
|
271 |
-
cannot_attend = ~torch.logical_or(causal, prefix.bool())
|
272 |
-
min_val = torch.finfo(attn_bias.dtype).min
|
273 |
-
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
274 |
-
return attn_bias
|
275 |
-
|
276 |
-
def forward(self, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None) -> BaseModelOutputWithPast:
|
277 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
278 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
279 |
if attention_mask is not None:
|
280 |
attention_mask = attention_mask.bool()
|
281 |
-
if prefix_mask is not None:
|
282 |
-
prefix_mask = prefix_mask.bool()
|
283 |
if not return_dict:
|
284 |
raise NotImplementedError('return_dict False is not implemented yet for MPT')
|
285 |
if output_attentions:
|
286 |
if self.attn_impl != 'torch':
|
287 |
-
raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash
|
288 |
if self.training and attention_mask is not None and (attention_mask[:, 0].sum() != attention_mask.shape[0]):
|
289 |
raise NotImplementedError('MPT does not support training with left padding.')
|
290 |
-
if self.prefix_lm and prefix_mask is None:
|
291 |
-
raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
|
292 |
if self.training:
|
293 |
if self.attn_uses_sequence_id and sequence_id is None:
|
294 |
raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
|
@@ -336,7 +301,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
336 |
x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
|
337 |
assert isinstance(self.emb_drop, nn.Module)
|
338 |
x = self.emb_drop(x_shrunk)
|
339 |
-
|
340 |
attention_mask_in_length = gen_attention_mask_in_length(sequence_id=sequence_id, S=S, attn_uses_sequence_id=self.attn_uses_sequence_id, attn_impl=self.attn_impl, attention_mask=attention_mask)
|
341 |
alibi_slopes = None
|
342 |
if self.alibi and self.attn_impl == 'flash':
|
@@ -349,12 +314,12 @@ class MPTModel(MPTPreTrainedModel):
|
|
349 |
flash_attn_padding_info = {}
|
350 |
if self.attn_impl == 'flash':
|
351 |
flash_attn_padding_info = gen_flash_attn_padding_info(bsz, S, past_position, x.device, attention_mask_in_length, attention_mask)
|
352 |
-
for
|
353 |
if output_hidden_states:
|
354 |
assert all_hidden_states is not None
|
355 |
all_hidden_states = all_hidden_states + (x,)
|
356 |
past_key_value = past_key_values[b_idx] if past_key_values is not None else None
|
357 |
-
|
358 |
if presents is not None:
|
359 |
presents += (present,)
|
360 |
if output_attentions:
|
@@ -422,7 +387,8 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
422 |
self.transformer.set_input_embeddings(new_embeddings)
|
423 |
|
424 |
def tie_weights(self) -> None:
|
425 |
-
self.
|
|
|
426 |
|
427 |
def set_decoder(self, decoder: MPTModel) -> None:
|
428 |
self.transformer = decoder
|
@@ -430,10 +396,10 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
430 |
def get_decoder(self) -> MPTModel:
|
431 |
return self.transformer
|
432 |
|
433 |
-
def forward(self, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None,
|
434 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
435 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
436 |
-
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask,
|
437 |
if self.lm_head is not None:
|
438 |
logits = self.lm_head(outputs.last_hidden_state)
|
439 |
else:
|
@@ -459,29 +425,48 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
459 |
return _fsdp_wrap_fn(self, module)
|
460 |
|
461 |
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
485 |
|
486 |
def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]=None, inputs_embeds: Optional[torch.Tensor]=None, **kwargs: Any) -> Dict[str, Any]:
|
487 |
attention_mask = kwargs['attention_mask'].bool()
|
@@ -493,17 +478,11 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
493 |
sequence_id = None
|
494 |
if past_key_values is not None:
|
495 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
496 |
-
if self.transformer.prefix_lm:
|
497 |
-
prefix_mask = torch.ones_like(attention_mask)
|
498 |
-
if kwargs.get('use_cache') == False:
|
499 |
-
raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')
|
500 |
-
else:
|
501 |
-
prefix_mask = None
|
502 |
if inputs_embeds is not None and past_key_values is None:
|
503 |
model_inputs = {'inputs_embeds': inputs_embeds}
|
504 |
else:
|
505 |
model_inputs = {'input_ids': input_ids}
|
506 |
-
model_inputs.update({'attention_mask': attention_mask, '
|
507 |
return model_inputs
|
508 |
|
509 |
@staticmethod
|
|
|
9 |
import torch
|
10 |
import torch.nn as nn
|
11 |
import torch.nn.functional as F
|
12 |
+
from .attention import is_flash_v2_installed
|
13 |
+
from .norm import NORM_CLASS_REGISTRY
|
14 |
if is_flash_v2_installed():
|
15 |
try:
|
16 |
from flash_attn import bert_padding
|
17 |
from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding
|
18 |
except Exception as e:
|
19 |
raise e
|
|
|
|
|
|
|
|
|
|
|
20 |
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
21 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
22 |
from transformers.models.llama.modeling_llama import LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding
|
23 |
from transformers.models.llama.modeling_llama import LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding
|
24 |
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding as HFRotaryEmbedding
|
25 |
+
from .attention import attn_bias_shape, build_attn_bias, gen_slopes
|
26 |
from .blocks import MPTBlock
|
27 |
from .custom_embedding import SharedEmbedding
|
|
|
|
|
|
|
28 |
from .ffn import build_ffn as build_ffn
|
|
|
29 |
from .configuration_mpt import MPTConfig
|
|
|
|
|
30 |
from .meta_init_context import init_empty_weights
|
31 |
from .param_init_fns import generic_param_init_fn_, MODEL_INIT_REGISTRY
|
32 |
+
from .act_ckpt import pass_on_block_idx, build_act_ckpt_mod_to_blocks, check_mapping_blocks_overlap
|
|
|
|
|
|
|
33 |
import logging
|
34 |
log = logging.getLogger(__name__)
|
35 |
|
|
|
127 |
key_padding_mask = attention_mask_in_length
|
128 |
query_padding_mask = attention_mask_in_length
|
129 |
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
|
130 |
+
_, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(torch.empty(bsz, S, 1, device=device), query_padding_mask)
|
131 |
+
_, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
|
132 |
+
_, indices_v, _, _ = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
|
133 |
flash_attn_padding_info['indices_q'] = indices_q
|
134 |
flash_attn_padding_info['indices_k'] = indices_k
|
135 |
flash_attn_padding_info['indices_v'] = indices_v
|
|
|
163 |
config._validate_config()
|
164 |
super().__init__(config)
|
165 |
self.attn_impl = config.attn_config['attn_impl']
|
|
|
166 |
self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
|
167 |
self.alibi = config.attn_config['alibi']
|
168 |
self.alibi_bias_max = config.attn_config['alibi_bias_max']
|
|
|
182 |
self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
|
183 |
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
184 |
self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
|
185 |
+
for i, block in enumerate(self.blocks):
|
186 |
+
block.block_idx = i
|
187 |
+
block.max_block_idx = config.n_layers - 1
|
188 |
+
pass_on_block_idx(block)
|
189 |
self.norm_f = norm_class(config.d_model, device=config.init_device)
|
190 |
self.rope = config.attn_config['rope']
|
191 |
self.rope_impl = None
|
|
|
195 |
if config.init_device != 'meta':
|
196 |
log.info(f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.')
|
197 |
self.apply(self.param_init_fn)
|
198 |
+
self.is_causal = True
|
199 |
self._attn_bias_initialized = False
|
200 |
self.attn_bias = None
|
201 |
+
self.attn_bias_shape = attn_bias_shape(self.attn_impl, config.n_heads, config.max_seq_len, self.alibi, causal=self.is_causal, use_sequence_id=self.attn_uses_sequence_id)
|
202 |
if config.no_bias:
|
203 |
for module in self.modules():
|
204 |
if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
|
|
|
217 |
self.wte = value
|
218 |
|
219 |
@torch.no_grad()
|
220 |
+
def _attn_bias(self, device: torch.device, dtype: torch.dtype, attention_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None) -> Tuple[Optional[torch.Tensor], Optional[torch.ByteTensor]]:
|
221 |
if not self._attn_bias_initialized:
|
222 |
if self.attn_bias_shape:
|
223 |
self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
|
|
|
228 |
if self.attn_bias is not None:
|
229 |
self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
|
230 |
attn_bias = self.attn_bias
|
|
|
|
|
|
|
|
|
231 |
if self.attn_uses_sequence_id and sequence_id is not None:
|
232 |
assert isinstance(attn_bias, torch.Tensor)
|
233 |
attn_bias = apply_sequence_id(attn_bias, sequence_id, self.config.max_seq_len)
|
|
|
238 |
else:
|
239 |
_s_k = max(0, attn_bias.size(-1) - s_k)
|
240 |
attn_bias = attn_bias[:, :, :, _s_k:]
|
|
|
|
|
241 |
min_val = torch.finfo(attn_bias.dtype).min
|
242 |
attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
|
243 |
return (attn_bias, attention_mask)
|
244 |
|
245 |
+
def forward(self, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None) -> BaseModelOutputWithPast:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
247 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
248 |
if attention_mask is not None:
|
249 |
attention_mask = attention_mask.bool()
|
|
|
|
|
250 |
if not return_dict:
|
251 |
raise NotImplementedError('return_dict False is not implemented yet for MPT')
|
252 |
if output_attentions:
|
253 |
if self.attn_impl != 'torch':
|
254 |
+
raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash`.')
|
255 |
if self.training and attention_mask is not None and (attention_mask[:, 0].sum() != attention_mask.shape[0]):
|
256 |
raise NotImplementedError('MPT does not support training with left padding.')
|
|
|
|
|
257 |
if self.training:
|
258 |
if self.attn_uses_sequence_id and sequence_id is None:
|
259 |
raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
|
|
|
301 |
x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
|
302 |
assert isinstance(self.emb_drop, nn.Module)
|
303 |
x = self.emb_drop(x_shrunk)
|
304 |
+
attn_bias, attention_mask = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, sequence_id=sequence_id)
|
305 |
attention_mask_in_length = gen_attention_mask_in_length(sequence_id=sequence_id, S=S, attn_uses_sequence_id=self.attn_uses_sequence_id, attn_impl=self.attn_impl, attention_mask=attention_mask)
|
306 |
alibi_slopes = None
|
307 |
if self.alibi and self.attn_impl == 'flash':
|
|
|
314 |
flash_attn_padding_info = {}
|
315 |
if self.attn_impl == 'flash':
|
316 |
flash_attn_padding_info = gen_flash_attn_padding_info(bsz, S, past_position, x.device, attention_mask_in_length, attention_mask)
|
317 |
+
for b_idx, block in enumerate(self.blocks):
|
318 |
if output_hidden_states:
|
319 |
assert all_hidden_states is not None
|
320 |
all_hidden_states = all_hidden_states + (x,)
|
321 |
past_key_value = past_key_values[b_idx] if past_key_values is not None else None
|
322 |
+
x, attn_weights, present = block(x, past_key_value=past_key_value, attn_bias=attn_bias, rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info)
|
323 |
if presents is not None:
|
324 |
presents += (present,)
|
325 |
if output_attentions:
|
|
|
387 |
self.transformer.set_input_embeddings(new_embeddings)
|
388 |
|
389 |
def tie_weights(self) -> None:
|
390 |
+
if getattr(self.config, 'tie_word_embeddings', True):
|
391 |
+
self.lm_head = None
|
392 |
|
393 |
def set_decoder(self, decoder: MPTModel) -> None:
|
394 |
self.transformer = decoder
|
|
|
396 |
def get_decoder(self) -> MPTModel:
|
397 |
return self.transformer
|
398 |
|
399 |
+
def forward(self, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None) -> CausalLMOutputWithPast:
|
400 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
401 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
402 |
+
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, inputs_embeds=inputs_embeds)
|
403 |
if self.lm_head is not None:
|
404 |
logits = self.lm_head(outputs.last_hidden_state)
|
405 |
else:
|
|
|
425 |
return _fsdp_wrap_fn(self, module)
|
426 |
|
427 |
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
|
428 |
+
"""The MPT activation checkpointing (act ckpt) function.
|
429 |
+
|
430 |
+
When `activation_checkpointing` in fsdp_config is set to true, this function will be called on all the modules in the FSDP wrapped model and determine whether a given module should be activation checkpointed. It checks the checkpointing target (`activation_checkpointing_target` in `model`) which can be specified as below:
|
431 |
+
1. null (or no such field): The whole MPTBlock will be activation checkpointed on all layers
|
432 |
+
2. a list of modules to act ckpt on all layers, e.g.,
|
433 |
+
activation_checkpointing_target:
|
434 |
+
- grouped_query_attention
|
435 |
+
- mptmlp
|
436 |
+
3. a dictionary of module name with target_blocks, e.g.,
|
437 |
+
activation_checkpointing_target:
|
438 |
+
{
|
439 |
+
"mptblock": target_blocks_1,
|
440 |
+
"grouped_query_attention": target_blocks_2
|
441 |
+
}
|
442 |
+
target_blocks (target_blocks_1, target_blocks_2 above) can be:
|
443 |
+
- a single integer n: the first n transformer block will be activation checkpointed
|
444 |
+
- a string of first-n, middle-m, last-k, range-i-j: the first n, the middle m, the last k, or the range [i, j) layers will be activation checkpointed. E.g, 'first-2, last-2' means the first 2 and last 2 transformer blocks will be activation checkpointed
|
445 |
+
middle-m is range [start, end) where ``start = max(max_block_idx // 2 - m // 2, 0), end = min(start + m, max_block_idx + 1)``
|
446 |
+
- a list of integers corresponds to the list of transformer block ids, e.g., [2] means the second transformer block will be activation checkpointed. [2, 3] means the second and third transformer blocks will be activation checkpointed
|
447 |
+
- a list of mixed integers and strings of first-n, middle-m, last-k, range-i-j
|
448 |
+
|
449 |
+
An example in yaml config file:
|
450 |
+
fsdp_config:
|
451 |
+
activation_checkpointing: true
|
452 |
+
model:
|
453 |
+
activation_checkpointing_target:
|
454 |
+
{
|
455 |
+
"mptblock": 'first-5',
|
456 |
+
"grouped_query_attention": 'last-35'
|
457 |
+
}
|
458 |
+
"""
|
459 |
+
if not hasattr(module, 'block_idx'):
|
460 |
+
log.debug(f'{module.__class__.__name__} cannot be activation checkpointed. Only transformer block or its submodules are eligible for activation checkpointing.')
|
461 |
+
return False
|
462 |
+
act_ckpt_target = getattr(self.config, 'activation_checkpointing_target', None)
|
463 |
+
act_ckpt_mod_to_blocks = build_act_ckpt_mod_to_blocks(act_ckpt_target, MPTBlock, module.max_block_idx)
|
464 |
+
check_mapping_blocks_overlap(act_ckpt_mod_to_blocks, module.max_block_idx)
|
465 |
+
for k in act_ckpt_mod_to_blocks.keys():
|
466 |
+
if isinstance(module, k):
|
467 |
+
blocks = act_ckpt_mod_to_blocks[k]
|
468 |
+
return True if blocks == -1 else module.block_idx in blocks
|
469 |
+
return False
|
470 |
|
471 |
def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]=None, inputs_embeds: Optional[torch.Tensor]=None, **kwargs: Any) -> Dict[str, Any]:
|
472 |
attention_mask = kwargs['attention_mask'].bool()
|
|
|
478 |
sequence_id = None
|
479 |
if past_key_values is not None:
|
480 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
if inputs_embeds is not None and past_key_values is None:
|
482 |
model_inputs = {'inputs_embeds': inputs_embeds}
|
483 |
else:
|
484 |
model_inputs = {'input_ids': input_ids}
|
485 |
+
model_inputs.update({'attention_mask': attention_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)})
|
486 |
return model_inputs
|
487 |
|
488 |
@staticmethod
|
monolithic_ckpt_callback.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import os
|
3 |
+
import tempfile
|
4 |
+
from pathlib import Path
|
5 |
+
import torch
|
6 |
+
|
7 |
+
class MonolithicCheckpointSaver(Callback):
|
8 |
+
"""Save a monolithic checkpoint every N batches.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
save_folder (str): Folder to save checkpoints to (can be a URI)
|
12 |
+
batch_interval (int): Number of batches between checkpoints.
|
13 |
+
filename (str): Filename to save checkpoints to.
|
14 |
+
overwrite (bool): Whether to overwrite previous checkpoints.
|
15 |
+
keep_optimizers (bool): Whether to save the optimizer state in the monolithic checkpoint.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, save_folder: str, batch_interval: int, filename: str='ep{epoch}-ba{batch}.pt', overwrite: bool=False, keep_optimizers: bool=False):
|
19 |
+
self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(save_folder)
|
20 |
+
self.filename_format_str = filename
|
21 |
+
self.batch_interval = batch_interval
|
22 |
+
self.upload_to_object_store = self.backend != ''
|
23 |
+
self.overwrite = overwrite
|
24 |
+
self.keep_optimizers = keep_optimizers
|
25 |
+
if self.upload_to_object_store:
|
26 |
+
self.remote_ud = RemoteUploaderDownloader(bucket_uri=f'{self.backend}://{self.bucket_name}')
|
27 |
+
else:
|
28 |
+
self.remote_ud = None
|
29 |
+
|
30 |
+
def init(self, state: State, logger: Logger) -> None:
|
31 |
+
if self.upload_to_object_store and self.remote_ud is not None:
|
32 |
+
self.remote_ud.init(state, logger)
|
33 |
+
state.callbacks.append(self.remote_ud)
|
34 |
+
|
35 |
+
def batch_checkpoint(self, state: State, logger: Logger) -> None:
|
36 |
+
if state.timestamp.batch.value % self.batch_interval == 0:
|
37 |
+
self._save_checkpoint(state, logger)
|
38 |
+
|
39 |
+
def fit_end(self, state: State, logger: Logger) -> None:
|
40 |
+
if state.timestamp.batch.value % self.batch_interval != 0:
|
41 |
+
self._save_checkpoint(state, logger)
|
42 |
+
|
43 |
+
def _save_checkpoint(self, state: State, logger: Logger) -> None:
|
44 |
+
del logger
|
45 |
+
filename = format_name_with_dist_and_time(self.filename_format_str, state.run_name, state.timestamp)
|
46 |
+
save_dir = format_name_with_dist_and_time(self.save_dir_format_str, state.run_name, state.timestamp)
|
47 |
+
dir_context_mgr = tempfile.TemporaryDirectory() if self.upload_to_object_store else contextlib.nullcontext(enter_result=save_dir)
|
48 |
+
with dir_context_mgr as temp_save_dir:
|
49 |
+
assert isinstance(temp_save_dir, str)
|
50 |
+
save_path = str(Path(temp_save_dir) / Path(filename))
|
51 |
+
dirname = os.path.dirname(save_path)
|
52 |
+
if dirname:
|
53 |
+
os.makedirs(dirname, exist_ok=True)
|
54 |
+
state_dict = {'state': state.state_dict(), 'rng': reproducibility.get_rng_state()}
|
55 |
+
state_dict['state'].pop('optimizers')
|
56 |
+
state_dict['state'].pop('model')
|
57 |
+
with fsdp_state_dict_type_context(state.model, state_dict_type='full'):
|
58 |
+
state_dict['state']['model'] = state.model.state_dict()
|
59 |
+
if self.keep_optimizers:
|
60 |
+
optimizer = state.optimizers[0]
|
61 |
+
state_dict['state']['optimizers'] = {type(optimizer).__qualname__: fsdp_get_optim_state_dict(state.model, optimizer, state_dict_type='full')}
|
62 |
+
if dist.get_global_rank() == 0:
|
63 |
+
torch.save(state_dict, save_path)
|
64 |
+
if self.upload_to_object_store and self.remote_ud is not None and (dist.get_global_rank() == 0):
|
65 |
+
remote_file_name = str(Path(save_dir) / Path(filename))
|
66 |
+
self.remote_ud.upload_file(state=state, remote_file_name=remote_file_name, file_path=Path(save_path), overwrite=self.overwrite)
|
mosaicml_logger_utils.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from typing import Any, Dict, List, Optional, Union
|
4 |
+
_MODEL_KEYS_TO_LOG = ['pretrained_model_name_or_path', 'pretrained', 'vocab_size', 'd_model', 'n_heads', 'n_layers', 'expansion_ratio', 'max_seq_len']
|
5 |
+
|
6 |
+
def maybe_create_mosaicml_logger() -> Optional[MosaicMLLogger]:
|
7 |
+
"""Creates a MosaicMLLogger if the run was sent from the Mosaic platform."""
|
8 |
+
if os.environ.get(MOSAICML_PLATFORM_ENV_VAR, 'false').lower() == 'true' and os.environ.get(MOSAICML_ACCESS_TOKEN_ENV_VAR):
|
9 |
+
return MosaicMLLogger()
|
10 |
+
|
11 |
+
def find_mosaicml_logger(loggers: List[LoggerDestination]) -> Optional[MosaicMLLogger]:
|
12 |
+
"""Returns the first MosaicMLLogger from a list, and None otherwise."""
|
13 |
+
return next((logger for logger in loggers if isinstance(logger, MosaicMLLogger)), None)
|
14 |
+
|
15 |
+
def log_eval_analytics(mosaicml_logger: MosaicMLLogger, model_configs: ListConfig, icl_tasks: Union[str, ListConfig], eval_gauntlet_config: Optional[Union[str, DictConfig]]):
|
16 |
+
"""Logs analytics for runs using the `eval.py` script."""
|
17 |
+
metrics: Dict[str, Any] = {'llmfoundry/script': 'eval'}
|
18 |
+
metrics['llmfoundry/gauntlet_configured'] = eval_gauntlet_config is not None
|
19 |
+
metrics['llmfoundry/icl_configured'] = isinstance(icl_tasks, str) or len(icl_tasks) > 0
|
20 |
+
metrics['llmfoundry/model_configs'] = []
|
21 |
+
for model_config in model_configs:
|
22 |
+
nested_model_config = model_config.get('model', {})
|
23 |
+
model_config_data = {}
|
24 |
+
for key in _MODEL_KEYS_TO_LOG:
|
25 |
+
if nested_model_config.get(key, None) is not None:
|
26 |
+
model_config_data[key] = nested_model_config.get(key)
|
27 |
+
if len(model_config_data) > 0:
|
28 |
+
metrics['llmfoundry/model_configs'].append(json.dumps(model_config_data, sort_keys=True))
|
29 |
+
mosaicml_logger.log_metrics(metrics)
|
30 |
+
mosaicml_logger._flush_metadata(force_flush=True)
|
31 |
+
|
32 |
+
def log_train_analytics(mosaicml_logger: MosaicMLLogger, model_config: DictConfig, train_loader_config: DictConfig, eval_loader_config: Optional[Union[DictConfig, ListConfig]], callback_configs: Optional[DictConfig], tokenizer_name: str, load_path: Optional[str], icl_tasks_config: Optional[Union[ListConfig, str]], eval_gauntlet: Optional[Union[DictConfig, str]]):
|
33 |
+
"""Logs analytics for runs using the `train.py` script."""
|
34 |
+
train_loader_dataset = train_loader_config.get('dataset', {})
|
35 |
+
metrics: Dict[str, Any] = {'llmfoundry/tokenizer_name': tokenizer_name, 'llmfoundry/script': 'train', 'llmfoundry/train_loader_name': train_loader_config.get('name')}
|
36 |
+
if callback_configs is not None:
|
37 |
+
metrics['llmfoundry/callbacks'] = [name for name, _ in callback_configs.items()]
|
38 |
+
metrics['llmfoundry/gauntlet_configured'] = eval_gauntlet is not None
|
39 |
+
metrics['llmfoundry/icl_configured'] = icl_tasks_config is not None and (isinstance(icl_tasks_config, str) or len(icl_tasks_config) > 0)
|
40 |
+
if train_loader_dataset.get('hf_name', None) is not None:
|
41 |
+
metrics['llmfoundry/train_dataset_hf_name'] = train_loader_dataset.get('hf_name', None)
|
42 |
+
if train_loader_config.get('name') == 'finetuning':
|
43 |
+
metrics['llmfoundry/train_task_type'] = 'INSTRUCTION_FINETUNE'
|
44 |
+
elif train_loader_config.get('name') == 'text':
|
45 |
+
if load_path is not None or model_config.get('pretrained') == True:
|
46 |
+
metrics['llmfoundry/train_task_type'] = 'CONTINUED_PRETRAIN'
|
47 |
+
else:
|
48 |
+
metrics['llmfoundry/train_task_type'] = 'PRETRAIN'
|
49 |
+
if eval_loader_config is not None:
|
50 |
+
metrics['llmfoundry/eval_loaders'] = []
|
51 |
+
if isinstance(eval_loader_config, ListConfig):
|
52 |
+
eval_loader_configs: ListConfig = eval_loader_config
|
53 |
+
else:
|
54 |
+
eval_loader_configs = ListConfig([eval_loader_config])
|
55 |
+
for loader_config in eval_loader_configs:
|
56 |
+
eval_loader_info = {}
|
57 |
+
eval_loader_dataset = loader_config.get('dataset', {})
|
58 |
+
eval_loader_info['name'] = loader_config.get('name')
|
59 |
+
if eval_loader_dataset.get('hf_name', None) is not None:
|
60 |
+
eval_loader_info['dataset_hf_name'] = eval_loader_dataset.get('hf_name')
|
61 |
+
metrics['llmfoundry/eval_loaders'].append(json.dumps(eval_loader_info, sort_keys=True))
|
62 |
+
model_config_data = {}
|
63 |
+
for key in _MODEL_KEYS_TO_LOG:
|
64 |
+
if model_config.get(key, None) is not None:
|
65 |
+
model_config_data[f'llmfoundry/{key}'] = model_config.get(key)
|
66 |
+
if len(model_config_data) > 0:
|
67 |
+
metrics.update(model_config_data)
|
68 |
+
mosaicml_logger.log_metrics(metrics)
|
69 |
+
mosaicml_logger._flush_metadata(force_flush=True)
|
mpt.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .configuration_mpt import MPTConfig
|
2 |
+
from .modeling_mpt import ComposerMPTCausalLM, MPTForCausalLM, MPTModel, MPTPreTrainedModel
|
packing.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import tempfile
|
3 |
+
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from transformers import PreTrainedTokenizerBase
|
7 |
+
log = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
class BinPackCollator:
|
10 |
+
"""Utility collator for packing to reduce padding."""
|
11 |
+
|
12 |
+
def __init__(self, collator: Callable, target_batch_size: int, max_seq_len: int, pad_token_id: int, padding_side: Literal['left', 'right'], max_leftover_bins_to_keep: Optional[int]=None):
|
13 |
+
self.base_collator = collator
|
14 |
+
self.out_size = int(target_batch_size)
|
15 |
+
self.max_seq_len = int(max_seq_len)
|
16 |
+
self.pad_token_id = int(pad_token_id)
|
17 |
+
self.padding_side = padding_side
|
18 |
+
if self.out_size <= 0:
|
19 |
+
raise ValueError(f'target_batch_size={target_batch_size!r} must be >0.')
|
20 |
+
if self.max_seq_len <= 0:
|
21 |
+
raise ValueError(f'max_seq_len={max_seq_len!r} must be >0.')
|
22 |
+
if self.pad_token_id < 0:
|
23 |
+
raise ValueError(f'pad_token_id={pad_token_id!r} must be >=0.')
|
24 |
+
if max_leftover_bins_to_keep is not None and max_leftover_bins_to_keep < 0:
|
25 |
+
raise ValueError(f'max_leftover_bins_to_keep={max_leftover_bins_to_keep!r} must be >=0 or None.')
|
26 |
+
self.max_leftover_bins_to_keep = max_leftover_bins_to_keep
|
27 |
+
self.n_packed_tokens = 0
|
28 |
+
self.n_total_tokens = 0
|
29 |
+
self.n_packed_examples = 0
|
30 |
+
self._leftover_bins: List[Tuple[int, Dict[str, torch.Tensor]]] = []
|
31 |
+
|
32 |
+
@property
|
33 |
+
def waste(self) -> float:
|
34 |
+
return 1 - self.n_packed_tokens / self.n_total_tokens
|
35 |
+
|
36 |
+
@property
|
37 |
+
def efficiency(self) -> float:
|
38 |
+
return self.n_packed_tokens / (self.max_seq_len * self.n_packed_examples)
|
39 |
+
|
40 |
+
def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
41 |
+
batch = self.base_collator(examples)
|
42 |
+
return self.pack(batch)
|
43 |
+
|
44 |
+
def pack(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
45 |
+
assert 'attention_mask' in batch
|
46 |
+
assert 'input_ids' in batch
|
47 |
+
for key in batch.keys():
|
48 |
+
assert key in ['input_ids', 'labels', 'attention_mask', 'sequence_id']
|
49 |
+
sizes, trimmed_examples = _trim_batch(batch)
|
50 |
+
return self._pack_trimmed_examples(trimmed_examples, sizes)
|
51 |
+
|
52 |
+
def _pack_trimmed_examples(self, trimmed_examples: List[Dict[str, torch.Tensor]], sizes: List[int]) -> Dict[str, torch.Tensor]:
|
53 |
+
"""Packs trimmed examples into fixed-size bins and repads them.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
trimmed_examples (List[Dict[str, torch.Tensor]]): A list of trimmed examples.
|
57 |
+
sizes (List[int]): The sizes of the trimmed examples.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
Dict[str, torch.Tensor]: A batch of repadded examples ready for processing
|
61 |
+
"""
|
62 |
+
packed_examples, n_packed_tokens, n_total_tokens, leftover_bins = _first_fit_bin_packing(sizes=sizes, examples=trimmed_examples, num_bins=self.out_size, max_bin_size=self.max_seq_len, existing_bins=self._leftover_bins)
|
63 |
+
self.n_packed_tokens += n_packed_tokens
|
64 |
+
self.n_total_tokens += n_total_tokens
|
65 |
+
self.n_packed_examples += self.out_size
|
66 |
+
self._leftover_bins = leftover_bins[:self.max_leftover_bins_to_keep]
|
67 |
+
batch = _repad(packed_examples, max_seq_len=self.max_seq_len, pad_token_id=self.pad_token_id, padding_side=self.padding_side)
|
68 |
+
return batch
|
69 |
+
|
70 |
+
def _trim_batch(batch: Dict[str, torch.Tensor]) -> Tuple[List[int], List[Dict[str, torch.Tensor]]]:
|
71 |
+
"""Trims padding off all examples in batch.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
batch (Dict[str, torch.Tensor]): Batch of padded data with tensors as values.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
A tuple with unpadded lengths of examples and a list of each trimmed example from the batch.
|
78 |
+
"""
|
79 |
+
sizes, trimmed_examples = ([], [])
|
80 |
+
for idx in range(batch['attention_mask'].shape[0]):
|
81 |
+
size, trimmed_example = _extract_trim_batch_idx(batch, idx)
|
82 |
+
sizes.append(size)
|
83 |
+
trimmed_examples.append(trimmed_example)
|
84 |
+
return (sizes, trimmed_examples)
|
85 |
+
|
86 |
+
def _extract_trim_batch_idx(batch: Dict[str, torch.Tensor], idx: int) -> Tuple[int, Dict[str, torch.Tensor]]:
|
87 |
+
example = {k: v[idx] for k, v in batch.items()}
|
88 |
+
keep = example['attention_mask'] == 1
|
89 |
+
size = int(keep.sum())
|
90 |
+
trim_example = {k: v[keep] for k, v in example.items()}
|
91 |
+
trim_example['sequence_id'] = torch.zeros_like(trim_example['input_ids'])
|
92 |
+
return (size, trim_example)
|
93 |
+
|
94 |
+
def _combine_in_place(example: Dict[str, torch.Tensor], add_on: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
95 |
+
if 'labels' in add_on:
|
96 |
+
add_on['labels'][0] = -100
|
97 |
+
for k in example.keys():
|
98 |
+
if k == 'sequence_id':
|
99 |
+
example[k] = torch.cat([example[k], add_on[k] + 1 + torch.max(example[k])])
|
100 |
+
else:
|
101 |
+
example[k] = torch.cat([example[k], add_on[k]])
|
102 |
+
return example
|
103 |
+
|
104 |
+
def _first_fit_bin_packing(sizes: List[int], examples: List[Dict[str, torch.Tensor]], num_bins: int, max_bin_size: int, existing_bins: List[Tuple[int, Dict[str, torch.Tensor]]]) -> Tuple[List[Dict[str, torch.Tensor]], int, int, List[Tuple[int, Dict[str, torch.Tensor]]]]:
|
105 |
+
bins: List[Tuple[int, Dict[str, torch.Tensor]]] = existing_bins
|
106 |
+
starting_total_bin_sizes = sum([bin_size for bin_size, _ in bins])
|
107 |
+
sizes_and_examples = [(size, example) for size, example in zip(sizes, examples)]
|
108 |
+
sorted_sizes_and_examples = sorted(sizes_and_examples, key=lambda x: x[0], reverse=True)
|
109 |
+
required_num_examples = max(0, num_bins - len(bins))
|
110 |
+
num_examples = len(sizes)
|
111 |
+
if num_examples < required_num_examples:
|
112 |
+
for size, example in sorted_sizes_and_examples:
|
113 |
+
bins.append((size, example))
|
114 |
+
total_bin_sizes = sum([bin_size for bin_size, _ in bins])
|
115 |
+
total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes
|
116 |
+
total_example_sizes = sum(sizes)
|
117 |
+
if total_new_bin_sizes != total_example_sizes:
|
118 |
+
raise AssertionError(f'Error in packing. total_example_sizes={total_example_sizes!r} does not equal total_new_bin_sizes={total_new_bin_sizes!r}.')
|
119 |
+
sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True)
|
120 |
+
bin_sizes, packed_examples = ([], [])
|
121 |
+
for bin_size, packed_example in sorted_bins:
|
122 |
+
bin_sizes.append(bin_size)
|
123 |
+
packed_examples.append(packed_example)
|
124 |
+
return (packed_examples[:num_bins], sum(bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:])
|
125 |
+
for i, (size, example) in enumerate(sorted_sizes_and_examples):
|
126 |
+
required_num_examples = max(0, num_bins - len(bins))
|
127 |
+
n_remaining = num_examples - i
|
128 |
+
assert n_remaining >= required_num_examples
|
129 |
+
if n_remaining == required_num_examples:
|
130 |
+
bins.append((size, example))
|
131 |
+
continue
|
132 |
+
added = False
|
133 |
+
for bidx in range(len(bins)):
|
134 |
+
if bins[bidx][0] + size <= max_bin_size:
|
135 |
+
bin_size, packed_example = bins.pop(bidx)
|
136 |
+
bin_size = bin_size + size
|
137 |
+
packed_example = _combine_in_place(packed_example, example)
|
138 |
+
bins.append((bin_size, packed_example))
|
139 |
+
added = True
|
140 |
+
break
|
141 |
+
if not added:
|
142 |
+
bins.append((size, example))
|
143 |
+
total_bin_sizes = sum([bin_size for bin_size, _ in bins])
|
144 |
+
total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes
|
145 |
+
total_example_sizes = sum(sizes)
|
146 |
+
if total_new_bin_sizes != total_example_sizes:
|
147 |
+
raise AssertionError(f'Error in packing. total_example_sizes={total_example_sizes!r} does not equal total_new_bin_sizes={total_new_bin_sizes!r}.')
|
148 |
+
sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True)
|
149 |
+
bin_sizes, packed_examples = ([], [])
|
150 |
+
for bin_size, packed_example in sorted_bins:
|
151 |
+
bin_sizes.append(bin_size)
|
152 |
+
packed_examples.append(packed_example)
|
153 |
+
return (packed_examples[:num_bins], sum(bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:])
|
154 |
+
|
155 |
+
def _repad(packed_examples: List[Dict[str, torch.Tensor]], max_seq_len: int, pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]:
|
156 |
+
|
157 |
+
def pad_tensor(tensor: torch.Tensor, pad_value: int):
|
158 |
+
if len(tensor) == max_seq_len:
|
159 |
+
return tensor
|
160 |
+
t = torch.full((max_seq_len,), pad_value, dtype=tensor.dtype, device=tensor.device)
|
161 |
+
if padding_side == 'left':
|
162 |
+
t[-len(tensor):] = tensor
|
163 |
+
elif padding_side == 'right':
|
164 |
+
t[:len(tensor)] = tensor
|
165 |
+
else:
|
166 |
+
raise ValueError(f'Unknown padding_side={padding_side!r}')
|
167 |
+
return t
|
168 |
+
pad_vals = {'input_ids': pad_token_id, 'labels': -100, 'attention_mask': 0, 'sequence_id': -1}
|
169 |
+
keys = packed_examples[0].keys()
|
170 |
+
batch = {}
|
171 |
+
for key in keys:
|
172 |
+
batch[key] = torch.stack([pad_tensor(example[key], pad_vals[key]) for example in packed_examples])
|
173 |
+
return batch
|
174 |
+
|
175 |
+
def auto_packing_ratio(dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int, num_packing_ratios: int=20) -> float:
|
176 |
+
"""Find a packing ratio that minimizes padding with zero waste.
|
177 |
+
|
178 |
+
By packing examples, we can increase training efficiency, training on more data with less batches.
|
179 |
+
However, in practice, the selected packing_ratio may produce some waste because profiling is done on only
|
180 |
+
a subset of the dataset.
|
181 |
+
|
182 |
+
We select a min_ratio of 1 and a max_ratio that is the max_seq_len / 100, and profile up to
|
183 |
+
num_packing_ratios packing ratios between min_ratio and max_ratio, inclusive.
|
184 |
+
When a packing_ratio with non-zero waste is found, we stop and select the previous ratio,
|
185 |
+
which has zero waste.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
dataloader_cfg (DictConfig): The dataloader configuration for profiling.
|
189 |
+
tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling.
|
190 |
+
device_batch_size (int): The size of the batches (number of examples) per device.
|
191 |
+
num_packing_ratio (int): The number of packing ratios to try.
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
A packing ratio that minimizes padding while maintaining zero waste.
|
195 |
+
"""
|
196 |
+
rng_state = reproducibility.get_rng_state()
|
197 |
+
reproducibility.seed_all(0)
|
198 |
+
max_seq_len = dataloader_cfg.dataset.max_seq_len
|
199 |
+
if max_seq_len <= 100:
|
200 |
+
return 1
|
201 |
+
min_ratio = 1
|
202 |
+
max_ratio = max_seq_len / 100
|
203 |
+
profiling_results = profile_packing(dataloader_cfg, tokenizer, min_ratio, max_ratio, num_packing_ratios, device_batch_size)
|
204 |
+
packing_ratio = 1
|
205 |
+
for packing_ratio_candidate, _, waste in profiling_results:
|
206 |
+
if waste is None or waste > 0:
|
207 |
+
break
|
208 |
+
packing_ratio = packing_ratio_candidate
|
209 |
+
if dist.is_available() and dist.is_initialized():
|
210 |
+
device = get_device(None)
|
211 |
+
packing_ratio_tensor = device.tensor_to_device(torch.tensor(packing_ratio))
|
212 |
+
dist.all_reduce(packing_ratio_tensor, reduce_operation='MIN')
|
213 |
+
packing_ratio = packing_ratio_tensor.item()
|
214 |
+
reproducibility.load_rng_state(rng_state)
|
215 |
+
return packing_ratio
|
216 |
+
|
217 |
+
def profile_packing(dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, min_ratio: float, max_ratio: float, num_packing_ratios: int, device_batch_size: int) -> Iterable[Tuple[float, Optional[float], Optional[float]]]:
|
218 |
+
"""Generator function that profiles example packing across packing ratios.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
dataloader_cfg (DictConfig): The dataloader configuration for profiling.
|
222 |
+
tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling.
|
223 |
+
min_ratio (float): Smallest packing_ratio to test. Must be >=1.
|
224 |
+
max_ratio (float): Largest packing_ratio to test. Must be larger than `min_ratio`.
|
225 |
+
num_packing_ratios (int): Number of packing_ratio values (spaced between `min_ratio` and `max_ratio`) to try.
|
226 |
+
device_batch_size (int): The size of the batches (number of examples) per device.
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
An iterable of tuples of packing ratio, padding, and waste, sorted by smallest to largest packing ratio.
|
230 |
+
"""
|
231 |
+
import copy
|
232 |
+
from .dataloader import build_dataloader
|
233 |
+
max_seq_len = dataloader_cfg.dataset.get('max_seq_len')
|
234 |
+
max_leftovers_to_keep = dataloader_cfg.dataset.get('max_leftovers_to_keep', None)
|
235 |
+
dataloader_cfg = copy.deepcopy(dataloader_cfg)
|
236 |
+
dataloader_cfg.dataset.packing_ratio = 1.0
|
237 |
+
dataloader_cfg.drop_last = False
|
238 |
+
dataloader_cfg.num_workers = 0
|
239 |
+
dataloader_cfg.prefetch_factor = None
|
240 |
+
dataloader_cfg.persistent_workers = False
|
241 |
+
if dataloader_cfg.dataset.get('remote') is not None:
|
242 |
+
dataloader_cfg.dataset.local = tempfile.TemporaryDirectory().name
|
243 |
+
packing_ratios, raw_batch_sizes = ([], [])
|
244 |
+
for packing_ratio in np.linspace(min_ratio, max_ratio, num_packing_ratios, endpoint=True):
|
245 |
+
packing_ratio = np.round(10 * packing_ratio) / 10
|
246 |
+
raw_batch_size = int(packing_ratio * device_batch_size)
|
247 |
+
if raw_batch_size not in raw_batch_sizes:
|
248 |
+
packing_ratios.append(packing_ratio)
|
249 |
+
raw_batch_sizes.append(raw_batch_size)
|
250 |
+
n_profile_examples = max(raw_batch_sizes) * 100
|
251 |
+
train_dataspec = build_dataloader(dataloader_cfg, tokenizer, n_profile_examples)
|
252 |
+
train_dataloader = train_dataspec.dataloader
|
253 |
+
big_batch = next(iter(train_dataloader))
|
254 |
+
sizes, trimmed_examples = _trim_batch(big_batch)
|
255 |
+
|
256 |
+
def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]:
|
257 |
+
trimmed_examples_copy = [te.copy() for te in trimmed_examples]
|
258 |
+
packer = BinPackCollator(collator=lambda x: x, target_batch_size=device_batch_size, max_seq_len=max_seq_len, pad_token_id=0, padding_side='left', max_leftover_bins_to_keep=max_leftovers_to_keep)
|
259 |
+
for idx in range(0, len(trimmed_examples_copy), raw_batch_size):
|
260 |
+
batch = trimmed_examples_copy[idx:idx + raw_batch_size]
|
261 |
+
if len(batch) < device_batch_size:
|
262 |
+
continue
|
263 |
+
packer._pack_trimmed_examples(batch, sizes[idx:idx + raw_batch_size])
|
264 |
+
if packer.n_packed_examples == 0:
|
265 |
+
log.debug('No examples packed during profiling. Dataset is smaller than device batch size.')
|
266 |
+
return (None, None)
|
267 |
+
padding_percent = 100 * (1 - packer.efficiency)
|
268 |
+
waste_percent = 100 * packer.waste
|
269 |
+
return (padding_percent, waste_percent)
|
270 |
+
for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes):
|
271 |
+
padding, waste = profile(raw_batch_size)
|
272 |
+
yield (packing_ratio, padding, waste)
|
param_init_fns.py
CHANGED
@@ -22,9 +22,9 @@ def fused_init_helper_(module: nn.Module, init_fn_: Callable) -> None:
|
|
22 |
if _fused is None:
|
23 |
raise RuntimeError(f'Internal logic error')
|
24 |
assert isinstance(module.weight, torch.Tensor)
|
25 |
-
|
26 |
splits = (0, *splits, module.weight.size(dim))
|
27 |
-
for
|
28 |
slice_indices = [slice(None)] * module.weight.ndim
|
29 |
slice_indices[dim] = slice(s, e)
|
30 |
init_fn_(module.weight[slice_indices])
|
@@ -71,7 +71,7 @@ def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int,
|
|
71 |
if lim == 0:
|
72 |
warnings.warn(f'Embedding layer initialized to 0.')
|
73 |
lim = [-lim, lim]
|
74 |
-
|
75 |
emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
|
76 |
else:
|
77 |
emb_init_fn_ = init_fn_
|
@@ -88,7 +88,7 @@ def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int,
|
|
88 |
assert d_model is not None
|
89 |
_d = d_model
|
90 |
splits = (0, _d, 2 * _d, 3 * _d)
|
91 |
-
for
|
92 |
init_fn_(module.in_proj_weight[s:e])
|
93 |
else:
|
94 |
assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)
|
|
|
22 |
if _fused is None:
|
23 |
raise RuntimeError(f'Internal logic error')
|
24 |
assert isinstance(module.weight, torch.Tensor)
|
25 |
+
dim, splits = _fused
|
26 |
splits = (0, *splits, module.weight.size(dim))
|
27 |
+
for s, e in zip(splits[:-1], splits[1:]):
|
28 |
slice_indices = [slice(None)] * module.weight.ndim
|
29 |
slice_indices[dim] = slice(s, e)
|
30 |
init_fn_(module.weight[slice_indices])
|
|
|
71 |
if lim == 0:
|
72 |
warnings.warn(f'Embedding layer initialized to 0.')
|
73 |
lim = [-lim, lim]
|
74 |
+
a, b = lim
|
75 |
emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
|
76 |
else:
|
77 |
emb_init_fn_ = init_fn_
|
|
|
88 |
assert d_model is not None
|
89 |
_d = d_model
|
90 |
splits = (0, _d, 2 * _d, 3 * _d)
|
91 |
+
for s, e in zip(splits[:-1], splits[1:]):
|
92 |
init_fn_(module.in_proj_weight[s:e])
|
93 |
else:
|
94 |
assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)
|
prompt_files.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Optional
|
3 |
+
PROMPTFILE_PREFIX = 'file::'
|
4 |
+
|
5 |
+
def load_prompts(prompts: List[str], prompt_delimiter: Optional[str]=None) -> List[str]:
|
6 |
+
"""Loads a set of prompts, both free text and from file.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
prompts (List[str]): List of free text prompts and prompt files
|
10 |
+
prompt_delimiter (Optional str): Delimiter for text file
|
11 |
+
If not provided, assumes the prompt file is a single prompt (non-delimited)
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
List of prompt string(s)
|
15 |
+
"""
|
16 |
+
prompt_strings = []
|
17 |
+
for prompt in prompts:
|
18 |
+
if prompt.startswith(PROMPTFILE_PREFIX):
|
19 |
+
prompts = load_prompts_from_file(prompt, prompt_delimiter)
|
20 |
+
prompt_strings.extend(prompts)
|
21 |
+
else:
|
22 |
+
prompt_strings.append(prompt)
|
23 |
+
return prompt_strings
|
24 |
+
|
25 |
+
def load_prompts_from_file(prompt_path: str, prompt_delimiter: Optional[str]=None) -> List[str]:
|
26 |
+
"""Load a set of prompts from a text fie.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
prompt_path (str): Path for text file
|
30 |
+
prompt_delimiter (Optional str): Delimiter for text file
|
31 |
+
If not provided, assumes the prompt file is a single prompt (non-delimited)
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
List of prompt string(s)
|
35 |
+
"""
|
36 |
+
if not prompt_path.startswith(PROMPTFILE_PREFIX):
|
37 |
+
raise ValueError(f'prompt_path_str must start with {PROMPTFILE_PREFIX}')
|
38 |
+
_, prompt_file_path = prompt_path.split(PROMPTFILE_PREFIX, maxsplit=1)
|
39 |
+
prompt_file_path = os.path.expanduser(prompt_file_path)
|
40 |
+
if not os.path.isfile(prompt_file_path):
|
41 |
+
raise FileNotFoundError(f'prompt_file_path={prompt_file_path!r} does not match any existing files.')
|
42 |
+
with open(prompt_file_path, 'r') as f:
|
43 |
+
prompt_string = f.read()
|
44 |
+
if prompt_delimiter is None:
|
45 |
+
return [prompt_string]
|
46 |
+
return [i for i in prompt_string.split(prompt_delimiter) if i]
|
registry.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Type
|
2 |
+
from torch.optim import Optimizer
|
3 |
+
from torchmetrics import Metric
|
4 |
+
from transformers import PreTrainedTokenizerBase
|
5 |
+
from .interfaces import CallbackWithConfig
|
6 |
+
from .registry_utils import create_registry
|
7 |
+
_loggers_description = 'The loggers registry is used to register classes that implement the LoggerDestination interface. ' + 'These classes are used to log data from the training loop, and will be passed to the loggers arg of the Trainer. The loggers ' + 'will be constructed by directly passing along the specified kwargs to the constructor.'
|
8 |
+
loggers = create_registry('llmfoundry', 'loggers', generic_type=Type[LoggerDestination], entry_points=True, description=_loggers_description)
|
9 |
+
_callbacks_description = 'The callbacks registry is used to register classes that implement the Callback interface. ' + 'These classes are used to interact with the Composer event system, and will be passed to the callbacks arg of the Trainer. ' + 'The callbacks will be constructed by directly passing along the specified kwargs to the constructor.'
|
10 |
+
callbacks = create_registry('llmfoundry', 'callbacks', generic_type=Type[Callback], entry_points=True, description=_callbacks_description)
|
11 |
+
_callbacks_with_config_description = 'The callbacks_with_config registry is used to register classes that implement the CallbackWithConfig interface. ' + 'These are the same as the callbacks registry, except that they additionally take the full training config as an argument to their constructor.'
|
12 |
+
callbacks_with_config = create_registry('llm_foundry.callbacks_with_config', generic_type=Type[CallbackWithConfig], entry_points=True, description=_callbacks_with_config_description)
|
13 |
+
_optimizers_description = 'The optimizers registry is used to register classes that implement the Optimizer interface. ' + 'The optimizer will be passed to the optimizers arg of the Trainer. The optimizer will be constructed by directly passing along the ' + 'specified kwargs to the constructor, along with the model parameters.'
|
14 |
+
optimizers = create_registry('llmfoundry', 'optimizers', generic_type=Type[Optimizer], entry_points=True, description=_optimizers_description)
|
15 |
+
_algorithms_description = 'The algorithms registry is used to register classes that implement the Algorithm interface. ' + 'The algorithm will be passed to the algorithms arg of the Trainer. The algorithm will be constructed by directly passing along the ' + 'specified kwargs to the constructor.'
|
16 |
+
algorithms = create_registry('llmfoundry', 'algorithms', generic_type=Type[Algorithm], entry_points=True, description=_algorithms_description)
|
17 |
+
_schedulers_description = 'The schedulers registry is used to register classes that implement the ComposerScheduler interface. ' + 'The scheduler will be passed to the schedulers arg of the Trainer. The scheduler will be constructed by directly passing along the ' + 'specified kwargs to the constructor.'
|
18 |
+
schedulers = create_registry('llmfoundry', 'schedulers', generic_type=Type[ComposerScheduler], entry_points=True, description=_schedulers_description)
|
19 |
+
_models_description = 'The models registry is used to register classes that implement the ComposerModel interface. The model\nconstructor should accept two arguments: an omegaconf DictConfig named `om_model_config` and a PreTrainedTokenizerBase named `tokenizer`.\nNote: This will soon be updated to take in named kwargs instead of a config directly.'
|
20 |
+
models = create_registry('llmfoundry', 'models', generic_type=Type[ComposerModel], entry_points=True, description=_models_description)
|
21 |
+
_dataloaders_description = 'The dataloaders registry is used to register functions that create a DataSpec. The function should take\na DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec.'
|
22 |
+
dataloaders = create_registry('llmfoundry', 'dataloaders', generic_type=Callable[[DictConfig, PreTrainedTokenizerBase, int], DataSpec], entry_points=True, description=_dataloaders_description)
|
23 |
+
_metrics_description = 'The metrics registry is used to register classes that implement the torchmetrics.Metric interface.'
|
24 |
+
metrics = create_registry('llmfoundry', 'metrics', generic_type=Type[Metric], entry_points=True, description=_metrics_description)
|
registry_utils.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import importlib.util
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
from types import ModuleType
|
6 |
+
from typing import Any, Callable, Dict, Generic, Optional, Sequence, Type, TypeVar, Union
|
7 |
+
import catalogue
|
8 |
+
T = TypeVar('T')
|
9 |
+
|
10 |
+
class TypedRegistry(catalogue.Registry, Generic[T]):
|
11 |
+
"""A thin wrapper around catalogue.Registry to add static typing and.
|
12 |
+
|
13 |
+
descriptions.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, namespace: Sequence[str], entry_points: bool=False, description: str='') -> None:
|
17 |
+
super().__init__(namespace, entry_points=entry_points)
|
18 |
+
self.description = description
|
19 |
+
|
20 |
+
def __call__(self, name: str, func: Optional[T]=None) -> Callable[[T], T]:
|
21 |
+
return super().__call__(name, func)
|
22 |
+
|
23 |
+
def register(self, name: str, *, func: Optional[T]=None) -> T:
|
24 |
+
return super().register(name, func=func)
|
25 |
+
|
26 |
+
def get(self, name: str) -> T:
|
27 |
+
return super().get(name)
|
28 |
+
|
29 |
+
def get_all(self) -> Dict[str, T]:
|
30 |
+
return super().get_all()
|
31 |
+
|
32 |
+
def get_entry_point(self, name: str, default: Optional[T]=None) -> T:
|
33 |
+
return super().get_entry_point(name, default=default)
|
34 |
+
|
35 |
+
def get_entry_points(self) -> Dict[str, T]:
|
36 |
+
return super().get_entry_points()
|
37 |
+
S = TypeVar('S')
|
38 |
+
|
39 |
+
def create_registry(*namespace: str, generic_type: Type[S], entry_points: bool=False, description: str='') -> 'TypedRegistry[S]':
|
40 |
+
"""Create a new registry.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
namespace (str): The namespace, e.g. "llmfoundry.loggers"
|
44 |
+
generic_type (Type[S]): The type of the registry.
|
45 |
+
entry_points (bool): Accept registered functions from entry points.
|
46 |
+
description (str): A description of the registry.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
The TypedRegistry object.
|
50 |
+
"""
|
51 |
+
if catalogue.check_exists(*namespace):
|
52 |
+
raise catalogue.RegistryError(f'Namespace already exists: {namespace}')
|
53 |
+
return TypedRegistry[generic_type](namespace, entry_points=entry_points, description=description)
|
54 |
+
|
55 |
+
def construct_from_registry(name: str, registry: TypedRegistry, partial_function: bool=True, pre_validation_function: Optional[Union[Callable[[Any], None], type]]=None, post_validation_function: Optional[Callable[[Any], None]]=None, kwargs: Optional[Dict[str, Any]]=None) -> Any:
|
56 |
+
"""Helper function to build an item from the registry.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
name (str): The name of the registered item
|
60 |
+
registry (catalogue.Registry): The registry to fetch the item from
|
61 |
+
partial_function (bool, optional): Whether to return a partial function for registered callables. Defaults to True.
|
62 |
+
pre_validation_function (Optional[Union[Callable[[Any], None], type]], optional): An optional validation function called
|
63 |
+
before constructing the item to return. This should throw an exception if validation fails. Defaults to None.
|
64 |
+
post_validation_function (Optional[Callable[[Any], None]], optional): An optional validation function called after
|
65 |
+
constructing the item to return. This should throw an exception if validation fails. Defaults to None.
|
66 |
+
|
67 |
+
Raises:
|
68 |
+
ValueError: If the validation functions failed or the registered item is invalid
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
Any: The constructed item from the registry
|
72 |
+
"""
|
73 |
+
if kwargs is None:
|
74 |
+
kwargs = {}
|
75 |
+
registered_constructor = registry.get(name)
|
76 |
+
if pre_validation_function is not None:
|
77 |
+
if isinstance(pre_validation_function, type):
|
78 |
+
if not issubclass(registered_constructor, pre_validation_function):
|
79 |
+
raise ValueError(f'Expected {name} to be of type {pre_validation_function}, but got {type(registered_constructor)}')
|
80 |
+
elif isinstance(pre_validation_function, Callable):
|
81 |
+
pre_validation_function(registered_constructor)
|
82 |
+
else:
|
83 |
+
raise ValueError(f'Expected pre_validation_function to be a callable or a type, but got {type(pre_validation_function)}')
|
84 |
+
if isinstance(registered_constructor, type) or (callable(registered_constructor) and (not partial_function)):
|
85 |
+
constructed_item = registered_constructor(**kwargs)
|
86 |
+
elif callable(registered_constructor):
|
87 |
+
constructed_item = functools.partial(registered_constructor, **kwargs)
|
88 |
+
else:
|
89 |
+
raise ValueError(f'Expected {name} to be a class or function, but got {type(registered_constructor)}')
|
90 |
+
if post_validation_function is not None:
|
91 |
+
post_validation_function(registered_constructor)
|
92 |
+
return constructed_item
|
93 |
+
|
94 |
+
def import_file(loc: Union[str, Path]) -> ModuleType:
|
95 |
+
"""Import module from a file.
|
96 |
+
|
97 |
+
Used to run arbitrary python code.
|
98 |
+
Args:
|
99 |
+
name (str): Name of module to load.
|
100 |
+
loc (str / Path): Path to the file.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
ModuleType: The module object.
|
104 |
+
"""
|
105 |
+
if not os.path.exists(loc):
|
106 |
+
raise FileNotFoundError(f'File {loc} does not exist.')
|
107 |
+
spec = importlib.util.spec_from_file_location('python_code', str(loc))
|
108 |
+
assert spec is not None
|
109 |
+
assert spec.loader is not None
|
110 |
+
module = importlib.util.module_from_spec(spec)
|
111 |
+
try:
|
112 |
+
spec.loader.exec_module(module)
|
113 |
+
except Exception as e:
|
114 |
+
raise RuntimeError(f'Error executing {loc}') from e
|
115 |
+
return module
|
resumption_callbacks.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import List
|
3 |
+
log = logging.getLogger(__name__)
|
4 |
+
|
5 |
+
class GlobalLRScaling(Callback):
|
6 |
+
"""GlobalLRScaling.
|
7 |
+
|
8 |
+
This callback can be applied upon resuming a model checkpoint. Upon
|
9 |
+
fit_start it will multiply the base LR by `lr_scale` and set the WD to be.
|
10 |
+
|
11 |
+
`wd_pct` * `lr`.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
lr_scale (float): Multiplicative factor to scale LR by
|
15 |
+
wd_pct (float): Percentage of LR to set weight decay to.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, lr_scale: float, wd_pct: float=0.0):
|
19 |
+
self.lr_scale = lr_scale
|
20 |
+
self.wd_pct = wd_pct
|
21 |
+
|
22 |
+
def fit_start(self, state: State, logger: Logger) -> None:
|
23 |
+
del logger
|
24 |
+
if hasattr(state, 'optimizer') and state.optimizers is None:
|
25 |
+
raise Exception('No optimizers defined')
|
26 |
+
for optimizer in state.optimizers:
|
27 |
+
for group in optimizer.param_groups:
|
28 |
+
group['lr'] *= self.lr_scale
|
29 |
+
group['weight_decay'] = group['lr'] * self.wd_pct
|
30 |
+
if 'initial_lr' in group:
|
31 |
+
group['initial_lr'] *= self.lr_scale
|
32 |
+
log.info(f"Set LR and WD to {group['lr']}, {group['weight_decay']}")
|
33 |
+
for scheduler in state.schedulers:
|
34 |
+
scheduler.base_lrs = [self.lr_scale * lr for lr in scheduler.base_lrs]
|
35 |
+
|
36 |
+
class LayerFreezing(Callback):
|
37 |
+
"""LayerFreezing.
|
38 |
+
|
39 |
+
This callback can be applied upon resuming a model checkpoint. Upon
|
40 |
+
fit_start it freeze the layers specified in `layer_names`. If using
|
41 |
+
activation checkpointing, please set the
|
42 |
+
`activation_checkpointing_reentrant` flag in `fsdp_config` to false.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
layer_names (float): Names of layers to freeze.
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(self, layer_names: List[str]):
|
49 |
+
self.layer_names = set(layer_names)
|
50 |
+
|
51 |
+
def fit_start(self, state: State, logger: Logger) -> None:
|
52 |
+
del logger
|
53 |
+
model_layers = set((name for name, _ in state.model.named_parameters()))
|
54 |
+
for layer in self.layer_names:
|
55 |
+
if layer not in model_layers:
|
56 |
+
raise Exception(f'Attempted to freeze layer not found in model: {layer}\nAvailable layers: {model_layers}')
|
57 |
+
successful_freeze = False
|
58 |
+
for name, p in state.model.named_parameters():
|
59 |
+
if p.requires_grad and name in self.layer_names:
|
60 |
+
p.requires_grad = False
|
61 |
+
log.debug(f'Froze layer: {name}\nParam: {p}')
|
62 |
+
successful_freeze = True
|
63 |
+
if not successful_freeze:
|
64 |
+
raise Exception(f"Tried to run LayerFreezing but didn't freeze any layers")
|
scheduled_gc_callback.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
from typing import Optional
|
3 |
+
import torch
|
4 |
+
|
5 |
+
def gc_cuda():
|
6 |
+
"""Garbage collect Torch (CUDA) memory."""
|
7 |
+
gc.collect()
|
8 |
+
if torch.cuda.is_available():
|
9 |
+
torch.cuda.empty_cache()
|
10 |
+
|
11 |
+
class ScheduledGarbageCollector(Callback):
|
12 |
+
"""Disable automatic garbage collection and collect garbage at interval.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
batch_interval (int): Number of batches between calls to gc.collect()
|
16 |
+
gen_1_batch_interval(int, optional): Number of batches between calls to gc.collect(1)
|
17 |
+
eval_keep_disabled (bool): keep gc disabled during eval (default: False)
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, batch_interval: int, gen_1_batch_interval: Optional[int]=None, eval_keep_disabled: bool=False):
|
21 |
+
self.batch_interval = batch_interval
|
22 |
+
self.gen_1_batch_interval = gen_1_batch_interval
|
23 |
+
self.eval_keep_disabled = eval_keep_disabled
|
24 |
+
self.gc_init_state = None
|
25 |
+
|
26 |
+
def fit_start(self, state: State, logger: Logger) -> None:
|
27 |
+
del state, logger
|
28 |
+
self.gc_init_state = gc.isenabled()
|
29 |
+
gc.disable()
|
30 |
+
gc_cuda()
|
31 |
+
|
32 |
+
def fit_end(self, state: State, logger: Logger) -> None:
|
33 |
+
del state, logger
|
34 |
+
gc_cuda()
|
35 |
+
if self.gc_init_state:
|
36 |
+
gc.enable()
|
37 |
+
else:
|
38 |
+
gc.disable()
|
39 |
+
|
40 |
+
def before_dataloader(self, state: State, logger: Logger) -> None:
|
41 |
+
del logger
|
42 |
+
if self.gen_1_batch_interval is not None and state.timestamp.batch.value % self.gen_1_batch_interval == 0:
|
43 |
+
gc.collect(1)
|
44 |
+
if state.timestamp.batch.value % self.batch_interval == 0:
|
45 |
+
gc_cuda()
|
46 |
+
|
47 |
+
def eval_start(self, state: State, logger: Logger) -> None:
|
48 |
+
del state, logger
|
49 |
+
gc_cuda()
|
50 |
+
if not self.eval_keep_disabled:
|
51 |
+
gc.enable()
|
52 |
+
|
53 |
+
def eval_end(self, state: State, logger: Logger) -> None:
|
54 |
+
del state, logger
|
55 |
+
if not self.eval_keep_disabled:
|
56 |
+
gc.disable()
|
57 |
+
gc_cuda()
|
tasks.py
ADDED
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Includes code for task-specific seq-to-seq data formatting.
|
2 |
+
|
3 |
+
This file provides some templates/examples of preprocessing functions
|
4 |
+
that format examples for use in seq-to-seq finetuning tasks.
|
5 |
+
These preprocessing functions take individual examples that contain raw
|
6 |
+
text and process them into formatted examples.
|
7 |
+
|
8 |
+
These functions have this basic structure:
|
9 |
+
|
10 |
+
def preprocessing_fn(example: Dict) -> Dict[str, str]:
|
11 |
+
# code to extract prompt/response from `example`
|
12 |
+
...
|
13 |
+
return {
|
14 |
+
'prompt': <prompt>,
|
15 |
+
'response': <response>,
|
16 |
+
}
|
17 |
+
|
18 |
+
where `<prompt>` is a placeholder for the prompt text string that you
|
19 |
+
extracted from the input example, and '<response>' is a placeholder for
|
20 |
+
the response text string.
|
21 |
+
|
22 |
+
Just to be clear, "prompt" represents the text you would give the model
|
23 |
+
at inference time, and "response" represents the text you are training
|
24 |
+
it to produce given the prompt.
|
25 |
+
|
26 |
+
The key requirement of these functions is that they return a dictionary
|
27 |
+
with "prompt" and "response" keys, and that the values associated with
|
28 |
+
those keys are strings (i.e. text).
|
29 |
+
"""
|
30 |
+
import importlib
|
31 |
+
import logging
|
32 |
+
import os
|
33 |
+
import warnings
|
34 |
+
from collections.abc import Mapping
|
35 |
+
from functools import partial
|
36 |
+
from pathlib import Path
|
37 |
+
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union, cast
|
38 |
+
import datasets as hf_datasets
|
39 |
+
import huggingface_hub as hf_hub
|
40 |
+
import numpy as np
|
41 |
+
from streaming import Stream, StreamingDataset
|
42 |
+
from transformers import PreTrainedTokenizerBase
|
43 |
+
from .collator import _HF_IGNORE_INDEX, stitch_turns_decoder_only, stitch_turns_encoder_decoder
|
44 |
+
from .exceptions import ConsecutiveRepeatedChatRolesError, IncorrectMessageKeyQuantityError, InvalidContentTypeError, InvalidFileExtensionError, InvalidLastChatMessageRoleError, InvalidPromptResponseKeysError, InvalidPromptTypeError, InvalidResponseTypeError, InvalidRoleError, NotEnoughChatDataError, TooManyKeysInExampleError, UnableToProcessPromptResponseError, UnknownExampleTypeError
|
45 |
+
from .logging_utils import SpecificWarningFilter
|
46 |
+
log = logging.getLogger(__name__)
|
47 |
+
_ALLOWED_RESPONSE_KEYS = {'response', 'completion'}
|
48 |
+
_ALLOWED_PROMPT_KEYS = {'prompt'}
|
49 |
+
_ALLOWED_MESSAGES_KEYS = {'messages'}
|
50 |
+
_ALLOWED_ROLE_KEYS = {'role'}
|
51 |
+
_ALLOWED_CONTENT_KEYS = {'content'}
|
52 |
+
_ALLOWED_ROLES = {'user', 'assistant', 'system'}
|
53 |
+
_ALLOWED_LAST_MESSAGE_ROLES = {'assistant'}
|
54 |
+
DOWNLOADED_FT_DATASETS_DIRPATH = os.path.abspath(os.path.join(os.path.realpath(__file__), os.pardir, os.pardir, os.pardir, '.downloaded_finetuning'))
|
55 |
+
SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet']
|
56 |
+
PromptResponseDict = Mapping[str, str]
|
57 |
+
ChatFormattedDict = Mapping[str, List[Dict[str, str]]]
|
58 |
+
Example = Union[PromptResponseDict, ChatFormattedDict]
|
59 |
+
ExampleType = Literal['prompt_response', 'chat']
|
60 |
+
TokenizedExample = Dict[str, List[Dict[str, List[int]]]]
|
61 |
+
|
62 |
+
def _get_example_type(example: Example) -> ExampleType:
|
63 |
+
"""Determines the type of the input example.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
example (Example): The input example, which can be a multi-way chat formatted conversation or an instruction-response pair.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
ExampleType: The type of the input example, which can be either 'chat' for multi-way chat formatted conversation or 'prompt_response' for instruction-response pair.
|
70 |
+
|
71 |
+
Raises:
|
72 |
+
KeyError: If the example type is unknown.
|
73 |
+
"""
|
74 |
+
if not isinstance(example, Mapping):
|
75 |
+
raise TypeError(f'Expected example to be a Mapping, but found {type(example)}')
|
76 |
+
if any((allowed_message_key in example for allowed_message_key in _ALLOWED_MESSAGES_KEYS)):
|
77 |
+
return 'chat'
|
78 |
+
elif any((p in example for p in _ALLOWED_PROMPT_KEYS)) and any((r in example for r in _ALLOWED_RESPONSE_KEYS)):
|
79 |
+
return 'prompt_response'
|
80 |
+
else:
|
81 |
+
raise UnknownExampleTypeError(example)
|
82 |
+
|
83 |
+
def _is_empty_or_nonexistent(dirpath: str) -> bool:
|
84 |
+
"""Check if a directory is empty or non-existent.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
dirpath (str): Directory path to check.
|
88 |
+
|
89 |
+
Returns
|
90 |
+
True if directory is empty or non-existent. False otherwise.
|
91 |
+
"""
|
92 |
+
return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0
|
93 |
+
|
94 |
+
def _get_key(dictionary: Mapping[str, Any], allowed_keys: set[str]):
|
95 |
+
if not isinstance(dictionary, Mapping):
|
96 |
+
raise TypeError(f'Expected dictionary to be a mapping, but found {type(dictionary)}')
|
97 |
+
desired_keys = allowed_keys.intersection(dictionary.keys())
|
98 |
+
if len(desired_keys) != 1:
|
99 |
+
raise TooManyKeysInExampleError(allowed_keys, desired_keys)
|
100 |
+
return list(desired_keys)[0]
|
101 |
+
|
102 |
+
def _validate_chat_formatted_example(example: ChatFormattedDict):
|
103 |
+
if not isinstance(example, Mapping):
|
104 |
+
raise TypeError(f'Expected example to be a mapping, but found {type(example)}')
|
105 |
+
messages = example[_get_key(example, _ALLOWED_MESSAGES_KEYS)]
|
106 |
+
if not isinstance(messages, List):
|
107 |
+
raise TypeError(f'Expected messages to be an iterable, but found {type(messages)}')
|
108 |
+
if len(messages) <= 1:
|
109 |
+
raise NotEnoughChatDataError()
|
110 |
+
last_message = messages[-1]
|
111 |
+
role_key = _get_key(last_message, _ALLOWED_ROLE_KEYS)
|
112 |
+
last_role = last_message[role_key]
|
113 |
+
if last_role not in _ALLOWED_LAST_MESSAGE_ROLES:
|
114 |
+
raise InvalidLastChatMessageRoleError(last_role, _ALLOWED_LAST_MESSAGE_ROLES)
|
115 |
+
last_message_role = None
|
116 |
+
for message in messages:
|
117 |
+
role_key, content_key = (_get_key(message, _ALLOWED_ROLE_KEYS), _get_key(message, _ALLOWED_CONTENT_KEYS))
|
118 |
+
if len(message.keys()) != 2:
|
119 |
+
raise IncorrectMessageKeyQuantityError(list(message.keys()))
|
120 |
+
if message[role_key] not in _ALLOWED_ROLES:
|
121 |
+
raise InvalidRoleError(message[role_key], _ALLOWED_ROLES)
|
122 |
+
if not isinstance(message[content_key], str):
|
123 |
+
raise InvalidContentTypeError(type(message[content_key]))
|
124 |
+
if last_message_role is not None and last_message_role == message[role_key]:
|
125 |
+
raise ConsecutiveRepeatedChatRolesError(last_message_role)
|
126 |
+
last_message_role = message[role_key]
|
127 |
+
|
128 |
+
def _slice_chat_formatted_example(example: ChatFormattedDict, tokenizer: PreTrainedTokenizerBase) -> List[Tuple[str, str]]:
|
129 |
+
"""Slices chat example into a list of templated prompt, response turns.
|
130 |
+
|
131 |
+
Note: Assistant messages mark the end of chat turns. So there are as many turns as there are
|
132 |
+
assistant messages in the chat example.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
example (ChatFormattedDict): The chat example containing the messages.
|
136 |
+
tokenizer (PreTrainedTokenizerBase): The tokenizer to apply the chat template.
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
List[Tuple[str, str]]: A list of templated prompt and response string pairs, one pair per chat turn.
|
140 |
+
|
141 |
+
Raises:
|
142 |
+
ValueError: If any chat turn in the example has less than two messages or if the last message is not from the assistant.
|
143 |
+
KeyError: If a message does not have a role or content.
|
144 |
+
"""
|
145 |
+
_validate_chat_formatted_example(example)
|
146 |
+
messages = example[_get_key(example, _ALLOWED_MESSAGES_KEYS)]
|
147 |
+
last_message = messages[-1]
|
148 |
+
if last_message['role'] != 'assistant':
|
149 |
+
raise InvalidLastChatMessageRoleError(last_message['role'], set(['assistant']))
|
150 |
+
|
151 |
+
def slice_out_last_turn(messages_through_current_turn: List[Dict[str, str]], conversation_through_previous_turn: str) -> Tuple[str, str]:
|
152 |
+
full_conversation = tokenizer.apply_chat_template(messages_through_current_turn, tokenize=False)
|
153 |
+
prompt_with_history = tokenizer.apply_chat_template(messages_through_current_turn[:-1], tokenize=False, add_generation_prompt=True)
|
154 |
+
if conversation_through_previous_turn != full_conversation[:len(conversation_through_previous_turn)]:
|
155 |
+
raise ValueError(f'The full conversation must start with the conversation through the previous turn. conversation_through_previous_turn={conversation_through_previous_turn!r}, full_conversation={full_conversation!r}')
|
156 |
+
if conversation_through_previous_turn != prompt_with_history[:len(conversation_through_previous_turn)]:
|
157 |
+
raise ValueError(f'The prompt_with_histry must start with the conversation through the previous turn. conversation_through_previous_turn={conversation_through_previous_turn!r}, prompt_with_history={prompt_with_history!r}')
|
158 |
+
if prompt_with_history != full_conversation[:len(prompt_with_history)]:
|
159 |
+
raise ValueError(f'prompt_with_history must be the first part of the full conversation. prompt_with_history={prompt_with_history!r}, full_conversation={full_conversation!r}')
|
160 |
+
prompt = prompt_with_history[len(conversation_through_previous_turn):]
|
161 |
+
response = full_conversation[len(prompt_with_history):]
|
162 |
+
return (prompt, response)
|
163 |
+
templated_prompt_response_turns: List[Tuple[str, str]] = []
|
164 |
+
conversation_through_previous_turn = ''
|
165 |
+
for idx, message in enumerate(messages):
|
166 |
+
if message['role'] == 'assistant':
|
167 |
+
prompt, response = slice_out_last_turn(messages[:idx + 1], conversation_through_previous_turn)
|
168 |
+
templated_prompt_response_turns.append((prompt, response))
|
169 |
+
conversation_through_previous_turn += prompt
|
170 |
+
conversation_through_previous_turn += response
|
171 |
+
return templated_prompt_response_turns
|
172 |
+
|
173 |
+
def _tokenize_with_bos_removal(tokenizer: PreTrainedTokenizerBase, text: str, text_target: str) -> Dict[str, List[int]]:
|
174 |
+
"""Tokenizes the prompt and response using the provided tokenizer.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
tokenizer (PreTrainedTokenizerBase): The tokenizer to use for tokenization.
|
178 |
+
text (str): The prompt to tokenize.
|
179 |
+
text_target (str): The response to tokenize.
|
180 |
+
|
181 |
+
Returns:
|
182 |
+
Dict[str, List[int]]: The tokenized text and text_target.
|
183 |
+
"""
|
184 |
+
tokenized_sample = tokenizer(text=text, text_target=text_target, padding=False, truncation=False)
|
185 |
+
if hasattr(tokenizer, 'add_bos_token') and tokenizer.add_bos_token:
|
186 |
+
if tokenizer.bos_token_id is not None and tokenized_sample['labels'][0] == tokenizer.bos_token_id:
|
187 |
+
tokenized_sample['labels'] = tokenized_sample['labels'][1:]
|
188 |
+
return tokenized_sample
|
189 |
+
|
190 |
+
def _tokenize_chat_formatted_example(example: ChatFormattedDict, tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
|
191 |
+
"""Tokenizes a chat-formatted example using the provided tokenizer.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
example (ChatFormattedDict): The chat-formatted example to tokenize.
|
195 |
+
tokenizer (PreTrainedTokenizerBase): The tokenizer to use for tokenization.
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
TokenizedExample: The tokenized example.
|
199 |
+
"""
|
200 |
+
return {'turns': [tokenizer(text=prompt, text_target=response, add_special_tokens=False, padding=False, truncation=False) for prompt, response in _slice_chat_formatted_example(example, tokenizer)]}
|
201 |
+
|
202 |
+
def _tokenize_prompt_response_formatted_example(example: PromptResponseDict, tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
|
203 |
+
"""Tokenize a formatted example and validate expected keys."""
|
204 |
+
example_keys = set(example.keys())
|
205 |
+
prompt_keys = example_keys.intersection(_ALLOWED_PROMPT_KEYS)
|
206 |
+
response_keys = example_keys.intersection(_ALLOWED_RESPONSE_KEYS)
|
207 |
+
if len(prompt_keys) != 1:
|
208 |
+
raise TooManyKeysInExampleError(_ALLOWED_PROMPT_KEYS, prompt_keys)
|
209 |
+
if len(response_keys) != 1:
|
210 |
+
raise TooManyKeysInExampleError(_ALLOWED_RESPONSE_KEYS, response_keys)
|
211 |
+
prompt_key = prompt_keys.pop()
|
212 |
+
response_key = response_keys.pop()
|
213 |
+
prompt = example[prompt_key]
|
214 |
+
response = example[response_key]
|
215 |
+
if not isinstance(prompt, str):
|
216 |
+
raise InvalidPromptTypeError(type(prompt))
|
217 |
+
if not isinstance(response, str):
|
218 |
+
raise InvalidResponseTypeError(type(response))
|
219 |
+
return {'turns': [_tokenize_with_bos_removal(tokenizer=tokenizer, text=prompt, text_target=response)]}
|
220 |
+
|
221 |
+
def tokenize_formatted_example(example: Example, tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
|
222 |
+
"""Tokenizes a formatted example using the provided tokenizer.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
example (Example): The input example to be tokenized.
|
226 |
+
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for tokenization.
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
TokenizedExample: The tokenized example.
|
230 |
+
|
231 |
+
Raises:
|
232 |
+
ValueError: If the example format is unknown.
|
233 |
+
"""
|
234 |
+
example_format = _get_example_type(example)
|
235 |
+
if example_format == 'chat':
|
236 |
+
chat_example = cast(ChatFormattedDict, example)
|
237 |
+
return _tokenize_chat_formatted_example(chat_example, tokenizer)
|
238 |
+
elif example_format == 'prompt_response':
|
239 |
+
prompt_response_example: PromptResponseDict = cast(PromptResponseDict, example)
|
240 |
+
return _tokenize_prompt_response_formatted_example(prompt_response_example, tokenizer)
|
241 |
+
else:
|
242 |
+
raise UnknownExampleTypeError(example)
|
243 |
+
|
244 |
+
def is_valid_ift_example(max_seq_len: int, target_prompts: str, target_responses: str, decoder_only_format: bool, example: TokenizedExample) -> bool:
|
245 |
+
"""Check if the example is a valid ift example.
|
246 |
+
|
247 |
+
This function confirms that none of the ``input_ids`` and ``labels`` fields
|
248 |
+
are empty in any of the turns within the example.
|
249 |
+
|
250 |
+
This function also prepares the final input_ids and labels
|
251 |
+
of the (potentially multi-turn) example, using the target settings
|
252 |
+
and format, and checks whether they are suitable for training at max_seq_len.
|
253 |
+
The example is not valid if (1) after truncation (if necessary),
|
254 |
+
the training targets contain no loss-generating tokens, or (2) either the
|
255 |
+
input_ids and labels are empty.
|
256 |
+
|
257 |
+
The token sequences in ``example`` are assumed to not have had
|
258 |
+
any padding or truncation applied already.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
max_seq_len (int): Maximum sequence length.
|
262 |
+
target_prompts (str): The prompts that are used as targets.
|
263 |
+
target_responses (str): The responses that are used as targets.
|
264 |
+
decoder_only_format (bool): Whether the data will be formatted
|
265 |
+
for a decoder-only model.
|
266 |
+
example (Dict): The input example after tokenization, which has
|
267 |
+
a list of dicts, each with ``input_ids`` and ``labels`` fields.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
bool: Indicator of whether the input example is valid
|
271 |
+
"""
|
272 |
+
for turn in example['turns']:
|
273 |
+
if len(turn['input_ids']) == 0:
|
274 |
+
return False
|
275 |
+
if len(turn['labels']) == 0:
|
276 |
+
return False
|
277 |
+
if decoder_only_format:
|
278 |
+
input_ids, labels = stitch_turns_decoder_only(example_turns=example['turns'], target_prompts=target_prompts, target_responses=target_responses)
|
279 |
+
else:
|
280 |
+
input_ids, labels = stitch_turns_encoder_decoder(example_turns=example['turns'])
|
281 |
+
input_ids = input_ids[:max_seq_len]
|
282 |
+
labels = labels[:max_seq_len]
|
283 |
+
if len(input_ids) == 0:
|
284 |
+
return False
|
285 |
+
if len([label for label in labels if label != _HF_IGNORE_INDEX]) == 0:
|
286 |
+
return False
|
287 |
+
return True
|
288 |
+
|
289 |
+
def _stream_remote_local_validate(remote: Optional[str], local: Optional[str], split: Optional[str]):
|
290 |
+
if remote is None or local == remote:
|
291 |
+
if local is not None and os.path.isdir(local):
|
292 |
+
contents = set(os.listdir(local))
|
293 |
+
if split is not None and split not in contents:
|
294 |
+
raise ValueError(f'Local directory {local} does not contain split {split}')
|
295 |
+
|
296 |
+
class StreamingFinetuningDataset(StreamingDataset):
|
297 |
+
"""Finetuning dataset with flexible tokenization using StreamingDataset.
|
298 |
+
|
299 |
+
Args:
|
300 |
+
tokenizer (Tokenizer): The name of the HuggingFace tokenizer to use to
|
301 |
+
tokenize samples.
|
302 |
+
streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from,
|
303 |
+
which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or
|
304 |
+
``remote``/``local``. Defaults to ``None``.
|
305 |
+
local (str): Local dataset directory where shards are cached by split.
|
306 |
+
remote (str, optional): Remote path or directory to download the dataset from. If ``None``,
|
307 |
+
its data must exist locally. StreamingDataset uses either ``streams`` or
|
308 |
+
``remote``/``local``. Defaults to ``None``.
|
309 |
+
split (str, optional): Which dataset split to use, if any. If provided, we stream from/to
|
310 |
+
the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``.
|
311 |
+
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
|
312 |
+
download_timeout (float): Number of seconds to wait for a shard to download before raising
|
313 |
+
an exception. Defaults to ``60``.
|
314 |
+
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
|
315 |
+
shards. Defaults to ``None``.
|
316 |
+
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
|
317 |
+
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
|
318 |
+
`False``.
|
319 |
+
epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced across all
|
320 |
+
streams. If ``None``, takes its value from the total number of underlying samples.
|
321 |
+
Provide this field if you are weighting streams relatively to target a larger or
|
322 |
+
smaller epoch size. Defaults to ``None``.
|
323 |
+
predownload (int, optional): Target number of samples ahead to download the shards of while
|
324 |
+
iterating. If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``.
|
325 |
+
cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's
|
326 |
+
shard cache. Before downloading a shard, the least recently used resident shard(s) may
|
327 |
+
be evicted (deleted from the local cache) in order to stay under the limit. Set to None
|
328 |
+
to disable shard eviction. Supports integer bytes as well as string human-readable
|
329 |
+
bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None.
|
330 |
+
partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
|
331 |
+
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
|
332 |
+
resumption. If ``None``, this is interpreted as 64 times the number of physical
|
333 |
+
nodes of the initial run if ``shuffle_algo`` is ``py1s`` or ``py2s``, and simply the
|
334 |
+
number of physical nodes of the initial run otherwise. Defaults to ``None``.
|
335 |
+
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
|
336 |
+
partitioned over the workers. Defaults to ``None``.
|
337 |
+
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
|
338 |
+
``False``.
|
339 |
+
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``.
|
340 |
+
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
|
341 |
+
shuffle_block_size (int): Unit of shuffle. If ``None``, its value is calculated as
|
342 |
+
``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to ``None``.
|
343 |
+
sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
|
344 |
+
Defaults to ``balanced``.
|
345 |
+
sampling_granularity (int): When picking samples for a stream's final partial repeat,
|
346 |
+
how many samples to pick from the same shard at a time (``1`` for evenly balanced
|
347 |
+
across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc).
|
348 |
+
Defaults to ``1``.
|
349 |
+
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
|
350 |
+
``per_stream``. Defaults to ``random``.
|
351 |
+
"""
|
352 |
+
|
353 |
+
def __init__(self, tokenizer: PreTrainedTokenizerBase, streams: Optional[Sequence[Stream]]=None, local: Optional[str]=None, remote: Optional[str]=None, split: Optional[str]=None, download_retry: int=2, download_timeout: float=60, validate_hash: Optional[str]=None, keep_zip: bool=False, epoch_size: Optional[Union[int, str]]=None, predownload: Optional[int]=None, cache_limit: Optional[Union[int, str]]=None, partition_algo: str='relaxed', num_canonical_nodes: Optional[int]=None, batch_size: Optional[int]=None, shuffle: bool=False, shuffle_algo: str='py1e', shuffle_seed: int=9176, shuffle_block_size: Optional[int]=None, sampling_method: str='balanced', sampling_granularity: int=1, batching_method: str='random', max_seq_len: int=2048, **kwargs: Any):
|
354 |
+
if len(kwargs) > 0:
|
355 |
+
raise ValueError(f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}')
|
356 |
+
if streams is None:
|
357 |
+
_stream_remote_local_validate(remote, local, split)
|
358 |
+
else:
|
359 |
+
for stream in streams:
|
360 |
+
_stream_remote_local_validate(stream.remote, stream.local, split)
|
361 |
+
super().__init__(streams=streams, local=local, remote=remote, split=split, download_retry=download_retry, download_timeout=download_timeout, validate_hash=validate_hash, keep_zip=keep_zip, epoch_size=epoch_size, predownload=predownload, cache_limit=cache_limit, partition_algo=partition_algo, num_canonical_nodes=num_canonical_nodes, batch_size=batch_size, shuffle=shuffle, shuffle_algo=shuffle_algo, shuffle_seed=shuffle_seed, shuffle_block_size=shuffle_block_size, sampling_method=sampling_method, sampling_granularity=sampling_granularity, batching_method=batching_method)
|
362 |
+
self.tokenizer = tokenizer
|
363 |
+
self.max_seq_len = max_seq_len
|
364 |
+
|
365 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
366 |
+
sample = super().__getitem__(idx)
|
367 |
+
if 'turns' in sample:
|
368 |
+
return sample
|
369 |
+
if 'input_ids' in sample:
|
370 |
+
if isinstance(sample['input_ids'], bytes):
|
371 |
+
sample['input_ids'] = np.frombuffer(sample['input_ids'], dtype=np.int64)[:self.max_seq_len].tolist().copy()
|
372 |
+
sample['labels'] = np.frombuffer(sample['labels'], dtype=np.int64)[:self.max_seq_len].tolist().copy()
|
373 |
+
elif isinstance(sample['input_ids'], np.ndarray):
|
374 |
+
sample['input_ids'] = sample['input_ids'][:self.max_seq_len].tolist().copy()
|
375 |
+
sample['labels'] = sample['labels'][:self.max_seq_len].tolist().copy()
|
376 |
+
else:
|
377 |
+
raise ValueError(f"Expect input_ids to be bytes or numpy.ndarray type, but got {type(sample['input_ids'])}")
|
378 |
+
return {'turns': [sample]}
|
379 |
+
return tokenize_formatted_example(sample, tokenizer=self.tokenizer)
|
380 |
+
|
381 |
+
class DatasetConstructor:
|
382 |
+
|
383 |
+
def __init__(self):
|
384 |
+
self._task_preprocessing_registry: Dict[str, Callable] = {}
|
385 |
+
|
386 |
+
def register(self, *names: str) -> Callable[[Callable], Callable]:
|
387 |
+
"""Decorator for registering preprocessing functions."""
|
388 |
+
|
389 |
+
def _register_func(name: str, func: Callable) -> None:
|
390 |
+
if name in self._task_preprocessing_registry:
|
391 |
+
raise ValueError(f'A tokenization function has already been registered with name={name!r}.')
|
392 |
+
self._task_preprocessing_registry[name] = func
|
393 |
+
return
|
394 |
+
|
395 |
+
def wrapper(func: Callable) -> Callable:
|
396 |
+
for name in names:
|
397 |
+
_register_func(name, func)
|
398 |
+
return func
|
399 |
+
return wrapper
|
400 |
+
|
401 |
+
def print_registered_tasks(self) -> None:
|
402 |
+
tasks = sorted(self._task_preprocessing_registry.keys())
|
403 |
+
log.info('\n'.join(tasks))
|
404 |
+
|
405 |
+
def get_preprocessing_fn_from_dict(self, mapping: Dict[str, str]) -> Callable[[Dict[str, Any]], Dict[str, str]]:
|
406 |
+
"""Get a preprocessing function from a dictionary.
|
407 |
+
|
408 |
+
The dictionary maps column names in the dataset to "prompt" and "response".
|
409 |
+
For example,
|
410 |
+
```yaml
|
411 |
+
preprocessing_fn:
|
412 |
+
prompt: text
|
413 |
+
response: summary
|
414 |
+
```
|
415 |
+
would map the `text` column as to prompt and the `summary` column as the response.
|
416 |
+
|
417 |
+
Args:
|
418 |
+
mapping (dict): A dictionary mapping column names to "prompt" and "response".
|
419 |
+
|
420 |
+
Returns:
|
421 |
+
Callable: The preprocessing function.
|
422 |
+
|
423 |
+
Raises:
|
424 |
+
ValueError: If the mapping does not have keys "prompt" and "response".
|
425 |
+
"""
|
426 |
+
|
427 |
+
def _preprocessor(example: Dict[str, Any]) -> Dict[str, str]:
|
428 |
+
if list(mapping.keys()) != ['prompt', 'response']:
|
429 |
+
raise InvalidPromptResponseKeysError(mapping, example)
|
430 |
+
return {'prompt': example[mapping['prompt']], 'response': example[mapping['response']]}
|
431 |
+
return _preprocessor
|
432 |
+
|
433 |
+
def get_preprocessing_fn_from_str(self, preprocessor: Optional[str], dataset_name: Optional[str]=None) -> Optional[Callable[[Dict[str, Any]], Dict[str, str]]]:
|
434 |
+
"""Get a preprocessing function from a string.
|
435 |
+
|
436 |
+
String can be either a registered function or an import path.
|
437 |
+
|
438 |
+
Args:
|
439 |
+
preprocessor (Optional[str]): The name of the preprocessing function, or an import path.
|
440 |
+
dataset_name (Optional[str]): The dataset name to look up in the registry.
|
441 |
+
|
442 |
+
Returns:
|
443 |
+
Callable: The preprocessing function or None if not found.
|
444 |
+
|
445 |
+
Raises:
|
446 |
+
ValueError: If the preprocessing function import from the provided string fails.
|
447 |
+
"""
|
448 |
+
if preprocessor is None:
|
449 |
+
if dataset_name is None:
|
450 |
+
return None
|
451 |
+
if dataset_name in self._task_preprocessing_registry:
|
452 |
+
log.info(f'Re-formatting dataset with "{dataset_name}" preprocessing function.')
|
453 |
+
return self._task_preprocessing_registry[dataset_name]
|
454 |
+
else:
|
455 |
+
log.info('No preprocessor was supplied and no preprocessing function ' + f'is registered for dataset name "{dataset_name}". No additional ' + 'preprocessing will be applied. If the dataset is already formatted ' + 'correctly, you can ignore this message.')
|
456 |
+
return None
|
457 |
+
if preprocessor in self._task_preprocessing_registry:
|
458 |
+
log.info(f'Re-formatting dataset with "{preprocessor}" preprocessing function.')
|
459 |
+
return self._task_preprocessing_registry[preprocessor]
|
460 |
+
try:
|
461 |
+
import_path, function_name = preprocessor.split(':', maxsplit=1)
|
462 |
+
module = importlib.import_module(import_path)
|
463 |
+
preprocessing_fn = getattr(module, function_name)
|
464 |
+
except Exception as e:
|
465 |
+
raise ValueError(f'Failed to import preprocessing function from string = {preprocessor}.') from e
|
466 |
+
return preprocessing_fn
|
467 |
+
|
468 |
+
def build_from_hf(self, dataset_name: str, split: str, safe_load: bool, max_seq_len: int, preprocessing_fn: Optional[Callable[[dict[str, Any]], dict[str, str]]], tokenizer: PreTrainedTokenizerBase, target_prompts: str, target_responses: str, decoder_only_format: bool, hf_kwargs: Dict[str, Any]) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset, hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]:
|
469 |
+
"""Load a HuggingFace Datasets, preprocess, and tokenize.
|
470 |
+
|
471 |
+
Note: This function will drop examples where the prompt is longer than the max_seq_len
|
472 |
+
|
473 |
+
Args:
|
474 |
+
cfg (DictConfig): The dataset configuration.
|
475 |
+
max_seq_len (int): The maximum sequence length. Examples with prompts longer than this will be dropped.
|
476 |
+
tokenizer (Tokenizer): The tokenizer to be used for tokenizing the dataset.
|
477 |
+
|
478 |
+
Returns:
|
479 |
+
Dataset: The tokenized dataset.
|
480 |
+
"""
|
481 |
+
signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_data_prep_completed'
|
482 |
+
if dist.get_local_rank() != 0:
|
483 |
+
log.debug('Waiting for local_rank 0 to finish data prep')
|
484 |
+
with dist.local_rank_zero_download_and_wait(signal_file_path):
|
485 |
+
pass
|
486 |
+
hf_tokenization_logger = logging.getLogger('transformers.tokenization_utils_base')
|
487 |
+
sequence_length_warning_filter = SpecificWarningFilter('Token indices sequence length is longer than the specified maximum sequence length')
|
488 |
+
hf_tokenization_logger.addFilter(sequence_length_warning_filter)
|
489 |
+
error: Optional[Exception] = None
|
490 |
+
filtered_dataset = None
|
491 |
+
try:
|
492 |
+
if safe_load:
|
493 |
+
if not os.path.isdir(dataset_name):
|
494 |
+
local_dataset_dir = os.path.join(DOWNLOADED_FT_DATASETS_DIRPATH, dataset_name)
|
495 |
+
if _is_empty_or_nonexistent(dirpath=local_dataset_dir):
|
496 |
+
hf_hub.snapshot_download(dataset_name, repo_type='dataset', allow_patterns=['*' + ext for ext in SUPPORTED_EXTENSIONS], token=hf_kwargs.get('token', None), revision=hf_kwargs.get('revision', None), local_dir_use_symlinks=False, local_dir=local_dataset_dir)
|
497 |
+
if _is_empty_or_nonexistent(dirpath=local_dataset_dir):
|
498 |
+
raise InvalidFileExtensionError(dataset_name, SUPPORTED_EXTENSIONS)
|
499 |
+
dataset_name = local_dataset_dir
|
500 |
+
dataset_name = os.path.abspath(dataset_name)
|
501 |
+
dataset_files = [f for _, _, files in os.walk(dataset_name) for f in files]
|
502 |
+
if not all((Path(f).suffix in SUPPORTED_EXTENSIONS for f in dataset_files)):
|
503 |
+
raise InvalidFileExtensionError(dataset_name, SUPPORTED_EXTENSIONS)
|
504 |
+
dataset = hf_datasets.load_dataset(dataset_name, split=split, **hf_kwargs)
|
505 |
+
|
506 |
+
def dataset_mapper(example: Dict):
|
507 |
+
if preprocessing_fn is not None:
|
508 |
+
example = preprocessing_fn(example)
|
509 |
+
return tokenize_formatted_example(example, tokenizer)
|
510 |
+
detected_cpu_count = os.cpu_count() or 1
|
511 |
+
detected_cpus_with_margin = detected_cpu_count - 8
|
512 |
+
num_cpus_to_use = max(1, detected_cpus_with_margin)
|
513 |
+
columns_to_remove = list(dataset[0].keys())
|
514 |
+
tokenized_dataset = dataset.map(dataset_mapper, batched=False, remove_columns=columns_to_remove, num_proc=num_cpus_to_use, desc='Tokenizing dataset')
|
515 |
+
filtered_dataset = tokenized_dataset.filter(partial(is_valid_ift_example, max_seq_len, target_prompts, target_responses, decoder_only_format), num_proc=num_cpus_to_use, desc='Filtering out long prompts')
|
516 |
+
examples_removed = len(tokenized_dataset) - len(filtered_dataset)
|
517 |
+
if examples_removed > 0:
|
518 |
+
warnings.warn(f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, ' + 'the prompt or response was empty, or the response was all padding tokens.')
|
519 |
+
except Exception as e:
|
520 |
+
error = e
|
521 |
+
if dist.get_local_rank() == 0:
|
522 |
+
log.debug('Local rank 0 finished data prep')
|
523 |
+
with open(signal_file_path, 'wb') as f:
|
524 |
+
f.write(b'local_rank0_completed_data_prep')
|
525 |
+
dist.barrier()
|
526 |
+
if dist.get_local_rank() == 0:
|
527 |
+
os.remove(signal_file_path)
|
528 |
+
if error is not None:
|
529 |
+
log.error('Error during data prep')
|
530 |
+
raise error
|
531 |
+
log.debug('All ranks finished data prep')
|
532 |
+
hf_tokenization_logger.removeFilter(sequence_length_warning_filter)
|
533 |
+
assert filtered_dataset is not None
|
534 |
+
return filtered_dataset
|
535 |
+
|
536 |
+
def build_from_streaming(self, *args: Any, **kwargs: Any) -> StreamingFinetuningDataset:
|
537 |
+
return StreamingFinetuningDataset(*args, **kwargs)
|
538 |
+
dataset_constructor = DatasetConstructor()
|
539 |
+
|
540 |
+
@dataset_constructor.register('tatsu-lab/alpaca')
|
541 |
+
def alpaca_preprocessing_function(inp: Dict) -> Dict[str, str]:
|
542 |
+
"""Split out prompt/response from text."""
|
543 |
+
try:
|
544 |
+
prompt, response = inp['text'].split('### Response:')
|
545 |
+
prompt += '### Response:'
|
546 |
+
except Exception as e:
|
547 |
+
raise UnableToProcessPromptResponseError(inp) from e
|
548 |
+
return {'prompt': prompt, 'response': response}
|
549 |
+
|
550 |
+
@dataset_constructor.register('HuggingFaceH4/databricks_dolly_15k')
|
551 |
+
def dolly_preprocessing_function(inp: Dict) -> Dict[str, str]:
|
552 |
+
"""Format the text string."""
|
553 |
+
PROMPT_FORMAT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n'
|
554 |
+
try:
|
555 |
+
if inp['input'] != '':
|
556 |
+
instruction = inp['instruction'] + '\n' + inp['input']
|
557 |
+
else:
|
558 |
+
instruction = inp['instruction']
|
559 |
+
prompt = PROMPT_FORMAT.format(instruction=instruction)
|
560 |
+
response = inp['output']
|
561 |
+
except Exception as e:
|
562 |
+
raise UnableToProcessPromptResponseError(inp) from e
|
563 |
+
return {'prompt': prompt, 'response': response}
|
564 |
+
|
565 |
+
@dataset_constructor.register('bigscience/P3')
|
566 |
+
def p3_preprocessing_function(inp: Dict) -> Dict[str, str]:
|
567 |
+
"""Format the already-split example."""
|
568 |
+
return {'prompt': inp['inputs'] + ':', 'response': inp['targets']}
|
569 |
+
|
570 |
+
@dataset_constructor.register('Muennighoff/P3', 'Muennighoff/flan')
|
571 |
+
def muennighoff_tokenize_function(inp: Dict) -> Dict[str, str]:
|
572 |
+
"""Format the already-split example."""
|
573 |
+
try:
|
574 |
+
prompt: str = inp['inputs']
|
575 |
+
response: str = inp['targets']
|
576 |
+
transitions = (' ', '\n', '\t')
|
577 |
+
if not (prompt.endswith(transitions) or response.startswith(transitions)):
|
578 |
+
response = ' ' + response
|
579 |
+
except Exception as e:
|
580 |
+
raise UnableToProcessPromptResponseError(inp) from e
|
581 |
+
return {'prompt': prompt, 'response': response}
|
text_data.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Build a StreamingTextDataset dataset and dataloader for training."""
|
2 |
+
import os
|
3 |
+
from itertools import islice
|
4 |
+
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import transformers
|
8 |
+
from streaming import Stream, StreamingDataset
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
from transformers import PreTrainedTokenizerBase
|
11 |
+
|
12 |
+
class StreamingTextDataset(StreamingDataset):
|
13 |
+
"""Generic text dataset using MosaicML's StreamingDataset.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
tokenizer (Tokenizer): HuggingFace tokenizer to
|
17 |
+
tokenize samples.
|
18 |
+
max_seq_len (int): The max sequence length of each sample.
|
19 |
+
streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from,
|
20 |
+
which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or
|
21 |
+
``remote``/``local``. Defaults to ``None``.
|
22 |
+
remote (str, optional): Remote path or directory to download the dataset from. If ``None``,
|
23 |
+
its data must exist locally. StreamingDataset uses either ``streams`` or
|
24 |
+
``remote``/``local``. Defaults to ``None``.
|
25 |
+
local (str, optional): Local working directory to download shards to. This is where shards
|
26 |
+
are cached while they are being used. Uses a temp directory if not set.
|
27 |
+
StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``.
|
28 |
+
split (str, optional): Which dataset split to use, if any. If provided, we stream from/to
|
29 |
+
the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``.
|
30 |
+
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
|
31 |
+
download_timeout (float): Number of seconds to wait for a shard to download before raising
|
32 |
+
an exception. Defaults to ``60``.
|
33 |
+
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
|
34 |
+
shards. Defaults to ``None``.
|
35 |
+
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
|
36 |
+
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
|
37 |
+
`False``.
|
38 |
+
epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced across all
|
39 |
+
streams. If ``None``, takes its value from the total number of underlying samples.
|
40 |
+
Provide this field if you are weighting streams relatively to target a larger or
|
41 |
+
smaller epoch size. Defaults to ``None``.
|
42 |
+
predownload (int, optional): Target number of samples ahead to download the shards of while
|
43 |
+
iterating. If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``.
|
44 |
+
cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's
|
45 |
+
shard cache. Before downloading a shard, the least recently used resident shard(s) may
|
46 |
+
be evicted (deleted from the local cache) in order to stay under the limit. Set to None
|
47 |
+
to disable shard eviction. Supports integer bytes as well as string human-readable
|
48 |
+
bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None.
|
49 |
+
partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
|
50 |
+
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
|
51 |
+
resumption. If ``None``, this is interpreted as 64 times the number of physical
|
52 |
+
nodes of the initial run if ``shuffle_algo`` is ``py1s`` or ``py2s``, and simply the
|
53 |
+
number of physical nodes of the initial run otherwise. Defaults to ``None``.
|
54 |
+
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
|
55 |
+
partitioned over the workers. Defaults to ``None``.
|
56 |
+
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
|
57 |
+
``False``.
|
58 |
+
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``.
|
59 |
+
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
|
60 |
+
shuffle_block_size (int, optional): Unit of shuffle. A canonical node's samples are split
|
61 |
+
into blocks of this size, and samples within each block are shuffled. If ``None``, its
|
62 |
+
value is calculated as ``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to
|
63 |
+
``None``.
|
64 |
+
sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
|
65 |
+
Defaults to ``balanced``.
|
66 |
+
sampling_granularity (int): When picking samples for a stream's final partial repeat,
|
67 |
+
how many samples to pick from the same shard at a time (``1`` for evenly balanced
|
68 |
+
across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc).
|
69 |
+
Defaults to ``1``.
|
70 |
+
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
|
71 |
+
``per_stream``. Defaults to ``random``.
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(self, tokenizer: PreTrainedTokenizerBase, max_seq_len: int, streams: Optional[Sequence[Stream]]=None, remote: Optional[str]=None, local: Optional[str]=None, split: Optional[str]=None, download_retry: int=2, download_timeout: float=60, validate_hash: Optional[str]=None, keep_zip: bool=False, epoch_size: Optional[Union[int, str]]=None, predownload: Optional[int]=None, cache_limit: Optional[Union[int, str]]=None, partition_algo: str='relaxed', num_canonical_nodes: Optional[int]=None, batch_size: Optional[int]=None, shuffle: bool=False, shuffle_algo: str='py1e', shuffle_seed: int=9176, shuffle_block_size: Optional[int]=None, sampling_method: str='balanced', sampling_granularity: int=1, batching_method: str='random', **kwargs: Any):
|
75 |
+
if len(kwargs) > 0:
|
76 |
+
raise ValueError(f'StreamingTextDataset() got an unexpected keyword argument: {kwargs}')
|
77 |
+
if local is not None and (remote is None or local == remote):
|
78 |
+
if os.path.isdir(local):
|
79 |
+
contents = set(os.listdir(local))
|
80 |
+
if split not in contents:
|
81 |
+
raise ValueError(f'local directory {local} does not contain split {split}')
|
82 |
+
if isinstance(shuffle_block_size, float):
|
83 |
+
shuffle_block_size = int(shuffle_block_size)
|
84 |
+
super().__init__(streams=streams, remote=remote, local=local, split=split, download_retry=download_retry, download_timeout=download_timeout, validate_hash=validate_hash, keep_zip=keep_zip, epoch_size=epoch_size, predownload=predownload, cache_limit=cache_limit, partition_algo=partition_algo, num_canonical_nodes=num_canonical_nodes, batch_size=batch_size, shuffle=shuffle, shuffle_algo=shuffle_algo, shuffle_seed=shuffle_seed, shuffle_block_size=shuffle_block_size, sampling_method=sampling_method, sampling_granularity=sampling_granularity, batching_method=batching_method)
|
85 |
+
self.tokenizer = tokenizer
|
86 |
+
self.max_seq_len = max_seq_len
|
87 |
+
|
88 |
+
def _tokenize(self, text_sample: Mapping) -> Dict[str, List[int]]:
|
89 |
+
if self.tokenizer._pad_token is None:
|
90 |
+
raise RuntimeError('If tokenizing on-the-fly, tokenizer must have a pad_token_id')
|
91 |
+
return self.tokenizer(text_sample['text'], truncation=True, padding='max_length', max_length=self.max_seq_len)
|
92 |
+
|
93 |
+
def _read_binary_tokenized_sample(self, sample: Dict[str, Any]) -> torch.Tensor:
|
94 |
+
return torch.from_numpy(np.frombuffer(sample['tokens'], dtype=np.int64)[:self.max_seq_len].copy())
|
95 |
+
|
96 |
+
def __getitem__(self, idx: int) -> Union[Dict[str, List[int]], torch.Tensor]:
|
97 |
+
sample = super().__getitem__(idx)
|
98 |
+
if 'text' in sample:
|
99 |
+
token_sample = self._tokenize(sample)
|
100 |
+
elif 'tokens' in sample:
|
101 |
+
token_sample = self._read_binary_tokenized_sample(sample)
|
102 |
+
else:
|
103 |
+
raise RuntimeError('StreamingTextDataset needs samples to have a `text` or `tokens` column')
|
104 |
+
return token_sample
|
105 |
+
|
106 |
+
class ConcatenatedSequenceCollatorWrapper:
|
107 |
+
"""Collator wrapper to add sequence_id to batch."""
|
108 |
+
|
109 |
+
def __init__(self, base_collator: Callable, eos_token_id: Optional[int]=None, bos_token_id: Optional[int]=None):
|
110 |
+
self.base_collator = base_collator
|
111 |
+
if eos_token_id is None and bos_token_id is None:
|
112 |
+
raise ValueError('Must supply a value for either eos_token_id or bos_token_id, but got None for both.')
|
113 |
+
if eos_token_id is not None and bos_token_id is not None:
|
114 |
+
raise ValueError('Cannot use *both* EOS and BOS tokens for detecting sequence boundaries. ' + 'Please supply `eos_token_id` if sequences end with an EOS token, or use ' + '`bos_token_id` if sequences start with a BOS token.')
|
115 |
+
if eos_token_id is None:
|
116 |
+
self.split_token_id = cast(int, bos_token_id)
|
117 |
+
self.bos_mode = True
|
118 |
+
else:
|
119 |
+
self.split_token_id = eos_token_id
|
120 |
+
self.bos_mode = False
|
121 |
+
|
122 |
+
def __call__(self, examples: List[Any]) -> Dict[str, torch.Tensor]:
|
123 |
+
batch = self.base_collator(examples)
|
124 |
+
batch['sequence_id'] = self.get_sequence_id_from_batch(batch)
|
125 |
+
return batch
|
126 |
+
|
127 |
+
def get_sequence_id_from_batch(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
128 |
+
is_separator = torch.eq(batch['input_ids'], self.split_token_id)
|
129 |
+
cumulative_sep = torch.cumsum(is_separator, dim=1).to(batch['input_ids'].dtype)
|
130 |
+
if self.bos_mode:
|
131 |
+
return cumulative_sep
|
132 |
+
left_zeros = cumulative_sep.new_zeros((cumulative_sep.shape[0], 1))
|
133 |
+
return torch.cat([left_zeros, cumulative_sep[:, :-1]], dim=1)
|
134 |
+
|
135 |
+
def build_streams(dataset_cfg: DictConfig):
|
136 |
+
streams_dict = dataset_cfg.pop('streams', None)
|
137 |
+
streams = None
|
138 |
+
if streams_dict is not None:
|
139 |
+
streams = []
|
140 |
+
for _, stream in streams_dict.items():
|
141 |
+
streams.append(Stream(**stream))
|
142 |
+
return streams
|
143 |
+
|
144 |
+
def build_text_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int) -> DataSpec:
|
145 |
+
assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}'
|
146 |
+
mlm_probability = cfg.dataset.pop('mlm_probability', None)
|
147 |
+
eos_token_id = cfg.dataset.pop('eos_token_id', None)
|
148 |
+
bos_token_id = cfg.dataset.pop('bos_token_id', None)
|
149 |
+
streams = build_streams(cfg.dataset)
|
150 |
+
dataset = StreamingTextDataset(tokenizer=tokenizer, streams=streams, batch_size=device_batch_size, **cfg.dataset)
|
151 |
+
collate_fn = transformers.DataCollatorForLanguageModeling(tokenizer=dataset.tokenizer, mlm=mlm_probability is not None, mlm_probability=mlm_probability)
|
152 |
+
if eos_token_id is not None or bos_token_id is not None:
|
153 |
+
collate_fn = ConcatenatedSequenceCollatorWrapper(base_collator=collate_fn, eos_token_id=eos_token_id, bos_token_id=bos_token_id)
|
154 |
+
dl = DataLoader(dataset, collate_fn=collate_fn, batch_size=device_batch_size, drop_last=cfg.drop_last, num_workers=cfg.num_workers, pin_memory=cfg.get('pin_memory', True), prefetch_factor=cfg.get('prefetch_factor', 2), persistent_workers=cfg.get('persistent_workers', True), timeout=cfg.get('timeout', 0))
|
155 |
+
token_counting_func = None
|
156 |
+
if tokenizer.pad_token_id is not None:
|
157 |
+
token_counting_func = get_tokens_per_batch_func()
|
158 |
+
return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)
|
159 |
+
|
160 |
+
def get_tokens_per_batch_func(decoder_only: bool=True) -> Callable[[Batch], int]:
|
161 |
+
"""Returns a callable that counts the number of tokens in a batch.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
pad_token_id (int): The id of the padding token.
|
165 |
+
decoder_only (bool, optional): Whether to expect the batch to just contain ``input_ids`` (decoder only)
|
166 |
+
or to also contain ``decoder_input_ids`` (encoder decoder). Defaults to ``True``.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
Callable[[Batch], int]: A callable that counts the number of tokens in a batch.
|
170 |
+
"""
|
171 |
+
|
172 |
+
def get_num_samples_in_batch(batch: Batch) -> int:
|
173 |
+
if not isinstance(batch, Mapping) or ('attention_mask' not in batch and 'input_ids' not in batch):
|
174 |
+
raise ValueError('get_tokens_per_batch_func() requires a batch with an attention_mask key or an input_ids key')
|
175 |
+
if not decoder_only and 'decoder_attention_mask' not in batch:
|
176 |
+
raise ValueError('get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_attention_mask key')
|
177 |
+
if 'attention_mask' in batch:
|
178 |
+
input_ids_tokens = int(torch.sum(batch['attention_mask']).item())
|
179 |
+
else:
|
180 |
+
input_ids_tokens = batch['input_ids'].numel()
|
181 |
+
decoder_input_ids_tokens = 0
|
182 |
+
if not decoder_only:
|
183 |
+
decoder_input_ids_tokens = int(torch.sum(batch['decoder_attention_mask']).item())
|
184 |
+
return input_ids_tokens + decoder_input_ids_tokens
|
185 |
+
return get_num_samples_in_batch
|
186 |
+
if __name__ == '__main__':
|
187 |
+
import argparse
|
188 |
+
from .builders import build_tokenizer
|
189 |
+
parser = argparse.ArgumentParser()
|
190 |
+
parser.add_argument('--tokenizer', type=str, default='EleutherAI/gpt-neox-20b', help='the name of the tokenizer to use')
|
191 |
+
parser.add_argument('--local_path', type=str, required=True, help='the path to the local copy of the dataset')
|
192 |
+
parser.add_argument('--remote_path', type=str, default=None, help='the path to the remote copy to stream from (optional)')
|
193 |
+
parser.add_argument('--split', type=str, default='val', help='which split of the dataset to use')
|
194 |
+
parser.add_argument('--max_seq_len', type=int, default=32, help='max sequence length to test')
|
195 |
+
args = parser.parse_args()
|
196 |
+
if args.remote_path is not None:
|
197 |
+
print(f'Reading {args.split} split from {args.local_path} <- streamed from <- {args.remote_path}')
|
198 |
+
else:
|
199 |
+
print(f'Reading {args.split} split from {args.local_path}')
|
200 |
+
cfg = {'name': 'text', 'dataset': {'local': args.local_path, 'remote': args.remote_path, 'split': args.split, 'shuffle': False, 'max_seq_len': args.max_seq_len, 'keep_zip': True}, 'drop_last': False, 'num_workers': 4}
|
201 |
+
cfg = om.create(cfg)
|
202 |
+
device_batch_size = 2
|
203 |
+
tokenizer_name = args.tokenizer
|
204 |
+
tokenizer_kwargs = {'model_max_length': args.max_seq_len}
|
205 |
+
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
|
206 |
+
loader = build_text_dataloader(cfg, tokenizer, device_batch_size).dataloader
|
207 |
+
assert isinstance(loader, DataLoader)
|
208 |
+
assert isinstance(loader.dataset, StreamingTextDataset)
|
209 |
+
tokenizer = loader.dataset.tokenizer
|
210 |
+
for batch_ix, batch in enumerate(islice(loader, 5)):
|
211 |
+
print('\n')
|
212 |
+
print('#' * 20, f'Batch {batch_ix}', '#' * 20)
|
213 |
+
for k, v in batch.items():
|
214 |
+
print(k, v.shape, v.dtype)
|
215 |
+
for sample_ix, token_sample in enumerate(batch['input_ids']):
|
216 |
+
print('-' * 20, f' Sample {sample_ix} ', '-' * 20)
|
217 |
+
print(tokenizer.decode(token_sample))
|
tiktoken.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
from typing import Any, Dict, List, Optional, Tuple
|
3 |
+
from transformers import PreTrainedTokenizer
|
4 |
+
DEFAULT_SYSTEM_PROMPT = 'You are a helpful, respectful and honest assistant. Always answer as helpfully as possible.'
|
5 |
+
|
6 |
+
@lru_cache()
|
7 |
+
def bytes_to_unicode():
|
8 |
+
"""Returns list of utf-8 byte and a mapping to unicode strings.
|
9 |
+
|
10 |
+
We specifically avoids mapping to whitespace/control characters the bpe code
|
11 |
+
barfs on.
|
12 |
+
|
13 |
+
The reversible bpe codes work on unicode strings. This means you need a
|
14 |
+
large # of unicode characters in your vocab if you want to avoid UNKs. When
|
15 |
+
you're at something like a 10B token dataset you end up needing around 5K
|
16 |
+
for decent coverage. This is a significant percentage of your normal, say,
|
17 |
+
32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and
|
18 |
+
unicode strings.
|
19 |
+
"""
|
20 |
+
bs = list(range(ord('!'), ord('~') + 1)) + list(range(ord('¡'), ord('¬') + 1)) + list(range(ord('®'), ord('ÿ') + 1))
|
21 |
+
cs = bs[:]
|
22 |
+
n = 0
|
23 |
+
for b in range(2 ** 8):
|
24 |
+
if b not in bs:
|
25 |
+
bs.append(b)
|
26 |
+
cs.append(2 ** 8 + n)
|
27 |
+
n += 1
|
28 |
+
cs = [chr(n) for n in cs]
|
29 |
+
return dict(zip(bs, cs))
|
30 |
+
|
31 |
+
class TiktokenTokenizerWrapper(PreTrainedTokenizer):
|
32 |
+
"""A thin wrapper around tiktoken to make it compatible with Hugging Face.
|
33 |
+
|
34 |
+
tokenizers.
|
35 |
+
|
36 |
+
See HuggingFace for further documentation on general tokenizer methods.
|
37 |
+
"""
|
38 |
+
model_input_names = ['input_ids', 'attention_mask']
|
39 |
+
|
40 |
+
def __init__(self, model_name: Optional[str]=None, encoding_name: Optional[str]=None, add_bos_token: bool=False, add_eos_token: bool=False, use_default_system_prompt: bool=False, unk_token: Optional[str]='<|endoftext|>', eos_token: Optional[str]='<|endoftext|>', bos_token: Optional[str]='<|endoftext|>', pad_token: Optional[str]=None, errors: str='replace', **kwargs: Any):
|
41 |
+
"""Constructor creates a tiktoken tokenizer to use as the underlying.
|
42 |
+
|
43 |
+
tokenizer.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
model_name (Optional[str], optional): The name of the model to load from tiktoken. Defaults to None.
|
47 |
+
Either model_name or encoding_name must be set, but not both.
|
48 |
+
encoding_name (Optional[str], optional): The name of the encoding to load from tiktoken. Defaults to None.
|
49 |
+
Either model_name or encoding_name must be set, but not both.
|
50 |
+
add_bos_token (bool, optional): Whether to add bos tokens. Defaults to False.
|
51 |
+
add_eos_token (bool, optional): Whether to add eos tokens. Defaults to False.
|
52 |
+
use_default_system_prompt (bool, optional): Use the default system prompt or not. Defaults to False.
|
53 |
+
unk_token (Optional[str], optional): The unk token. Defaults to '<|endoftext|>'.
|
54 |
+
eos_token (Optional[str], optional): The eos token. Defaults to '<|endoftext|>'.
|
55 |
+
bos_token (Optional[str], optional): The bos token. Defaults to '<|endoftext|>'.
|
56 |
+
pad_token (Optional[str], optional): The pad token. Defaults to None.
|
57 |
+
errors (str, optional): Paradigm to follow when decoding bytes to UTF-8. See
|
58 |
+
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
59 |
+
Defaults to `"replace"`.
|
60 |
+
"""
|
61 |
+
try:
|
62 |
+
import tiktoken
|
63 |
+
except:
|
64 |
+
raise ImportError('You need to install tiktoken to use TiktokenTokenizerWrapper.')
|
65 |
+
import copyreg
|
66 |
+
import functools
|
67 |
+
from tiktoken import Encoding
|
68 |
+
|
69 |
+
def pickle_Encoding(enc: Encoding):
|
70 |
+
return (functools.partial(Encoding, enc.name, pat_str=enc._pat_str, mergeable_ranks=enc._mergeable_ranks, special_tokens=enc._special_tokens), ())
|
71 |
+
copyreg.pickle(Encoding, pickle_Encoding)
|
72 |
+
if model_name is not None and encoding_name is not None:
|
73 |
+
raise ValueError('You need to specify either model_name or encoding_name, not both.')
|
74 |
+
self.model_name = model_name
|
75 |
+
self.encoding_name = encoding_name
|
76 |
+
if self.model_name is not None:
|
77 |
+
self.encoding = tiktoken.encoding_for_model(self.model_name)
|
78 |
+
elif self.encoding_name is not None:
|
79 |
+
self.encoding = tiktoken.get_encoding(self.encoding_name)
|
80 |
+
else:
|
81 |
+
raise ValueError('You need to specify either model_name or encoding_name.')
|
82 |
+
self.add_bos_token = add_bos_token
|
83 |
+
self.add_eos_token = add_eos_token
|
84 |
+
self.use_default_system_prompt = use_default_system_prompt
|
85 |
+
self.byte_encoder = bytes_to_unicode()
|
86 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
87 |
+
self.errors = errors
|
88 |
+
self.decoder: Dict[int, str] = {}
|
89 |
+
for i in range(self.encoding.n_vocab):
|
90 |
+
try:
|
91 |
+
self.encoding.decode_single_token_bytes(i)
|
92 |
+
except KeyError:
|
93 |
+
continue
|
94 |
+
decoding = ''.join([bytes_to_unicode()[ord(char)] for char in self.encoding.decode_single_token_bytes(i).decode('latin-1')])
|
95 |
+
self.decoder[i] = decoding
|
96 |
+
self.encoder: Dict[str, int] = {}
|
97 |
+
for i in range(self.encoding.n_vocab):
|
98 |
+
if i in self.decoder:
|
99 |
+
self.encoder[self.decoder[i]] = i
|
100 |
+
super().__init__(model_name=model_name, encoding_name=encoding_name, add_bos_token=add_bos_token, add_eos_token=add_eos_token, use_default_system_prompt=use_default_system_prompt, unk_token=unk_token, eos_token=eos_token, bos_token=bos_token, pad_token=pad_token, errors=errors, **kwargs)
|
101 |
+
|
102 |
+
@property
|
103 |
+
def vocab_size(self) -> int:
|
104 |
+
"""Returns vocab size."""
|
105 |
+
return self.encoding.n_vocab
|
106 |
+
|
107 |
+
@property
|
108 |
+
def is_fast(self) -> bool:
|
109 |
+
return False
|
110 |
+
|
111 |
+
@property
|
112 |
+
def default_chat_template(self):
|
113 |
+
"""Chat ML Template for User/Assistant.
|
114 |
+
|
115 |
+
Pinning default Chat ML template in case defaults change.
|
116 |
+
"""
|
117 |
+
template = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not 'system' in messages[0]['role'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_PROMPT' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% if system_message != false %}{{ '<|im_start|>system\n' + system_message.strip() + '<|im_end|>\n'}}{% endif %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}{% else %}{{ '\n' + '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}{% endif %}{% if (add_generation_prompt == true and loop.last) %}{{ '\n' + '<|im_start|>' + 'assistant' + '\n' }}{% endif %}{% endfor %}"
|
118 |
+
template = template.replace('USE_DEFAULT_PROMPT', 'true' if self.use_default_system_prompt else 'false')
|
119 |
+
template = template.replace('DEFAULT_SYSTEM_PROMPT', DEFAULT_SYSTEM_PROMPT)
|
120 |
+
return template
|
121 |
+
|
122 |
+
def get_vocab(self) -> Dict[str, int]:
|
123 |
+
"""Returns vocab as a dict."""
|
124 |
+
vocab_clone = self.encoder.copy()
|
125 |
+
extra_id_index = 0
|
126 |
+
candidate_extra_id = f'<extra_id_{extra_id_index}>'
|
127 |
+
indices_to_fill_in = {i for i in range(self.vocab_size)} - set(vocab_clone.values())
|
128 |
+
for index_to_add in indices_to_fill_in:
|
129 |
+
while candidate_extra_id in vocab_clone:
|
130 |
+
extra_id_index += 1
|
131 |
+
candidate_extra_id = f'<extra_id_{extra_id_index}>'
|
132 |
+
vocab_clone[candidate_extra_id] = index_to_add
|
133 |
+
return vocab_clone
|
134 |
+
|
135 |
+
def _tokenize(self, text: str) -> List[str]:
|
136 |
+
"""Returns a tokenized string."""
|
137 |
+
if not isinstance(text, str):
|
138 |
+
raise ValueError(f'Expected a string input to _tokenize but got {type(text)}.')
|
139 |
+
tokens = [self.decoder[t] for t in self.encoding.encode(text, allowed_special='all')]
|
140 |
+
return tokens
|
141 |
+
|
142 |
+
def _convert_token_to_id(self, token: str) -> Optional[int]:
|
143 |
+
"""Converts a token (str) in an id using the vocab."""
|
144 |
+
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
145 |
+
|
146 |
+
def _convert_id_to_token(self, index: int) -> Optional[str]:
|
147 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
148 |
+
return self.decoder.get(index, '')
|
149 |
+
|
150 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
151 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
152 |
+
text = ''.join(tokens)
|
153 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
154 |
+
return text
|
155 |
+
|
156 |
+
def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]]=None) -> List[int]:
|
157 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
158 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
159 |
+
output = bos_token_id + token_ids_0 + eos_token_id
|
160 |
+
if token_ids_1 is not None:
|
161 |
+
output = output + bos_token_id + token_ids_1 + eos_token_id
|
162 |
+
return output
|
163 |
+
|
164 |
+
def get_special_tokens_mask(self, token_ids_0: List[int], token_ids_1: Optional[List[int]]=None, already_has_special_tokens: bool=False) -> List[int]:
|
165 |
+
"""Retrieves sequence ids from a token list that has no special tokens.
|
166 |
+
|
167 |
+
Function copied from
|
168 |
+
https://github.com/huggingface/transformers/blob/e3a4bd2bee212a2d0fd9f03b27fe7bfc1debe42d/src/transformers/models/gpt2/tokenization_gpt2.py#L265-L295
|
169 |
+
|
170 |
+
added. This method is called when adding special tokens using the
|
171 |
+
tokenizer `prepare_for_model` or `encode_plus` methods.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
token_ids_0 (`List[int]`):
|
175 |
+
List of IDs.
|
176 |
+
token_ids_1 (`List[int]`, *optional*):
|
177 |
+
Optional second list of IDs for sequence pairs.
|
178 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
179 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
180 |
+
|
181 |
+
Returns:
|
182 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
183 |
+
"""
|
184 |
+
if already_has_special_tokens:
|
185 |
+
return super().get_special_tokens_mask(token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True)
|
186 |
+
bos_token_id = [1] if self.add_bos_token else []
|
187 |
+
eos_token_id = [1] if self.add_eos_token else []
|
188 |
+
if token_ids_1 is None:
|
189 |
+
return bos_token_id + [0] * len(token_ids_0) + eos_token_id
|
190 |
+
return bos_token_id + [0] * len(token_ids_0) + eos_token_id + bos_token_id + [0] * len(token_ids_1) + eos_token_id
|
191 |
+
|
192 |
+
def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]]=None) -> List[int]:
|
193 |
+
sep = [self.sep_token_id]
|
194 |
+
if token_ids_1 is None:
|
195 |
+
return len(token_ids_0 + sep) * [0]
|
196 |
+
return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
197 |
+
|
198 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str]=None) -> Tuple[str]:
|
199 |
+
return (None, None)
|
200 |
+
|
201 |
+
def sanitize_special_tokens(self) -> int:
|
202 |
+
"""Make sure that all the special tokens attributes of the tokenizer.
|
203 |
+
|
204 |
+
(`tokenizer.mask_token`, `tokenizer.cls_token`, etc.) are in the
|
205 |
+
vocabulary.
|
206 |
+
|
207 |
+
Add the missing ones to the vocabulary if needed.
|
208 |
+
|
209 |
+
Return:
|
210 |
+
`int`: The number of tokens added in the vocabulary during the operation.
|
211 |
+
"""
|
212 |
+
actual_new_tokens = []
|
213 |
+
for token in self.all_special_tokens_extended:
|
214 |
+
encoded = self.encoding.encode(token, allowed_special='all')
|
215 |
+
if len(encoded) > 1:
|
216 |
+
actual_new_tokens.append(token)
|
217 |
+
return self.add_tokens(actual_new_tokens, special_tokens=True)
|
218 |
+
TiktokenTokenizerWrapper.register_for_auto_class()
|
tokenizers.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .tiktoken import TiktokenTokenizerWrapper
|
utils.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .builders import build_algorithm, build_callback, build_logger, build_optimizer, build_scheduler, build_tokenizer
|
2 |
+
from .checkpoint_conversion_helpers import convert_and_save_ft_weights, get_hf_tokenizer_from_composer_state_dict, load_tokenizer
|
3 |
+
from .config_utils import calculate_batch_size_info, log_config, pop_config, process_init_device, update_batch_size_info
|
4 |
+
from .data_prep_utils import DownloadingIterable, merge_shard_groups, with_id
|
5 |
+
from .huggingface_hub_utils import edit_files_for_hf_compatibility
|
6 |
+
from .logging_utils import SpecificWarningFilter
|
7 |
+
from .model_download_utils import download_from_hf_hub, download_from_http_fileserver, download_from_oras
|
8 |
+
from .mosaicml_logger_utils import find_mosaicml_logger, log_eval_analytics, log_train_analytics, maybe_create_mosaicml_logger
|
9 |
+
from .prompt_files import load_prompts, load_prompts_from_file
|
10 |
+
from .registry_utils import TypedRegistry, construct_from_registry, create_registry
|
11 |
+
from .warnings import VersionedDeprecationWarning
|
warnings.py
CHANGED
@@ -1,4 +1,8 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
2 |
"""A custom deprecation warning class that includes version information.
|
3 |
|
4 |
Attributes:
|
@@ -10,7 +14,7 @@ class VersionedDeprecationWarning(DeprecationWarning):
|
|
10 |
... warnings.warn(
|
11 |
... VersionedDeprecationWarning(
|
12 |
... "Function XYZ is deprecated.",
|
13 |
-
...
|
14 |
... )
|
15 |
... )
|
16 |
...
|
@@ -19,4 +23,49 @@ class VersionedDeprecationWarning(DeprecationWarning):
|
|
19 |
"""
|
20 |
|
21 |
def __init__(self, message: str, remove_version: str) -> None:
|
22 |
-
super().__init__(message + f' It will be removed in version {remove_version}.')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import warnings
|
3 |
+
from typing import Any, Callable, Type, TypeVar, cast
|
4 |
+
|
5 |
+
class VersionedDeprecationWarning(UserWarning):
|
6 |
"""A custom deprecation warning class that includes version information.
|
7 |
|
8 |
Attributes:
|
|
|
14 |
... warnings.warn(
|
15 |
... VersionedDeprecationWarning(
|
16 |
... "Function XYZ is deprecated.",
|
17 |
+
... remove_version="2.0.0"
|
18 |
... )
|
19 |
... )
|
20 |
...
|
|
|
23 |
"""
|
24 |
|
25 |
def __init__(self, message: str, remove_version: str) -> None:
|
26 |
+
super().__init__(message + f' It will be removed in version {remove_version}.')
|
27 |
+
|
28 |
+
class ExperimentalWarning(Warning):
|
29 |
+
"""A warning for experimental features.
|
30 |
+
|
31 |
+
Attributes:
|
32 |
+
feature_name (str): The name of the experimental feature.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, feature_name: str) -> None:
|
36 |
+
super().__init__(f'{feature_name} is experimental and may change with future versions.')
|
37 |
+
F = TypeVar('F', bound=Callable[..., Any])
|
38 |
+
|
39 |
+
def experimental_function(feature_name: str) -> Callable[[F], F]:
|
40 |
+
"""Decorator to mark a function as experimental.
|
41 |
+
|
42 |
+
The message displayed will be {feature_name} is experimental and may change with future versions.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
feature_name (str): The name of the experimental feature.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
The decorated function.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def decorator(func: Callable):
|
52 |
+
|
53 |
+
@functools.wraps(func)
|
54 |
+
def wrapper(*args: Any, **kwargs: Any):
|
55 |
+
warnings.warn(ExperimentalWarning(feature_name))
|
56 |
+
return func(*args, **kwargs)
|
57 |
+
return cast(F, wrapper)
|
58 |
+
return decorator
|
59 |
+
|
60 |
+
def experimental_class(feature_name: str) -> Callable[[Type], Type]:
|
61 |
+
"""Class decorator to mark a class as experimental."""
|
62 |
+
|
63 |
+
def class_decorator(cls: Type):
|
64 |
+
original_init = cls.__init__
|
65 |
+
|
66 |
+
def new_init(self: Any, *args: Any, **kwargs: Any):
|
67 |
+
warnings.warn(ExperimentalWarning(feature_name))
|
68 |
+
original_init(self, *args, **kwargs)
|
69 |
+
cls.__init__ = new_init
|
70 |
+
return cls
|
71 |
+
return class_decorator
|