mpt-7b-8k-chat / builders.py
irenedea's picture
LLM-foundry update March 26, 2024 23:50:31
fdb2891 verified
raw
history blame
18.7 kB
import contextlib
import functools
import logging
import os
import re
from collections import OrderedDict
from typing import Any, ContextManager, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from .llmfoundry import registry
from .callbacks import EvalGauntlet
from .dataloader import build_dataloader
from .tiktoken import TiktokenTokenizerWrapper
from .registry_utils import construct_from_registry
log = logging.getLogger(__name__)
def build_evaluators(eval_loader_config: Optional[Union[DictConfig, ListConfig]], icl_tasks_config: Optional[Union[str, ListConfig]], eval_gauntlet_config: Optional[Union[str, DictConfig]], *, tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int, icl_seq_len: int, icl_subset_num_batches: Optional[int]) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]:
evaluators = []
if eval_loader_config is not None:
evaluators = build_eval_loaders(eval_loader_config, tokenizer, device_eval_batch_size)
logger_keys = []
eval_gauntlet_callback = None
if icl_tasks_config is not None:
icl_evaluators, logger_keys, eval_gauntlet_callback = build_icl_data_and_gauntlet(icl_tasks_config, eval_gauntlet_config, tokenizer, device_eval_batch_size, icl_seq_len, icl_subset_num_batches)
evaluators.extend(icl_evaluators)
return (evaluators, logger_keys, eval_gauntlet_callback)
def build_eval_loaders(eval_loader_config: Union[DictConfig, ListConfig], tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int) -> List[Evaluator]:
evaluators: List[Evaluator] = []
if isinstance(eval_loader_config, ListConfig):
eval_configs: ListConfig = eval_loader_config
is_multi_eval = True
else:
eval_configs = ListConfig([eval_loader_config])
is_multi_eval = False
for eval_config in eval_configs:
eval_dataloader = build_dataloader(eval_config, tokenizer, device_eval_batch_size)
eval_loader: Evaluator = Evaluator(label=f'eval/{eval_config.label}' if is_multi_eval else 'eval', dataloader=eval_dataloader, metric_names=[])
evaluators.append(eval_loader)
return evaluators
def add_metrics_to_eval_loaders(evaluators: List[Evaluator], metric_names: List[str]) -> List[Evaluator]:
eval_loaders, other_evaluators = ([], [])
for evaluator in evaluators:
if evaluator.metric_names == []:
evaluator.metric_names = metric_names
eval_loaders.append(evaluator)
else:
other_evaluators.append(evaluator)
return eval_loaders + other_evaluators
def build_icl_data_and_gauntlet(icl_tasks_config: Union[str, ListConfig], eval_gauntlet_config: Optional[Union[str, DictConfig]], tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int, icl_seq_len: int, icl_subset_num_batches: Optional[int]=None) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]:
icl_evaluators, logger_keys = build_icl_evaluators(icl_tasks_config, tokenizer, icl_seq_len, device_eval_batch_size, icl_subset_num_batches=icl_subset_num_batches)
eval_gauntlet_cb = None
if eval_gauntlet_config is not None:
if isinstance(eval_gauntlet_config, str):
with open(eval_gauntlet_config, 'r') as icl_f:
eval_gauntlet_cfg = om.load(icl_f)
eval_gauntlet = eval_gauntlet_cfg.eval_gauntlet
elif isinstance(eval_gauntlet_config, DictConfig):
eval_gauntlet = eval_gauntlet_config
else:
raise ValueError(f'Got invalid type for eval_gauntlet_config: {type(eval_gauntlet_config)}')
eval_gauntlet.logger_keys = logger_keys
eval_gauntlet.benchmark_sizes = {e.label: e.dataloader.num_samples for e in icl_evaluators}
eval_gauntlet_cb = EvalGauntlet(**eval_gauntlet)
return (icl_evaluators, logger_keys, eval_gauntlet_cb)
def build_composer_model(name: str, cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, init_context: Optional[ContextManager]=None, master_weights_dtype: Optional[str]=None) -> ComposerModel:
"""Builds a ComposerModel from the registry.
Args:
name (str): Name of the model to build.
cfg (DictConfig): Configuration for the model.
tokenizer (PreTrainedTokenizerBase): Tokenizer to use.
init_context (Optional[ContextManager], optional): Context manager to use for initialization. Defaults to None.
master_weights_dtype (Optional[str], optional): Master weights dtype. Defaults to None.
Returns:
ComposerModel: _description_
"""
if init_context is None:
init_context = contextlib.nullcontext()
with init_context:
model = construct_from_registry(name=name, registry=registry.models, pre_validation_function=ComposerModel, post_validation_function=None, kwargs={'om_model_config': cfg, 'tokenizer': tokenizer})
str_dtype_to_torch_dtype = {'f16': torch.float16, 'float16': torch.float16, 'bf16': torch.bfloat16, 'bfloat16': torch.bfloat16}
if master_weights_dtype is not None:
if master_weights_dtype not in str_dtype_to_torch_dtype:
raise ValueError(f'Invalid master_weights_dtype: {master_weights_dtype}. ' + f'Valid options are: {list(str_dtype_to_torch_dtype.keys())}.')
dtype = str_dtype_to_torch_dtype[master_weights_dtype]
model = model.to(dtype=dtype)
return model
def build_callback(name: str, kwargs: Optional[Dict[str, Any]]=None, config: Any=None) -> Callback:
"""Builds a callback from the registry."""
registry_to_use = registry.callbacks
if name in registry.callbacks_with_config:
if kwargs is None:
kwargs = {}
if 'config' in kwargs:
raise ValueError(f'`config` is a reserved keyword for callbacks with config. Please remove it from the kwargs.')
kwargs['config'] = config
registry_to_use = registry.callbacks_with_config
return construct_from_registry(name=name, registry=registry_to_use, partial_function=True, pre_validation_function=Callback, post_validation_function=None, kwargs=kwargs)
def build_logger(name: str, kwargs: Optional[Dict[str, Any]]=None) -> LoggerDestination:
"""Builds a logger from the registry."""
return construct_from_registry(name=name, registry=registry.loggers, partial_function=True, pre_validation_function=LoggerDestination, post_validation_function=None, kwargs=kwargs)
def build_algorithm(name: str, kwargs: Optional[Dict[str, Any]]=None) -> Algorithm:
"""Builds an algorithm from the registry."""
return construct_from_registry(name=name, registry=registry.algorithms, partial_function=True, pre_validation_function=Algorithm, post_validation_function=None, kwargs=kwargs)
def build_metric(name: str, kwargs: Optional[Dict[str, Any]]=None) -> Metric:
"""Builds a metric from the registry."""
return construct_from_registry(name=name, registry=registry.metrics, partial_function=True, pre_validation_function=Metric, post_validation_function=None, kwargs=kwargs)
def _extract_param_groups(model: torch.nn.Module, optimizer_config: Optional[Dict[str, Any]]=None) -> Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]:
"""Extracts parameter groups defined in the optimizer config.
The optimizer_config defines the optimizer args. It can additionally have key
`disable_grad` which is a string or list of strings. If a string matches a
parameter name, then that parameter will have `requires_grad=False`. This is
useful for freezing parameters. It can additionally have a key
`param_groups` which is a list of dicts. In this dict, key `param_str_match`
defines a string; if a parameter name contains this string, then it will be
in this parameter group. This is useful for grouping parameters together.
The dict can also contain any other key that is a valid optimizer arg.
Note: to handle name overlap conflicts, params are assigned to parameter
groups and added to `param_groups` in the order that `param_str_match` appear
in `param_groups`.
Usage
To disable gradient for all parameters that contain the string "norm" or "bias":
```
optimizer_config: {
"name": "decoupled_lionw",
"lr": 1e-3,
"weight_decay": 1e-2,
"betas": [0.9, 0.999],
"eps": 1e-8,
"disable_grad": ["norm", "bias"]
}
```
To create and modify the optimizer parameters for all parameters that contain
the string "norm" and "bias" separately:
```
optimizer_config: {
"name": "decoupled_lionw",
"lr": 1e-3,
"weight_decay": 1e-2,
"betas": [0.9, 0.999],
"eps": 1e-8,
"param_groups": [
{
"param_str_match": "norm",
"lr": 1e-4,
"weight_decay": 0.0,
},
{
"param_str_match": "bias",
"lr": 5e-4,
"weight_decay": 0.0,
},
],
}
```
Args:
model (torch.nn.Module): model to extract parameters from
optimizer_config (Dict[str, Any]): optimizer config
Returns:
Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: an iterable of
torch.Tensor's or dict's. Specifies what Tensors should be optimized
and their param groupings.
"""
if optimizer_config is None:
return model.parameters()
if 'disable_grad' in optimizer_config.keys():
str_matches = optimizer_config.pop('disable_grad')
if isinstance(str_matches, str):
str_matches = [str_matches]
for str_match in str_matches:
for n, p in model.named_parameters():
if re.search(str_match, n):
p.requires_grad = False
log.debug(f'Setting `{n}.requires_grad = False`.')
param_groups_config = optimizer_config.pop('param_groups', None)
if param_groups_config is not None:
params = []
param_dict = OrderedDict(((n, p) for n, p in model.named_parameters()))
log.debug(f'Default optimizer settings: {optimizer_config}.')
for param_group_config in param_groups_config:
str_match = param_group_config.pop('param_str_match')
filter_fn = functools.partial(re.search, str_match)
param_names = [n for n in param_dict.keys() if filter_fn(n)]
group_params = {'params': [param_dict.pop(n) for n in param_names]}
group_params.update(param_group_config)
log.debug(f'Creating optimizer param_group with parameters: {param_names} ' + f'(extracted using str_match={str_match!r}). The param_group optimizer ' + f'setting overrides are: {param_group_config}.')
params.append(group_params)
params.insert(0, {'params': param_dict.values()})
return params
return model.parameters()
def build_optimizer(model: torch.nn.Module, name: str, optimizer_config: Optional[Dict[str, Any]]=None) -> Optimizer:
params = _extract_param_groups(model, optimizer_config)
kwargs = optimizer_config
if kwargs is None:
kwargs = {}
if 'params' in kwargs:
raise ValueError('The `params` will be automatically extracted from the model and ' + 'optimizer config. Please remove it from the optimizer config kwargs.')
kwargs['params'] = params
return construct_from_registry(name=name, registry=registry.optimizers, partial_function=True, pre_validation_function=Optimizer, post_validation_function=None, kwargs=kwargs)
def build_scheduler(name: str, scheduler_config: Optional[Dict[str, Any]]=None) -> ComposerScheduler:
return construct_from_registry(name=name, registry=registry.schedulers, partial_function=True, pre_validation_function=ComposerScheduler, post_validation_function=None, kwargs=scheduler_config)
def build_tokenizer(tokenizer_name: str, tokenizer_kwargs: Dict[str, Any]) -> PreTrainedTokenizerBase:
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup'
if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
with dist.local_rank_zero_download_and_wait(signal_file_path):
pass
if tokenizer_name.startswith('tiktoken'):
tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs)
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_kwargs)
tokenizer.model_max_length = tokenizer_kwargs.get('model_max_length', int(1e+30))
if not hasattr(tokenizer, 'eos_token') or tokenizer.eos_token is None:
raise ValueError(f'The tokenizer {tokenizer_name} must have an eos_token.')
if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
if dist.get_local_rank() == 0:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_tokenizer_setup')
dist.barrier()
if dist.get_local_rank() == 0:
os.remove(signal_file_path)
return tokenizer
def build_icl_evaluators(icl_tasks: Union[str, ListConfig], tokenizer: PreTrainedTokenizerBase, default_max_seq_len: int, default_batch_size: int, destination_dir: Optional[str]=None, icl_subset_num_batches: Optional[int]=None) -> Tuple[List[Evaluator], List[str]]:
if destination_dir is None:
destination_dir = os.getcwd()
evaluators = []
logger_keys = []
icl_tasks_list = None
if isinstance(icl_tasks, str):
log.info(f'Extracting ICL task config from path: {icl_tasks}')
with open(icl_tasks, 'r') as icl_f:
icl_task_cfg = om.load(icl_f)
icl_tasks_list = icl_task_cfg.icl_tasks
else:
icl_tasks_list = icl_tasks
def _validate_cfg(icl_cfg: DictConfig):
assert 'label' in icl_cfg
assert 'dataset_uri' in icl_cfg and icl_cfg.dataset_uri is not None
assert 'icl_task_type' in icl_cfg
assert 'num_fewshot' in icl_cfg
if 'metric_names' not in icl_cfg:
if icl_cfg.icl_task_type == 'language_modeling':
icl_cfg.metric_names = ['InContextLearningLMAccuracy']
elif icl_cfg.icl_task_type == 'multiple_choice':
icl_cfg.metric_names = ['InContextLearningMultipleChoiceAccuracy']
elif icl_cfg.icl_task_type == 'schema':
icl_cfg.metric_names = ['InContextLearningMultipleChoiceAccuracy']
elif icl_cfg.icl_task_type == 'question_answering':
icl_cfg.metric_names = ['InContextLearningQAAccuracy']
elif icl_cfg.icl_task_type == 'code_evaluation':
icl_cfg.metric_names = ['InContextLearningCodeEvalAccuracy']
else:
raise ValueError(f'No metric_names defined, unable to build default metrics for icl_task_type={icl_cfg.icl_task_type}.')
if 'prompt_string' not in icl_cfg:
icl_cfg.prompt_string = ''
if 'example_delimiter' not in icl_cfg:
icl_cfg.example_delimiter = '\n'
if 'continuation_delimiter' not in icl_cfg:
icl_cfg.continuation_delimiter = ' '
if 'max_seq_len' not in icl_cfg:
icl_cfg.max_seq_len = default_max_seq_len
if 'batch_size' not in icl_cfg:
icl_cfg.batch_size = default_batch_size
if 'pass_at_k' not in icl_cfg:
icl_cfg.pass_at_k = 1
if 'fewshot_random_seed' not in icl_cfg:
icl_cfg.fewshot_random_seed = 1234
if 'generations_per_sample' not in icl_cfg:
icl_cfg.generations_per_sample = 1
if 'num_beams' in icl_cfg:
raise ValueError('num_beams is no longer supported as a top level icl_task parameter.' + 'Please use generation_kwargs.num_beams instead.')
for icl_cfg in icl_tasks_list:
assert isinstance(icl_cfg, DictConfig)
_validate_cfg(icl_cfg)
for num_fewshot in list(icl_cfg.num_fewshot):
if tokenizer.pad_token_id is None:
pad_tok_id = tokenizer.eos_token_id
else:
pad_tok_id = tokenizer.pad_token_id
label = f'{icl_cfg.label}/{num_fewshot}-shot'
metric_names = list(icl_cfg.metric_names)
destination_path = f'{destination_dir}/{icl_cfg.label}-{num_fewshot}.jsonl'
if dist.get_local_rank() == 0 and os.path.exists(destination_path):
os.remove(destination_path)
dist.barrier()
hf_parsing_map = icl_cfg.get('hf_parsing_map', {})
hf_loading_vars = icl_cfg.get('hf_loading_vars', {})
early_stopping_criteria = icl_cfg.get('early_stopping_criteria', None)
if isinstance(early_stopping_criteria, ListConfig):
early_stopping_criteria = om.to_container(early_stopping_criteria)
assert early_stopping_criteria is None or isinstance(early_stopping_criteria, list)
dataloaders = get_icl_task_dataloader(icl_cfg.icl_task_type, icl_cfg.dataset_uri, tokenizer, batch_size=icl_cfg.batch_size, max_seq_len=icl_cfg.max_seq_len, pad_tok_id=pad_tok_id, num_fewshot=num_fewshot, prompt_string=icl_cfg.prompt_string, example_delimiter=icl_cfg.example_delimiter, hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, continuation_delimiter=icl_cfg.continuation_delimiter, question_prelimiter=icl_cfg.get('question_prelimiter', ''), destination_path=destination_path, fewshot_random_seed=icl_cfg.fewshot_random_seed, pass_at_k=icl_cfg.pass_at_k, generations_per_sample=icl_cfg.generations_per_sample, has_categories=icl_cfg.get('has_categories', False), cot_delimiter=icl_cfg.get('cot_delimiter', ''), generation_kwargs=icl_cfg.get('generation_kwargs', {}), early_stopping_criteria=early_stopping_criteria, do_normalization=icl_cfg.get('do_normalization', True))
if hasattr(icl_cfg, 'has_categories') and icl_cfg.has_categories and isinstance(dataloaders, dict):
for category in dataloaders.keys():
logger_keys.extend([f'metrics/{label}/{category}/{m}' for m in metric_names])
evaluators.append(Evaluator(label=f'{label}/{category}', dataloader=dataloaders[category], metric_names=metric_names))
else:
logger_keys.extend([f'metrics/{label}/{m}' for m in metric_names])
evaluators.append(Evaluator(label=label, dataloader=dataloaders, metric_names=metric_names, subset_num_batches=icl_subset_num_batches))
return (evaluators, logger_keys)