mpt-7b-8k-chat / act_ckpt.py
irenedea's picture
LLM-foundry update March 26, 2024 23:50:31
fdb2891 verified
raw
history blame
5.81 kB
from typing import Any
import torch
from .attention import ATTN_CLASS_REGISTRY
from .blocks import MPTBlock
from .ffn import FFN_CLASS_REGISTRY
from .norm import NORM_CLASS_REGISTRY
def pass_on_block_idx(parent: torch.nn.Module):
if not hasattr(parent, 'block_idx') or not hasattr(parent, 'max_block_idx'):
return
for child in parent.children():
child.block_idx = parent.block_idx
child.max_block_idx = parent.max_block_idx
if child.children():
pass_on_block_idx(child)
def get_act_ckpt_module(mod_name: str) -> Any:
"""Get the module type from the module name."""
if mod_name.lower() == 'mptblock':
mod_type = MPTBlock
elif mod_name in ATTN_CLASS_REGISTRY:
mod_type = ATTN_CLASS_REGISTRY[mod_name]
elif mod_name in FFN_CLASS_REGISTRY:
mod_type = FFN_CLASS_REGISTRY[mod_name]
elif mod_name in NORM_CLASS_REGISTRY:
mod_type = NORM_CLASS_REGISTRY[mod_name]
else:
msg = ', '.join(list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) + list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock'])
raise ValueError(f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.')
return mod_type
def parse_ele_str(ele: str, max_block_idx: int) -> list:
"""Parse a string in target_blocks and return a list of block ids to add.
Supported formats are: first-n, middle-m, last-k, range-i-j which correspond
to the first n, the middle m, the last k, and the range [i, j).
"""
to_add = None
if ele.startswith('first-'):
assert ele[6:].isdigit(), f'Invalid target_blocks element {ele}'
to_add = list(range(min(int(ele[6:]), max_block_idx + 1)))
elif ele.startswith('last-'):
assert ele[5:].isdigit(), f'Invalid target_blocks element {ele}'
to_add = list(range(max(max_block_idx - int(ele[5:]) + 1, 0), max_block_idx + 1))
elif ele.startswith('middle-'):
assert ele[7:].isdigit(), f'Invalid target_blocks element {ele}'
num = int(ele[7:])
start = max(max_block_idx // 2 - num // 2, 0)
end = min(start + num, max_block_idx + 1)
to_add = list(range(start, end))
elif ele.startswith('range-'):
r = ele[6:].split('-')
assert len(r) == 2, f'Invalid target_blocks element {ele}'
start, end = (int(r[0]), int(r[1]))
start = max(start, 0)
end = min(end, max_block_idx + 1)
to_add = list(range(start, end))
else:
raise ValueError(f'Invalid target_blocks element {ele}')
return to_add
def get_target_block_list(target_blocks: Any, max_block_idx: int) -> list:
"""Parse the user input and return a list of block ids."""
candidate_block_ids = []
if isinstance(target_blocks, int):
candidate_block_ids = list(range(target_blocks))
elif isinstance(target_blocks, list):
for ele in target_blocks:
if isinstance(ele, int):
candidate_block_ids.append(ele)
elif isinstance(ele, str):
to_add = parse_ele_str(ele, max_block_idx)
candidate_block_ids.extend(to_add)
else:
raise ValueError(f'target_blocks must be a list of integers or "first-n", "middle-m", "last-k", or "range-i-j" where n, m, k, i, j are integers, but got {target_blocks}')
elif isinstance(target_blocks, str):
target_blocks = target_blocks.replace(' ', '')
for ele in target_blocks.split(','):
to_add = parse_ele_str(ele, max_block_idx)
candidate_block_ids.extend(to_add)
else:
raise ValueError(f'target_blocks must be either a single intege, or a list of integers, or a comma separated string made of "first-n", "last-m", "middle-k", "range-i-j", or a list of mixed integers and before-mentioned strings, but got {type(target_blocks)}')
candidate_block_ids = list(set(candidate_block_ids))
return candidate_block_ids
def check_mapping_blocks_overlap(mapping: dict, max_block_idx: int) -> None:
"""Check if the block ids in the mapping overlap with each other."""
all_blocks = [None] * (max_block_idx + 1)
for k, v in mapping.items():
if v == -1:
v = list(range(max_block_idx + 1))
for vv in v:
if vv < 0 or vv > max_block_idx:
continue
elif all_blocks[vv] is not None:
raise ValueError(f'Block {vv} is assigned to both {k} and {all_blocks[vv]}. Each block can only have one granularity of activation checkpointing. Make sure the target_blocks in activation_checkpointing_target do not overlap. For more details, refer to the docs of activation_checkpointing_fn.')
else:
all_blocks[vv] = k
def build_act_ckpt_mod_to_blocks(act_ckpt_target: Any, top_module: Any, max_block_idx: int) -> dict:
act_ckpt_mod_to_blocks = {}
if act_ckpt_target is None or act_ckpt_target == []:
mod = top_module
act_ckpt_mod_to_blocks[mod] = -1
elif isinstance(act_ckpt_target, str):
mod = get_act_ckpt_module(act_ckpt_target)
act_ckpt_mod_to_blocks[mod] = -1
elif isinstance(act_ckpt_target, list):
for target in act_ckpt_target:
mod = get_act_ckpt_module(target)
act_ckpt_mod_to_blocks[mod] = -1
elif isinstance(act_ckpt_target, dict):
for k, v in act_ckpt_target.items():
mod = get_act_ckpt_module(k)
block_ids = get_target_block_list(v, max_block_idx)
act_ckpt_mod_to_blocks[mod] = block_ids
else:
raise ValueError(f'activation_checkpointing_target must be either a single string or a list or a dict, but got {type(act_ckpt_target)}')
return act_ckpt_mod_to_blocks