uploading stuff lol
Browse files- app.py +48 -0
- cli/__init__.py +0 -0
- cli/config.json +20 -0
- cli/convert_model.py +86 -0
- cli/deploy_server.sh +87 -0
- cli/inference_one_block.py +53 -0
- cli/local_server_config_example.cfg +5 -0
- cli/remote_server_config_example.cfg +6 -0
- cli/run_local_servers.sh +111 -0
- cli/run_remote_servers.sh +112 -0
- cli/run_server.py +85 -0
- src/__init__.py +5 -0
- src/bloom/__init__.py +2 -0
- src/bloom/block.py +248 -0
- src/bloom/from_pretrained.py +80 -0
- src/bloom/model.py +408 -0
- src/bloom/ops.py +246 -0
- src/client/__init__.py +4 -0
- src/client/remote_block.py +135 -0
- src/client/remote_model.py +58 -0
- src/client/remote_sequence_info.py +94 -0
- src/client/remote_sequential.py +135 -0
- src/data_structures.py +8 -0
- src/dht_utils.py +132 -0
- src/server/__init__.py +0 -0
- src/server/backend.py +58 -0
- src/server/cache.py +127 -0
- src/server/handler.py +229 -0
- src/server/server.py +254 -0
app.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import transformers
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
from src.client import DistributedBloomForCausalLM
|
7 |
+
|
8 |
+
INITIAL_PEERS = ['/ip4/193.106.95.184/tcp/443/p2p/QmSXDXLeSMXjS4YerDrdn1zpGQaNzkZ9ogN2SoAEyAdDhs']
|
9 |
+
|
10 |
+
import hivemind # test that DHT instances work on localhost
|
11 |
+
dht1 = hivemind.DHT(start=True)
|
12 |
+
dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
|
13 |
+
|
14 |
+
|
15 |
+
tokenizer = transformers.BloomTokenizerFast.from_pretrained("bigscience/test-bloomd-6b3")
|
16 |
+
model = DistributedBloomForCausalLM.from_pretrained("bigscience/test-bloomd-6b3", initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32)
|
17 |
+
|
18 |
+
def inference(text, seq_length=1):
|
19 |
+
input_ids = tokenizer(text, return_tensors='pt')['input_ids']
|
20 |
+
final_tokens = input_ids
|
21 |
+
with torch.inference_mode(), model.transformer.h.inference_session() as remote_transformer:
|
22 |
+
for i in range(seq_length):
|
23 |
+
h = model.transformer.word_embeddings(input_ids)
|
24 |
+
h = model.transformer.word_embeddings_layernorm(h)
|
25 |
+
h = remote_transformer.step(h)
|
26 |
+
h = model.transformer.ln_f(h)
|
27 |
+
h = F.linear(h, weight=model.transformer.word_embeddings.weight) # note: this line takes a while, will also be fixed
|
28 |
+
next_token_ix = torch.multinomial((h[0, -1] / 0.8).softmax(-1), 1)
|
29 |
+
|
30 |
+
final_tokens = torch.cat([final_tokens, next_token_ix.view(1, 1)], dim=-1)
|
31 |
+
input_ids = next_token_ix.view(1, 1)
|
32 |
+
return tokenizer.decode(final_tokens[0], skip_special_tokens=False)
|
33 |
+
|
34 |
+
iface = gr.Interface(
|
35 |
+
fn=inference,
|
36 |
+
inputs=[
|
37 |
+
gr.Textbox(lines=10, label="Input text"),
|
38 |
+
gr.inputs.Slider(
|
39 |
+
minimum=0,
|
40 |
+
maximum=1000,
|
41 |
+
step=1,
|
42 |
+
default=42,
|
43 |
+
label="Sequence length for generation"
|
44 |
+
)
|
45 |
+
],
|
46 |
+
outputs="text"
|
47 |
+
)
|
48 |
+
iface.launch()
|
cli/__init__.py
ADDED
File without changes
|
cli/config.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"apply_residual_connection_post_layernorm": false,
|
3 |
+
"attention_dropout": 0.0,
|
4 |
+
"attention_softmax_in_fp32": true,
|
5 |
+
"bos_token_id": 1,
|
6 |
+
"eos_token_id": 2,
|
7 |
+
"hidden_dropout": 0.0,
|
8 |
+
"initializer_range": 0.02,
|
9 |
+
"layer_norm_epsilon": 1e-05,
|
10 |
+
"masked_softmax_fusion": true,
|
11 |
+
"model_type": "bloom",
|
12 |
+
"n_embed": 14336,
|
13 |
+
"n_layer": 70,
|
14 |
+
"num_attention_heads": 112,
|
15 |
+
"pretraining_tp": 4,
|
16 |
+
"slow_but_exact": false,
|
17 |
+
"transformers_version": "4.20.0.dev0",
|
18 |
+
"use_cache": true,
|
19 |
+
"vocab_size": 250880
|
20 |
+
}
|
cli/convert_model.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import psutil
|
5 |
+
import torch.backends.quantized
|
6 |
+
import torch.nn as nn
|
7 |
+
import transformers
|
8 |
+
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
9 |
+
from huggingface_hub import Repository
|
10 |
+
from tqdm.auto import tqdm
|
11 |
+
|
12 |
+
from src import BloomModel
|
13 |
+
from src.client import DistributedBloomConfig
|
14 |
+
from src.bloom.from_pretrained import CLIENT_BRANCH, BLOCK_BRANCH_PREFIX
|
15 |
+
use_hivemind_log_handler("in_root_logger")
|
16 |
+
logger = get_logger(__file__)
|
17 |
+
|
18 |
+
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
19 |
+
|
20 |
+
|
21 |
+
if __name__ == "__main__":
|
22 |
+
parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
|
23 |
+
|
24 |
+
parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
|
25 |
+
parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
|
26 |
+
parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
|
27 |
+
parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
|
28 |
+
parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
|
29 |
+
parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
|
30 |
+
parser.add_argument(
|
31 |
+
"--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix"
|
32 |
+
)
|
33 |
+
parser.add_argument(
|
34 |
+
"--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
|
35 |
+
)
|
36 |
+
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
37 |
+
args = parser.parse_args()
|
38 |
+
|
39 |
+
free_ram_gb = psutil.virtual_memory().available / 2**30
|
40 |
+
if args.model == "bigscience/bloom" and free_ram_gb < 400:
|
41 |
+
logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
|
42 |
+
|
43 |
+
assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
|
44 |
+
if os.path.exists(args.output_path) and (
|
45 |
+
len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
|
46 |
+
):
|
47 |
+
raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
|
48 |
+
|
49 |
+
logger.info(f"Loading source model {args.model} (this may take a few minutes)")
|
50 |
+
config = DistributedBloomConfig.from_pretrained(
|
51 |
+
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
52 |
+
)
|
53 |
+
config.dht_prefix = args.output_repo
|
54 |
+
|
55 |
+
model = BloomModel.from_pretrained(
|
56 |
+
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
57 |
+
)
|
58 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
59 |
+
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
60 |
+
)
|
61 |
+
os.makedirs(args.output_path, exist_ok=True)
|
62 |
+
|
63 |
+
repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
|
64 |
+
repo.git_pull()
|
65 |
+
|
66 |
+
transformer_blocks = model.h
|
67 |
+
logger.info(
|
68 |
+
f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
|
69 |
+
f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
|
70 |
+
)
|
71 |
+
for i, block in enumerate(tqdm(transformer_blocks)):
|
72 |
+
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
73 |
+
with repo.commit(
|
74 |
+
commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
|
75 |
+
):
|
76 |
+
torch.save(block.state_dict(), "./pytorch_model.bin")
|
77 |
+
|
78 |
+
logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
|
79 |
+
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
80 |
+
with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
|
81 |
+
model.h = nn.ModuleList()
|
82 |
+
model.save_pretrained(".")
|
83 |
+
tokenizer.save_pretrained(".")
|
84 |
+
config.save_pretrained(".")
|
85 |
+
|
86 |
+
logger.info(f"Converted {args.model} and pushed to {args.output_repo}")
|
cli/deploy_server.sh
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
#################
|
4 |
+
# Parse options #
|
5 |
+
#################
|
6 |
+
|
7 |
+
instructions() {
|
8 |
+
echo "Usage: $0 [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2
|
9 |
+
echo " -i: initial peer"
|
10 |
+
echo " -d: device" >&2
|
11 |
+
echo " -p: server identity path" >&2
|
12 |
+
echo " -b: block_ids" >&2
|
13 |
+
echo " -a: host maddrs" >&2
|
14 |
+
echo " -t: whether to run local tests" >&2
|
15 |
+
exit 1
|
16 |
+
}
|
17 |
+
|
18 |
+
if [ ! $# -ge 8 ]; then
|
19 |
+
instructions
|
20 |
+
fi
|
21 |
+
|
22 |
+
while getopts ":i:d:p:b:a:t:" option; do
|
23 |
+
case $option in
|
24 |
+
i) INITIAL_PEER=${OPTARG}
|
25 |
+
;;
|
26 |
+
d) DEVICE=${OPTARG}
|
27 |
+
;;
|
28 |
+
p) SERVER_ID_PATH=${OPTARG}
|
29 |
+
;;
|
30 |
+
b) BLOCK_IDS=${OPTARG}
|
31 |
+
;;
|
32 |
+
a) HOST_MADDR=${OPTARG} # TODO: allow several maddrs
|
33 |
+
;;
|
34 |
+
t) RUN_LOCAL_TESTS=true
|
35 |
+
;;
|
36 |
+
\?) instructions
|
37 |
+
;;
|
38 |
+
esac
|
39 |
+
done
|
40 |
+
|
41 |
+
|
42 |
+
echo "=========="
|
43 |
+
echo "= Config ="
|
44 |
+
echo "=========="
|
45 |
+
echo "Initial peer: ${INITIAL_PEER}"
|
46 |
+
echo "Device: ${DEVICE}"
|
47 |
+
echo "Server name: ${SERVER_ID_PATH}"
|
48 |
+
echo "Server address: ${HOST_MADDR}"
|
49 |
+
echo "Bloom blocks: ${BLOCK_IDS}"
|
50 |
+
|
51 |
+
|
52 |
+
###########################
|
53 |
+
# Install or activate env #
|
54 |
+
###########################
|
55 |
+
|
56 |
+
# TODO fix bug with self calling
|
57 |
+
source ~/miniconda3/etc/profile.d/conda.sh
|
58 |
+
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
|
59 |
+
conda activate bloom-demo
|
60 |
+
else
|
61 |
+
conda create -y --name bloom-demo python=3.8.12 pip
|
62 |
+
conda activate bloom-demo
|
63 |
+
|
64 |
+
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
65 |
+
pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
66 |
+
pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
|
67 |
+
pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
|
68 |
+
pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
|
69 |
+
fi
|
70 |
+
|
71 |
+
|
72 |
+
##############
|
73 |
+
# Local test #
|
74 |
+
##############
|
75 |
+
|
76 |
+
if [ "$RUN_LOCAL_TESTS" = true ] ; then
|
77 |
+
echo "Run test on your local machine"
|
78 |
+
python -m cli.inference_one_block --config cli/config.json --device ${DEVICE} # see other args
|
79 |
+
fi
|
80 |
+
|
81 |
+
|
82 |
+
##############
|
83 |
+
# Run server #
|
84 |
+
##############
|
85 |
+
|
86 |
+
python -m cli.run_server --prefix bloom6b3 --converted_model_name_or_path bigscience/test-bloomd-6b3 --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
|
87 |
+
--block_indices ${BLOCK_IDS} --torch_dtype float32 --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} &> ${SERVER_ID_PATH}.log
|
cli/inference_one_block.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
5 |
+
from tqdm.auto import trange
|
6 |
+
|
7 |
+
from src.bloom.block import BloomBlock
|
8 |
+
from src.bloom.model import BloomConfig
|
9 |
+
from src.bloom.ops import build_alibi_tensor
|
10 |
+
|
11 |
+
use_hivemind_log_handler("in_root_logger")
|
12 |
+
logger = get_logger(__file__)
|
13 |
+
|
14 |
+
logger.warning("inference_one_block will soon be deprecated in favour of tests!")
|
15 |
+
|
16 |
+
|
17 |
+
def print_device_info(device=None):
|
18 |
+
"""Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
|
19 |
+
device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
20 |
+
logger.info(f"Using device: {device}")
|
21 |
+
|
22 |
+
# Additional Info when using cuda
|
23 |
+
if device.type == "cuda":
|
24 |
+
logger.info(torch.cuda.get_device_name(0))
|
25 |
+
logger.info(f"Memory Usage:")
|
26 |
+
logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
|
27 |
+
logger.info(f"Cached: {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
parser = argparse.ArgumentParser(description="Run a single bloom block locally on dummy data")
|
32 |
+
parser.add_argument("--config", required=True, type=str, help="Path to a config json file")
|
33 |
+
parser.add_argument("--state_dict", default=None, type=str, help="Optional path to saved block state dict")
|
34 |
+
parser.add_argument("--layer_index", default=0, type=int, help="Optional path to saved block state dict")
|
35 |
+
parser.add_argument("--num_steps", default=500, type=int, help="How many inference steps to run")
|
36 |
+
parser.add_argument("--device", default=None, type=str, help="Run inference on this device")
|
37 |
+
args = parser.parse_args()
|
38 |
+
|
39 |
+
if args.device is None:
|
40 |
+
args.device = "cuda" if torch.cuda.is_available() else "cpu"
|
41 |
+
|
42 |
+
config = BloomConfig.from_json_file(args.config)
|
43 |
+
block = BloomBlock(config, args.layer_index).to(args.device)
|
44 |
+
|
45 |
+
cache = None
|
46 |
+
|
47 |
+
for i in trange(args.num_steps):
|
48 |
+
dummy_input = torch.randn(1, 1, config.hidden_size, device=args.device)
|
49 |
+
alibi = build_alibi_tensor(i + 1, config.num_attention_heads).to(args.device)
|
50 |
+
with torch.no_grad():
|
51 |
+
outputs, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
|
52 |
+
|
53 |
+
print_device_info(args.device)
|
cli/local_server_config_example.cfg
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
device=cpu
|
2 |
+
block_ids=2:3
|
3 |
+
id_path=./server.id
|
4 |
+
maddr=/ip4/127.0.0.1/tcp/30000
|
5 |
+
#
|
cli/remote_server_config_example.cfg
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name=bloom-peer-0.bloom.net
|
2 |
+
device=cpu
|
3 |
+
block_ids=1:3
|
4 |
+
id_path=./server.id
|
5 |
+
maddr=/ip4/0.0.0.0/tcp/30000
|
6 |
+
#
|
cli/run_local_servers.sh
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# !/usr/bin/env bash
|
2 |
+
|
3 |
+
#################
|
4 |
+
# Parse options #
|
5 |
+
#################
|
6 |
+
|
7 |
+
instructions() {
|
8 |
+
echo "Usage: $0 [-n] [-c]" >&2
|
9 |
+
echo " -n: number of servers to run" >&2
|
10 |
+
echo " -c: path to the server configs" >&2
|
11 |
+
exit 1
|
12 |
+
}
|
13 |
+
|
14 |
+
if [ $# != 4 ]; then
|
15 |
+
instructions
|
16 |
+
fi
|
17 |
+
|
18 |
+
while getopts ":n:c:t:" option; do
|
19 |
+
case $option in
|
20 |
+
n) NUM_SERVERS=${OPTARG}
|
21 |
+
;;
|
22 |
+
c) CONFIG_PATH=${OPTARG}
|
23 |
+
;;
|
24 |
+
\?) instructions
|
25 |
+
;;
|
26 |
+
esac
|
27 |
+
done
|
28 |
+
|
29 |
+
|
30 |
+
###########################
|
31 |
+
# Install or activate env #
|
32 |
+
###########################
|
33 |
+
|
34 |
+
source ~/miniconda3/etc/profile.d/conda.sh
|
35 |
+
if conda env list | grep ".*bloom-demo.*" &>/dev/null; then
|
36 |
+
conda activate bloom-demo
|
37 |
+
else
|
38 |
+
conda create -y --name bloom-demo python=3.8.12 pip
|
39 |
+
conda activate bloom-demo
|
40 |
+
|
41 |
+
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
42 |
+
pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
43 |
+
pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
|
44 |
+
pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
|
45 |
+
pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
|
46 |
+
fi
|
47 |
+
|
48 |
+
|
49 |
+
#######################
|
50 |
+
# Create Initial peer #
|
51 |
+
#######################
|
52 |
+
|
53 |
+
hivemind-dht &> tmp.out &
|
54 |
+
sleep 3
|
55 |
+
INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-1])" )
|
56 |
+
echo "Initial peer: ${INITIAL_PEER}"
|
57 |
+
|
58 |
+
|
59 |
+
##############################
|
60 |
+
# Initialize the config file #
|
61 |
+
##############################
|
62 |
+
|
63 |
+
typeset -A cfg
|
64 |
+
cfg=( # set default values in config array
|
65 |
+
[device]="cpu"
|
66 |
+
[block_ids]="1:2"
|
67 |
+
[id_path]="server.id"
|
68 |
+
[maddr]="/ip4/127.0.0.1/tcp/30000"
|
69 |
+
)
|
70 |
+
|
71 |
+
###############
|
72 |
+
# Run servers #
|
73 |
+
###############
|
74 |
+
|
75 |
+
for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
|
76 |
+
do
|
77 |
+
###############
|
78 |
+
# Read config #
|
79 |
+
###############
|
80 |
+
|
81 |
+
while read line
|
82 |
+
do
|
83 |
+
if echo $line | grep -F = &>/dev/null
|
84 |
+
then
|
85 |
+
varname=$(echo "$line" | cut -d '=' -f 1)
|
86 |
+
cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
|
87 |
+
fi
|
88 |
+
done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
|
89 |
+
|
90 |
+
echo "=== Server #${SERVER_ID} ==="
|
91 |
+
echo "Server ID: ${id_path}"
|
92 |
+
echo "Device: ${cfg[device]}"
|
93 |
+
echo "Bloom block ids: ${cfg[block_ids]}"
|
94 |
+
echo "Host maddr: ${cfg[maddr]}"
|
95 |
+
echo ""
|
96 |
+
|
97 |
+
##############
|
98 |
+
# Run server #
|
99 |
+
##############
|
100 |
+
|
101 |
+
tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
|
102 |
+
done
|
103 |
+
|
104 |
+
|
105 |
+
#####################
|
106 |
+
# Kill initial peer #
|
107 |
+
#####################
|
108 |
+
|
109 |
+
sleep 10
|
110 |
+
pkill -f hivemind-dht # TODO: kill only particular pids of hivemind-dht
|
111 |
+
rm tmp.out
|
cli/run_remote_servers.sh
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# !/usr/bin/env bash
|
2 |
+
|
3 |
+
SSH_KEY_PATH="~/.ssh/<YOUR_KEY>"
|
4 |
+
|
5 |
+
#################
|
6 |
+
# Parse options #
|
7 |
+
#################
|
8 |
+
|
9 |
+
instructions() {
|
10 |
+
echo "Usage: $0 [-u] [-n] [-c]" >&2
|
11 |
+
echo " -u: username" >&2
|
12 |
+
echo " -n: number of servers to run" >&2
|
13 |
+
echo " -c: path to the server configs" >&2
|
14 |
+
exit 1
|
15 |
+
}
|
16 |
+
|
17 |
+
if [ $# != 6 ]; then
|
18 |
+
instructions
|
19 |
+
fi
|
20 |
+
|
21 |
+
while getopts ":u:n:c:" option; do
|
22 |
+
case $option in
|
23 |
+
u) USERNAME=${OPTARG}
|
24 |
+
;;
|
25 |
+
n) NUM_SERVERS=${OPTARG}
|
26 |
+
;;
|
27 |
+
c) CONFIG_PATH=${OPTARG}
|
28 |
+
;;
|
29 |
+
\?) instructions
|
30 |
+
;;
|
31 |
+
esac
|
32 |
+
done
|
33 |
+
|
34 |
+
|
35 |
+
###########################
|
36 |
+
# Install or activate env #
|
37 |
+
###########################
|
38 |
+
|
39 |
+
source ~/miniconda3/etc/profile.d/conda.sh
|
40 |
+
if conda env list | grep ".*bloom-demo.*" &>/dev/null; then
|
41 |
+
conda activate bloom-demo
|
42 |
+
else
|
43 |
+
conda create -y --name bloom-demo python=3.8.12 pip
|
44 |
+
conda activate bloom-demo
|
45 |
+
|
46 |
+
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
47 |
+
pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
48 |
+
pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
|
49 |
+
pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
|
50 |
+
pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
|
51 |
+
fi
|
52 |
+
|
53 |
+
|
54 |
+
#######################
|
55 |
+
# Create Initial peer #
|
56 |
+
#######################
|
57 |
+
|
58 |
+
hivemind-dht &> tmp.out &
|
59 |
+
|
60 |
+
sleep 3
|
61 |
+
INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-2])" )
|
62 |
+
rm tmp.out
|
63 |
+
echo "Initial peer: ${INITIAL_PEER}"
|
64 |
+
|
65 |
+
|
66 |
+
##############################
|
67 |
+
# Initialize the config file #
|
68 |
+
##############################
|
69 |
+
|
70 |
+
typeset -A cfg
|
71 |
+
cfg=( # set default values in config array
|
72 |
+
[name]=""
|
73 |
+
[device]="cpu"
|
74 |
+
[block_ids]="1:2"
|
75 |
+
[id_path]="server.id"
|
76 |
+
[maddr]="/ip4/0.0.0.0/tcp/30000"
|
77 |
+
)
|
78 |
+
|
79 |
+
###############
|
80 |
+
# Run servers #
|
81 |
+
###############
|
82 |
+
|
83 |
+
for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
|
84 |
+
do
|
85 |
+
###############
|
86 |
+
# Read config #
|
87 |
+
###############
|
88 |
+
|
89 |
+
while read line
|
90 |
+
do
|
91 |
+
if echo $line | grep -F = &>/dev/null
|
92 |
+
then
|
93 |
+
varname=$(echo "$line" | cut -d '=' -f 1)
|
94 |
+
cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
|
95 |
+
fi
|
96 |
+
done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
|
97 |
+
|
98 |
+
SERVER_NAME="${USERNAME}@${cfg[name]}"
|
99 |
+
echo "=== Server #${SERVER_ID} ==="
|
100 |
+
echo "Server name ${SERVER_NAME}"
|
101 |
+
echo "Server ID: ${cfg[id_path]}"
|
102 |
+
echo "Device: ${cfg[device]}"
|
103 |
+
echo "Bloom block ids: ${cfg[block_ids]}"
|
104 |
+
echo "Host maddr: ${cfg[maddr]}"
|
105 |
+
echo "================="
|
106 |
+
|
107 |
+
##############
|
108 |
+
# Run server #
|
109 |
+
##############
|
110 |
+
|
111 |
+
ssh -i ${SSH_KEY_PATH} ${SERVER_NAME} "tmux new-session -d -s 'Server_${SERVER_ID}' 'cd bloom-demo && bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}'"
|
112 |
+
done
|
cli/run_server.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import configargparse
|
2 |
+
from hivemind.proto.runtime_pb2 import CompressionType
|
3 |
+
from hivemind.utils.limits import increase_file_limit
|
4 |
+
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
5 |
+
|
6 |
+
from src.server.server import Server
|
7 |
+
|
8 |
+
use_hivemind_log_handler("in_root_logger")
|
9 |
+
logger = get_logger(__file__)
|
10 |
+
|
11 |
+
|
12 |
+
def main():
|
13 |
+
# fmt:off
|
14 |
+
parser = configargparse.ArgParser(default_config_files=["config.yml"])
|
15 |
+
parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
|
16 |
+
|
17 |
+
parser.add_argument('--converted_model_name_or_path', type=str, default='bigscience/test-bloomd-6b3',
|
18 |
+
help="path or name of a pretrained model, converted with cli/convert_model.py (see README.md)")
|
19 |
+
parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
|
20 |
+
parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
|
21 |
+
parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default,"
|
22 |
+
"use the same name as in the converted model.")
|
23 |
+
parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
|
24 |
+
help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
|
25 |
+
parser.add_argument('--announce_maddrs', nargs='+', default=None, required=False,
|
26 |
+
help='Visible multiaddrs the host announces for external connections from other p2p instances')
|
27 |
+
|
28 |
+
parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
|
29 |
+
|
30 |
+
parser.add_argument('--num_handlers', type=int, default=None, required=False,
|
31 |
+
help='server will use this many processes to handle incoming requests')
|
32 |
+
parser.add_argument('--min_batch_size', type=int, default=1,
|
33 |
+
help='Minimum required batch size for all expert operations')
|
34 |
+
parser.add_argument('--max_batch_size', type=int, default=16384,
|
35 |
+
help='The total number of examples in the same batch will not exceed this value')
|
36 |
+
parser.add_argument('--cache_size_bytes', type=int, default=None,
|
37 |
+
help='The size of memory cache for storing past attention keys/values between inference steps')
|
38 |
+
parser.add_argument('--device', type=str, default=None, required=False,
|
39 |
+
help='all experts will use this device in torch notation; default: cuda if available else cpu')
|
40 |
+
parser.add_argument("--torch_dtype", type=str, default="auto",
|
41 |
+
help="Use this dtype to store block weights and do computations. "
|
42 |
+
"By default, respect the dtypes in the pre-trained state dict.")
|
43 |
+
|
44 |
+
parser.add_argument('--update_period', type=float, required=False, default=30,
|
45 |
+
help='Server will report experts to DHT once in this many seconds')
|
46 |
+
parser.add_argument('--expiration', type=float, required=False, default=None,
|
47 |
+
help='DHT entries will expire after this many seconds')
|
48 |
+
parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
|
49 |
+
help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
|
50 |
+
parser.add_argument('--increase_file_limit', action='store_true',
|
51 |
+
help='On *nix, this will increase the max number of processes '
|
52 |
+
'a server can spawn before hitting "Too many open files"; Use at your own risk.')
|
53 |
+
parser.add_argument('--stats_report_interval', type=int, required=False,
|
54 |
+
help='Interval between two reports of batch processing performance statistics')
|
55 |
+
|
56 |
+
parser.add_argument('--custom_module_path', type=str, required=False,
|
57 |
+
help='Path of a file with custom nn.modules, wrapped into special decorator')
|
58 |
+
parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
|
59 |
+
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
60 |
+
|
61 |
+
# fmt:on
|
62 |
+
args = vars(parser.parse_args())
|
63 |
+
args.pop("config", None)
|
64 |
+
|
65 |
+
if args.pop("increase_file_limit"):
|
66 |
+
increase_file_limit()
|
67 |
+
|
68 |
+
compression_type = args.pop("compression")
|
69 |
+
compression = getattr(CompressionType, compression_type)
|
70 |
+
|
71 |
+
use_auth_token = args.pop("use_auth_token")
|
72 |
+
args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
|
73 |
+
|
74 |
+
server = Server.create(**args, start=True, compression=compression)
|
75 |
+
|
76 |
+
try:
|
77 |
+
server.join()
|
78 |
+
except KeyboardInterrupt:
|
79 |
+
logger.info("Caught KeyboardInterrupt, shutting down")
|
80 |
+
finally:
|
81 |
+
server.shutdown()
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
main()
|
src/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.bloom import *
|
2 |
+
from src.client import *
|
3 |
+
from src.dht_utils import declare_active_modules, get_remote_module
|
4 |
+
|
5 |
+
__version__ = "0.2"
|
src/bloom/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from src.bloom.block import BloomBlock
|
2 |
+
from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel
|
src/bloom/block.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Bloom intermediate layer
|
3 |
+
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
4 |
+
See commit history for authorship.
|
5 |
+
"""
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.quantized.dynamic.modules.linear
|
11 |
+
|
12 |
+
from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
|
13 |
+
pre_process_alibi_for_pad, split_tensor_along_last_dim)
|
14 |
+
|
15 |
+
|
16 |
+
class BloomAttention(nn.Module):
|
17 |
+
def __init__(self, config, layer_number=None):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
self.hidden_size = config.hidden_size
|
21 |
+
self.num_heads = config.n_head
|
22 |
+
self.head_dim = self.hidden_size // self.num_heads
|
23 |
+
self.split_size = self.hidden_size
|
24 |
+
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
|
25 |
+
self.masked_softmax_fusion = config.masked_softmax_fusion
|
26 |
+
self.hidden_dropout = config.hidden_dropout
|
27 |
+
|
28 |
+
if self.head_dim * self.num_heads != self.hidden_size:
|
29 |
+
raise ValueError(
|
30 |
+
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
|
31 |
+
f" {self.num_heads})."
|
32 |
+
)
|
33 |
+
|
34 |
+
# Layer-wise attention scaling
|
35 |
+
self.layer_number = max(1, layer_number)
|
36 |
+
self.norm_factor = math.sqrt(self.head_dim) * self.layer_number
|
37 |
+
|
38 |
+
# Scaled Softmax
|
39 |
+
self.scale_mask_softmax = BloomScaledSoftmax(
|
40 |
+
self.masked_softmax_fusion,
|
41 |
+
attention_mask_func,
|
42 |
+
self.attention_softmax_in_fp32,
|
43 |
+
self.layer_number,
|
44 |
+
)
|
45 |
+
|
46 |
+
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
|
47 |
+
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
|
48 |
+
|
49 |
+
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
50 |
+
|
51 |
+
def forward(
|
52 |
+
self,
|
53 |
+
hidden_states,
|
54 |
+
residual,
|
55 |
+
layer_past=None,
|
56 |
+
attention_mask=None,
|
57 |
+
alibi=None,
|
58 |
+
head_mask=None,
|
59 |
+
use_cache=False,
|
60 |
+
output_attentions=False,
|
61 |
+
):
|
62 |
+
if alibi is None:
|
63 |
+
current_sequence_length = hidden_states.shape[1] + (0 if layer_past is None else layer_past[0].shape[1])
|
64 |
+
alibi = build_alibi_tensor(
|
65 |
+
current_sequence_length, n_head=self.num_heads, dtype=hidden_states.dtype, device=hidden_states.device
|
66 |
+
)
|
67 |
+
|
68 |
+
# hidden_states: [batch_size, seq_length, hidden_size]
|
69 |
+
# apply preprocessing if the input is padded
|
70 |
+
if attention_mask is not None:
|
71 |
+
alibi = pre_process_alibi_for_pad(alibi, attention_mask)
|
72 |
+
# otherwise repeat alibi tensor with the batch size
|
73 |
+
else:
|
74 |
+
alibi = alibi.repeat(hidden_states.shape[0], 1, 1)
|
75 |
+
|
76 |
+
mixed_x_layer = self.query_key_value(hidden_states)
|
77 |
+
|
78 |
+
# [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim]
|
79 |
+
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim)
|
80 |
+
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
81 |
+
|
82 |
+
# [batch_size, seq_length, num_heads, 3 x head_dim] --> 3 [batch_size, seq_length, num_heads, head_dim]
|
83 |
+
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
84 |
+
|
85 |
+
if layer_past is not None:
|
86 |
+
past_key, past_value = layer_past
|
87 |
+
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
|
88 |
+
value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1)
|
89 |
+
|
90 |
+
if use_cache is True:
|
91 |
+
present = (key_layer, value_layer)
|
92 |
+
else:
|
93 |
+
present = None
|
94 |
+
|
95 |
+
# [batch_size, head_dim, q_length, k_length]
|
96 |
+
output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1))
|
97 |
+
|
98 |
+
# [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
|
99 |
+
query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1)
|
100 |
+
|
101 |
+
# [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
|
102 |
+
key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1)
|
103 |
+
|
104 |
+
# Raw attention scores. [batch_size * num_heads, q_length, k_length]
|
105 |
+
beta = 1.0 / self.layer_number
|
106 |
+
|
107 |
+
matmul_result = torch.baddbmm(
|
108 |
+
alibi,
|
109 |
+
query_layer.transpose(1, 0),
|
110 |
+
key_layer.transpose(1, 0).transpose(1, 2),
|
111 |
+
beta=beta,
|
112 |
+
alpha=(1.0 / self.norm_factor),
|
113 |
+
)
|
114 |
+
|
115 |
+
# change view to [batch_size, num_heads, q_length, k_length]
|
116 |
+
attention_scores = matmul_result.view(*output_size)
|
117 |
+
|
118 |
+
# attention scores and attention mask [b, np, sq, sk]
|
119 |
+
max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
|
120 |
+
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(value_layer.dtype)
|
121 |
+
attention_probs = self.attention_dropout(attention_probs)
|
122 |
+
|
123 |
+
if head_mask is not None:
|
124 |
+
attention_probs = attention_probs * head_mask
|
125 |
+
|
126 |
+
# context layer shape: [batch_size, num_heads, q_length, head_dim]
|
127 |
+
output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3))
|
128 |
+
|
129 |
+
# change view [k_length, batch_size x num_heads, head_dim]
|
130 |
+
value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1)
|
131 |
+
|
132 |
+
# change view [batch_size x num_heads, q_length, k_length]
|
133 |
+
attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
|
134 |
+
|
135 |
+
# matmul: [batch_size * num_heads, q_length, head_dim]
|
136 |
+
context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
|
137 |
+
|
138 |
+
# change view [batch_size, num_heads, q_length, head_dim]
|
139 |
+
context_layer = context_layer.view(*output_size)
|
140 |
+
|
141 |
+
# [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim]
|
142 |
+
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
143 |
+
|
144 |
+
# [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size]
|
145 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
|
146 |
+
|
147 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
148 |
+
|
149 |
+
# Output. [q_length, batch_size, hidden_size]
|
150 |
+
|
151 |
+
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
152 |
+
output_tensor = self.dense(context_layer)
|
153 |
+
output = output_tensor.transpose(1, 0)
|
154 |
+
|
155 |
+
output = dropout_add(output, residual, self.hidden_dropout, self.training)
|
156 |
+
|
157 |
+
outputs = (output, present)
|
158 |
+
if output_attentions:
|
159 |
+
outputs += (attention_probs,)
|
160 |
+
|
161 |
+
return outputs
|
162 |
+
|
163 |
+
|
164 |
+
class BloomMLP(nn.Module):
|
165 |
+
def __init__(self, config):
|
166 |
+
super().__init__()
|
167 |
+
self.hidden_size = config.hidden_size
|
168 |
+
self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
|
169 |
+
self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
|
170 |
+
self.hidden_dropout = config.hidden_dropout
|
171 |
+
self.gelu_impl = BloomGelu()
|
172 |
+
|
173 |
+
def forward(self, hidden_states, residual):
|
174 |
+
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
|
175 |
+
intermediate_output = self.dense_4h_to_h(hidden_states)
|
176 |
+
output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
|
177 |
+
return output
|
178 |
+
|
179 |
+
|
180 |
+
class BloomBlock(nn.Module):
|
181 |
+
def __init__(self, config, layer_number=None):
|
182 |
+
super().__init__()
|
183 |
+
self.hidden_size = config.hidden_size
|
184 |
+
|
185 |
+
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
|
186 |
+
self.n_head = config.n_head
|
187 |
+
self.self_attention = BloomAttention(config, layer_number=layer_number)
|
188 |
+
self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
|
189 |
+
|
190 |
+
self.mlp = BloomMLP(config)
|
191 |
+
|
192 |
+
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
193 |
+
self.hidden_dropout = config.hidden_dropout
|
194 |
+
|
195 |
+
def forward(
|
196 |
+
self,
|
197 |
+
hidden_states,
|
198 |
+
layer_past=None,
|
199 |
+
attention_mask=None,
|
200 |
+
head_mask=None,
|
201 |
+
use_cache=False,
|
202 |
+
output_attentions=False,
|
203 |
+
alibi=None,
|
204 |
+
):
|
205 |
+
# hidden_states: [batch_size, seq_length, hidden_size]
|
206 |
+
|
207 |
+
# Layer norm at the beginning of the transformer layer.
|
208 |
+
layernorm_output = self.input_layernorm(hidden_states)
|
209 |
+
|
210 |
+
# Layer norm post the self attention.
|
211 |
+
if self.apply_residual_connection_post_layernorm:
|
212 |
+
residual = layernorm_output
|
213 |
+
else:
|
214 |
+
residual = hidden_states
|
215 |
+
|
216 |
+
# Self attention.
|
217 |
+
attn_outputs = self.self_attention(
|
218 |
+
layernorm_output,
|
219 |
+
residual,
|
220 |
+
layer_past=layer_past,
|
221 |
+
attention_mask=attention_mask,
|
222 |
+
alibi=alibi,
|
223 |
+
head_mask=head_mask,
|
224 |
+
use_cache=use_cache,
|
225 |
+
output_attentions=output_attentions,
|
226 |
+
)
|
227 |
+
|
228 |
+
attention_output = attn_outputs[0]
|
229 |
+
|
230 |
+
outputs = attn_outputs[1:]
|
231 |
+
|
232 |
+
layernorm_output = self.post_attention_layernorm(attention_output)
|
233 |
+
|
234 |
+
# Get residual
|
235 |
+
if self.apply_residual_connection_post_layernorm:
|
236 |
+
residual = layernorm_output
|
237 |
+
else:
|
238 |
+
residual = attention_output
|
239 |
+
|
240 |
+
# MLP.
|
241 |
+
output = self.mlp(layernorm_output, residual)
|
242 |
+
|
243 |
+
if use_cache:
|
244 |
+
outputs = (output,) + outputs
|
245 |
+
else:
|
246 |
+
outputs = (output,) + outputs[1:]
|
247 |
+
|
248 |
+
return outputs # hidden_states, present, attentions
|
src/bloom/from_pretrained.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
|
3 |
+
If necessary, one can rewrite this to implement a different behavior, such as:
|
4 |
+
- loading files from a local data source (e.g. S3)
|
5 |
+
- load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
|
6 |
+
- fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
|
7 |
+
|
8 |
+
"""
|
9 |
+
from __future__ import annotations
|
10 |
+
|
11 |
+
from typing import Optional, OrderedDict, Union
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
15 |
+
from transformers.modeling_utils import WEIGHTS_NAME
|
16 |
+
from transformers.utils.hub import cached_path, hf_bucket_url
|
17 |
+
|
18 |
+
from src.bloom import BloomBlock, BloomConfig
|
19 |
+
|
20 |
+
use_hivemind_log_handler("in_root_logger")
|
21 |
+
logger = get_logger(__file__)
|
22 |
+
|
23 |
+
CLIENT_BRANCH = "main"
|
24 |
+
BLOCK_BRANCH_PREFIX = "block_"
|
25 |
+
USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
|
26 |
+
FORCE_DOWNLOAD = False
|
27 |
+
RESUME_DOWNLOAD = False
|
28 |
+
LOCAL_FILES_ONLY = False
|
29 |
+
|
30 |
+
|
31 |
+
def load_pretrained_block(
|
32 |
+
converted_model_name_or_path: str,
|
33 |
+
block_index: int,
|
34 |
+
config: Optional[BloomConfig] = None,
|
35 |
+
torch_dtype: Union[torch.dtype, str] = "auto",
|
36 |
+
use_auth_token: Optional[str] = None,
|
37 |
+
) -> BloomBlock:
|
38 |
+
"""Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
|
39 |
+
if config is None:
|
40 |
+
config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
|
41 |
+
block = BloomBlock(config, layer_number=block_index)
|
42 |
+
state_dict = _load_state_dict(converted_model_name_or_path, block_index, use_auth_token=use_auth_token)
|
43 |
+
block.load_state_dict(state_dict)
|
44 |
+
|
45 |
+
if torch_dtype == "auto":
|
46 |
+
with torch.no_grad():
|
47 |
+
for name, param in block.named_parameters():
|
48 |
+
assert name in state_dict, f"{name} not in state dict"
|
49 |
+
param.data = param.data.to(state_dict[name].dtype)
|
50 |
+
else:
|
51 |
+
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
52 |
+
block = block.to(dtype=torch_dtype)
|
53 |
+
|
54 |
+
report = block.load_state_dict(state_dict, strict=True)
|
55 |
+
logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
|
56 |
+
return block
|
57 |
+
|
58 |
+
|
59 |
+
def _load_state_dict(
|
60 |
+
pretrained_model_name_or_path: str, block_index: Optional[int] = None, use_auth_token: Optional[str] = None
|
61 |
+
) -> OrderedDict[str, torch.Tensor]:
|
62 |
+
revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
|
63 |
+
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
|
64 |
+
|
65 |
+
# Load from URL or cache if already cached
|
66 |
+
resolved_archive_file = cached_path(
|
67 |
+
archive_file,
|
68 |
+
cache_dir=None,
|
69 |
+
force_download=FORCE_DOWNLOAD,
|
70 |
+
proxies=None,
|
71 |
+
resume_download=RESUME_DOWNLOAD,
|
72 |
+
local_files_only=LOCAL_FILES_ONLY,
|
73 |
+
use_auth_token=use_auth_token,
|
74 |
+
user_agent=USER_AGENT,
|
75 |
+
)
|
76 |
+
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
77 |
+
return state_dict
|
78 |
+
|
79 |
+
|
80 |
+
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
src/bloom/model.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
PyTorch BLOOM model that implements several memory-efficient modes.
|
3 |
+
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
4 |
+
See commit history for authorship.
|
5 |
+
"""
|
6 |
+
from typing import Tuple
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.utils.checkpoint
|
11 |
+
from hivemind import use_hivemind_log_handler
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import CrossEntropyLoss, LayerNorm
|
14 |
+
from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
|
15 |
+
add_start_docstrings_to_model_forward)
|
16 |
+
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
|
17 |
+
from transformers.modeling_utils import PreTrainedModel
|
18 |
+
from transformers.models.bloom.configuration_bloom import BloomConfig
|
19 |
+
from transformers.utils import logging
|
20 |
+
|
21 |
+
from src.bloom.block import BloomBlock
|
22 |
+
|
23 |
+
use_hivemind_log_handler("in_root_logger")
|
24 |
+
logger = logging.get_logger(__file__)
|
25 |
+
|
26 |
+
_CHECKPOINT_FOR_DOC = "bigscience/Bloom"
|
27 |
+
_CONFIG_FOR_DOC = "BloomConfig"
|
28 |
+
_TOKENIZER_FOR_DOC = "BloomTokenizer"
|
29 |
+
|
30 |
+
|
31 |
+
class BloomPreTrainedModel(PreTrainedModel):
|
32 |
+
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
33 |
+
"""
|
34 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
35 |
+
models.
|
36 |
+
"""
|
37 |
+
|
38 |
+
config_class = BloomConfig
|
39 |
+
base_model_prefix = "transformer"
|
40 |
+
supports_gradient_checkpointing = True
|
41 |
+
_no_split_modules = ["BloomBlock"]
|
42 |
+
|
43 |
+
def __init__(self, *inputs, **kwargs):
|
44 |
+
super().__init__(*inputs, **kwargs)
|
45 |
+
|
46 |
+
def _init_weights(self, module):
|
47 |
+
"""Initialize the weights."""
|
48 |
+
if isinstance(module, (nn.Linear)):
|
49 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
50 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
51 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
52 |
+
if module.bias is not None:
|
53 |
+
module.bias.data.zero_()
|
54 |
+
elif isinstance(module, nn.Embedding):
|
55 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
56 |
+
if module.padding_idx is not None:
|
57 |
+
module.weight.data[module.padding_idx].zero_()
|
58 |
+
elif isinstance(module, LayerNorm):
|
59 |
+
module.bias.data.zero_()
|
60 |
+
module.weight.data.fill_(1.0)
|
61 |
+
|
62 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
63 |
+
if isinstance(module, BloomModel):
|
64 |
+
module.gradient_checkpointing = value
|
65 |
+
|
66 |
+
|
67 |
+
BLOOM_START_DOCSTRING = r"""
|
68 |
+
|
69 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
70 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
|
71 |
+
|
72 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
73 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
74 |
+
and behavior.
|
75 |
+
|
76 |
+
Parameters:
|
77 |
+
config ([`MemoryEfficientBloomConfig`]): Model configuration class with all the parameters of the model.
|
78 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
79 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
80 |
+
"""
|
81 |
+
|
82 |
+
BLOOM_INPUTS_DOCSTRING = r"""
|
83 |
+
Args:
|
84 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
85 |
+
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
86 |
+
`past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
|
87 |
+
sequence tokens in the vocabulary.
|
88 |
+
|
89 |
+
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
90 |
+
`input_ids`.
|
91 |
+
|
92 |
+
Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
93 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
94 |
+
|
95 |
+
[What are input IDs?](../glossary#input-ids)
|
96 |
+
past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
|
97 |
+
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
98 |
+
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
|
99 |
+
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
100 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
101 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
102 |
+
|
103 |
+
- 1 for tokens that are **not masked**,
|
104 |
+
- 0 for tokens that are **masked**.
|
105 |
+
|
106 |
+
[What are attention masks?](../glossary#attention-mask)
|
107 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
108 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
109 |
+
config.max_position_embeddings - 1]`.
|
110 |
+
|
111 |
+
[What are position IDs?](../glossary#position-ids)
|
112 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
113 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
114 |
+
|
115 |
+
- 1 indicates the head is **not masked**,
|
116 |
+
- 0 indicates the head is **masked**.
|
117 |
+
|
118 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
119 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
120 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
121 |
+
model's internal embedding lookup matrix.
|
122 |
+
|
123 |
+
If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
|
124 |
+
`past_key_values`).
|
125 |
+
use_cache (`bool`, *optional*):
|
126 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
127 |
+
`past_key_values`).
|
128 |
+
output_attentions (`bool`, *optional*):
|
129 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
130 |
+
tensors for more detail.
|
131 |
+
output_hidden_states (`bool`, *optional*):
|
132 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
133 |
+
more detail.
|
134 |
+
return_dict (`bool`, *optional*):
|
135 |
+
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
136 |
+
"""
|
137 |
+
|
138 |
+
|
139 |
+
@add_start_docstrings(
|
140 |
+
"The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
|
141 |
+
BLOOM_START_DOCSTRING,
|
142 |
+
)
|
143 |
+
class BloomModel(BloomPreTrainedModel):
|
144 |
+
def __init__(self, config):
|
145 |
+
super().__init__(config)
|
146 |
+
assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
|
147 |
+
|
148 |
+
self.embed_dim = config.hidden_size
|
149 |
+
self.n_head = config.n_head
|
150 |
+
|
151 |
+
# Embedding + LN Embedding
|
152 |
+
|
153 |
+
# TODO: @dbaranchuk make efficient fp16 on cpu (convert only word_embeddings!)
|
154 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) # dtype=config.torch_dtype
|
155 |
+
self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
156 |
+
|
157 |
+
# Transformer blocks
|
158 |
+
self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)])
|
159 |
+
|
160 |
+
# Final Layer Norm
|
161 |
+
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
162 |
+
|
163 |
+
self.gradient_checkpointing = False
|
164 |
+
|
165 |
+
# Initialize weights and apply final processing
|
166 |
+
self.post_init()
|
167 |
+
|
168 |
+
# Forbid accumulate grads for embeddings and layernorm
|
169 |
+
self.set_requires_grad(False)
|
170 |
+
|
171 |
+
def get_input_embeddings(self):
|
172 |
+
return self.word_embeddings
|
173 |
+
|
174 |
+
def set_input_embeddings(self, new_embeddings):
|
175 |
+
self.word_embeddings = new_embeddings
|
176 |
+
|
177 |
+
def set_requires_grad(self, value):
|
178 |
+
for p in self.parameters():
|
179 |
+
p.requires_grad = value
|
180 |
+
|
181 |
+
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
|
182 |
+
@add_code_sample_docstrings(
|
183 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
184 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
185 |
+
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
186 |
+
config_class=_CONFIG_FOR_DOC,
|
187 |
+
)
|
188 |
+
def forward(
|
189 |
+
self,
|
190 |
+
input_ids=None,
|
191 |
+
past_key_values=None,
|
192 |
+
attention_mask=None,
|
193 |
+
position_ids=None,
|
194 |
+
head_mask=None,
|
195 |
+
inputs_embeds=None,
|
196 |
+
use_cache=None,
|
197 |
+
output_attentions=None,
|
198 |
+
output_hidden_states=None,
|
199 |
+
return_dict=None,
|
200 |
+
):
|
201 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
202 |
+
output_hidden_states = (
|
203 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
204 |
+
)
|
205 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
206 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
207 |
+
|
208 |
+
if input_ids is not None and inputs_embeds is not None:
|
209 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
210 |
+
if position_ids is not None:
|
211 |
+
logger.warning("position_ids are ignored in this bloom implementation")
|
212 |
+
elif input_ids is not None:
|
213 |
+
input_shape = input_ids.size()
|
214 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
215 |
+
elif inputs_embeds is not None:
|
216 |
+
input_shape = inputs_embeds.size()[:-1]
|
217 |
+
else:
|
218 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
219 |
+
|
220 |
+
if past_key_values is None:
|
221 |
+
past_key_values = tuple([None] * len(self.h))
|
222 |
+
|
223 |
+
# Prepare head mask if needed
|
224 |
+
# 1.0 in head_mask indicate we keep the head
|
225 |
+
# attention_probs has shape bsz x n_head x N x N
|
226 |
+
# head_mask has shape n_layer x batch x n_head x N x N
|
227 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
228 |
+
|
229 |
+
if inputs_embeds is None:
|
230 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
231 |
+
|
232 |
+
hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
|
233 |
+
|
234 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
235 |
+
|
236 |
+
presents = () if use_cache else None
|
237 |
+
all_self_attentions = () if output_attentions else None
|
238 |
+
all_hidden_states = () if output_hidden_states else None
|
239 |
+
|
240 |
+
# Compute alibi tensor: check build_alibi_tensor documentation
|
241 |
+
current_sequence_length = hidden_states.shape[1]
|
242 |
+
if past_key_values and past_key_values[0]:
|
243 |
+
current_sequence_length += past_key_values[0][0].shape[1]
|
244 |
+
|
245 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
246 |
+
|
247 |
+
if output_hidden_states:
|
248 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
249 |
+
|
250 |
+
if self.gradient_checkpointing and self.training:
|
251 |
+
|
252 |
+
if use_cache:
|
253 |
+
logger.warning(
|
254 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
255 |
+
)
|
256 |
+
use_cache = False
|
257 |
+
|
258 |
+
def create_custom_forward(module):
|
259 |
+
def custom_forward(*inputs):
|
260 |
+
# None for past_key_value
|
261 |
+
return module(*inputs, use_cache, output_attentions, alibi=None)
|
262 |
+
|
263 |
+
return custom_forward
|
264 |
+
|
265 |
+
outputs = torch.utils.checkpoint.checkpoint(
|
266 |
+
create_custom_forward(block),
|
267 |
+
hidden_states,
|
268 |
+
None,
|
269 |
+
attention_mask,
|
270 |
+
head_mask[i],
|
271 |
+
)
|
272 |
+
else:
|
273 |
+
outputs = block(
|
274 |
+
hidden_states,
|
275 |
+
layer_past=layer_past,
|
276 |
+
attention_mask=attention_mask,
|
277 |
+
head_mask=head_mask[i],
|
278 |
+
use_cache=use_cache,
|
279 |
+
output_attentions=output_attentions,
|
280 |
+
alibi=None,
|
281 |
+
)
|
282 |
+
|
283 |
+
hidden_states = outputs[0]
|
284 |
+
if use_cache is True:
|
285 |
+
presents = presents + (outputs[1],)
|
286 |
+
|
287 |
+
if output_attentions:
|
288 |
+
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
289 |
+
|
290 |
+
# Add last hidden state
|
291 |
+
hidden_states = self.ln_f(hidden_states)
|
292 |
+
|
293 |
+
if output_hidden_states:
|
294 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
295 |
+
|
296 |
+
hidden_states = hidden_states.view(output_shape)
|
297 |
+
|
298 |
+
if not return_dict:
|
299 |
+
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
300 |
+
|
301 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
302 |
+
last_hidden_state=hidden_states,
|
303 |
+
past_key_values=presents,
|
304 |
+
hidden_states=all_hidden_states,
|
305 |
+
attentions=all_self_attentions,
|
306 |
+
)
|
307 |
+
|
308 |
+
|
309 |
+
@add_start_docstrings(
|
310 |
+
"""
|
311 |
+
The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
312 |
+
embeddings).
|
313 |
+
""",
|
314 |
+
BLOOM_START_DOCSTRING,
|
315 |
+
)
|
316 |
+
class BloomForCausalLM(BloomPreTrainedModel):
|
317 |
+
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
318 |
+
|
319 |
+
def __init__(self, config):
|
320 |
+
super().__init__(config)
|
321 |
+
self.transformer = BloomModel(config)
|
322 |
+
# Initialize weights and apply final processing
|
323 |
+
self.post_init()
|
324 |
+
|
325 |
+
def get_output_embeddings(self):
|
326 |
+
return self.transformer.word_embeddings
|
327 |
+
|
328 |
+
def set_output_embeddings(self, new_embeddings):
|
329 |
+
self.transformer.word_embeddings.weight = new_embeddings.weight
|
330 |
+
|
331 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
332 |
+
# only last token for inputs_ids if past is defined in kwargs
|
333 |
+
if past:
|
334 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
335 |
+
|
336 |
+
attention_mask = kwargs.get("attention_mask", None)
|
337 |
+
position_ids = kwargs.get("position_ids", None)
|
338 |
+
|
339 |
+
if attention_mask is not None and position_ids is None:
|
340 |
+
# create position_ids on the fly for batch generation
|
341 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
342 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
343 |
+
if past:
|
344 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
345 |
+
else:
|
346 |
+
position_ids = None
|
347 |
+
return {
|
348 |
+
"input_ids": input_ids,
|
349 |
+
"past_key_values": past,
|
350 |
+
"use_cache": kwargs.get("use_cache"),
|
351 |
+
"position_ids": position_ids,
|
352 |
+
"attention_mask": attention_mask,
|
353 |
+
}
|
354 |
+
|
355 |
+
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
|
356 |
+
@add_code_sample_docstrings(
|
357 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
358 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
359 |
+
output_type=CausalLMOutputWithCrossAttentions,
|
360 |
+
config_class=_CONFIG_FOR_DOC,
|
361 |
+
)
|
362 |
+
def forward(self, input_ids=None, labels=None, return_dict=None, **kwargs):
|
363 |
+
r"""
|
364 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
365 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
366 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
367 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
368 |
+
"""
|
369 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
370 |
+
transformer_outputs = self.transformer.forward(input_ids=input_ids, return_dict=return_dict, **kwargs)
|
371 |
+
word_embeddings = self.transformer.word_embeddings.weight
|
372 |
+
|
373 |
+
# Switch dtype in case word_embeddings are fp16/bf16
|
374 |
+
hidden_states = transformer_outputs[0].to(word_embeddings.dtype)
|
375 |
+
lm_logits = F.linear(hidden_states, word_embeddings).float()
|
376 |
+
|
377 |
+
loss = None
|
378 |
+
if labels is not None:
|
379 |
+
# Shift so that tokens < n predict n
|
380 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
381 |
+
shift_labels = labels[..., 1:].contiguous()
|
382 |
+
# Flatten the tokens
|
383 |
+
loss_fct = CrossEntropyLoss()
|
384 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
385 |
+
|
386 |
+
if not return_dict:
|
387 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
388 |
+
return ((loss,) + output) if loss is not None else output
|
389 |
+
|
390 |
+
return CausalLMOutputWithCrossAttentions(
|
391 |
+
loss=loss,
|
392 |
+
logits=lm_logits,
|
393 |
+
past_key_values=transformer_outputs.past_key_values,
|
394 |
+
hidden_states=transformer_outputs.hidden_states,
|
395 |
+
attentions=transformer_outputs.attentions,
|
396 |
+
)
|
397 |
+
|
398 |
+
@staticmethod
|
399 |
+
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
400 |
+
"""
|
401 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
402 |
+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
403 |
+
beam_idx at every generation step.
|
404 |
+
"""
|
405 |
+
return tuple(
|
406 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
407 |
+
for layer_past in past
|
408 |
+
)
|
src/bloom/ops.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utility operations used in the the BLOOM model
|
3 |
+
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
4 |
+
See commit history for authorship.
|
5 |
+
"""
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.autograd
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
|
15 |
+
"""Split a tensor along its last dimension.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
tensor: ([`torch.tensor`], *required*):
|
19 |
+
input tensor to split
|
20 |
+
num_partitions ([`int`], *required*):
|
21 |
+
number of partitions to split the tensor
|
22 |
+
contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
|
23 |
+
If True, make each chunk contiguous in memory.
|
24 |
+
"""
|
25 |
+
# Get the size and dimension.
|
26 |
+
last_dim = tensor.dim() - 1
|
27 |
+
numerator, denominator = tensor.size()[last_dim], num_partitions
|
28 |
+
if not (numerator % denominator == 0):
|
29 |
+
raise ValueError(f"{numerator} is not divisible by {denominator}")
|
30 |
+
last_dim_size = numerator // denominator
|
31 |
+
# Split.
|
32 |
+
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
33 |
+
# Note: torch.split does not create contiguous tensors by default.
|
34 |
+
if contiguous_split_chunks:
|
35 |
+
return tuple(chunk.contiguous() for chunk in tensor_list)
|
36 |
+
|
37 |
+
return tensor_list
|
38 |
+
|
39 |
+
|
40 |
+
def attention_mask_func(attention_scores, attention_mask, causal_mask):
|
41 |
+
if attention_mask.dtype == torch.bool:
|
42 |
+
attention_mask_bool = ~attention_mask
|
43 |
+
else:
|
44 |
+
attention_mask_bool = (1 - attention_mask).bool()
|
45 |
+
|
46 |
+
query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
|
47 |
+
padded_causal_mask = (
|
48 |
+
attention_mask_bool[:, None, key_length - query_length : key_length, None]
|
49 |
+
+ ~causal_mask[:, :, key_length - query_length : key_length, :key_length]
|
50 |
+
).bool()
|
51 |
+
padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool()
|
52 |
+
# Make use of floats
|
53 |
+
return (
|
54 |
+
attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
|
55 |
+
padded_causal_mask,
|
56 |
+
)
|
57 |
+
|
58 |
+
|
59 |
+
def build_alibi_tensor(
|
60 |
+
max_seq_len: int, n_head: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu")
|
61 |
+
) -> torch.Tensor:
|
62 |
+
"""
|
63 |
+
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
|
64 |
+
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
65 |
+
`softmax(l+a) = softmax(l)`. Based on
|
66 |
+
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
|
67 |
+
Args:
|
68 |
+
Returns tensor shaped (n_head, 1, max_seq_len)
|
69 |
+
max_seq_len: (`int`, *required*):
|
70 |
+
max sequence length
|
71 |
+
n_head: (`int`, *required*):
|
72 |
+
number of heads
|
73 |
+
dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`):
|
74 |
+
dtype of the output tensor
|
75 |
+
device: (`torch.device`, *optional*, default=`torch.device('cpu')`):
|
76 |
+
device of the output alibi tensor
|
77 |
+
"""
|
78 |
+
closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
|
79 |
+
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
|
80 |
+
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
|
81 |
+
slopes = torch.pow(base, powers)
|
82 |
+
|
83 |
+
if closest_power_of_2 != n_head:
|
84 |
+
extra_base = torch.tensor(
|
85 |
+
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
|
86 |
+
)
|
87 |
+
num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2)
|
88 |
+
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
|
89 |
+
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
90 |
+
|
91 |
+
lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32)
|
92 |
+
return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype)
|
93 |
+
|
94 |
+
|
95 |
+
def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor):
|
96 |
+
"""
|
97 |
+
Args:
|
98 |
+
Pre-process the alibi tensor for padding.
|
99 |
+
alibi: ([`torch.tensor`], *required*):
|
100 |
+
alibi tensor to pre-process
|
101 |
+
attention_mask: ([`torch.tensor`], *required*):
|
102 |
+
attention mask to pre-process
|
103 |
+
"""
|
104 |
+
assert attention_mask.shape.ndim == 2, "mask should be [batch_size, seq_length]"
|
105 |
+
unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1)
|
106 |
+
# ^-- [batch, max_len], values correspond to element indices after removing padding
|
107 |
+
# We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0
|
108 |
+
alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0)
|
109 |
+
return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1)
|
110 |
+
|
111 |
+
|
112 |
+
def dropout_add(x, residual, prob, training):
|
113 |
+
"""
|
114 |
+
Dropout add function
|
115 |
+
|
116 |
+
Args:
|
117 |
+
x (`torch.tensor`, *required*):
|
118 |
+
input tensor
|
119 |
+
residual (`torch.tensor`, *rquired*):
|
120 |
+
esidual tensor
|
121 |
+
prob (`float`, *required*):
|
122 |
+
dropout probability
|
123 |
+
training (`bool`, *required*):
|
124 |
+
training mode
|
125 |
+
"""
|
126 |
+
out = nn.functional.dropout(x, p=prob, training=training)
|
127 |
+
out = residual + out
|
128 |
+
return out
|
129 |
+
|
130 |
+
|
131 |
+
def bloom_gelu_forward(x):
|
132 |
+
"""
|
133 |
+
Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
|
134 |
+
make the model jitable.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
x (`torch.tensor`, *required*):
|
138 |
+
input hidden states
|
139 |
+
"""
|
140 |
+
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
141 |
+
|
142 |
+
|
143 |
+
def bloom_gelu_back(g, x):
|
144 |
+
"""
|
145 |
+
gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
|
146 |
+
0.3989423 * x * torch.exp(-0.5 * x * x)
|
147 |
+
|
148 |
+
Args:
|
149 |
+
g (`torch.tensor`, *required*):
|
150 |
+
gradient output tensor
|
151 |
+
x (`torch.tensor`, *required*):
|
152 |
+
input tensor
|
153 |
+
"""
|
154 |
+
x = x[0] # x is a tuple of 1 element, needs to unpack it first
|
155 |
+
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
156 |
+
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
157 |
+
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
158 |
+
return ff * g
|
159 |
+
|
160 |
+
|
161 |
+
class GeLUFunction(torch.autograd.Function):
|
162 |
+
@staticmethod
|
163 |
+
def forward(ctx, input):
|
164 |
+
ctx.save_for_backward(input)
|
165 |
+
return bloom_gelu_forward(input)
|
166 |
+
|
167 |
+
@staticmethod
|
168 |
+
def backward(ctx, grad_output):
|
169 |
+
input = ctx.saved_tensors
|
170 |
+
tmp = bloom_gelu_back(grad_output, input)
|
171 |
+
return tmp
|
172 |
+
|
173 |
+
|
174 |
+
class BloomGelu(nn.Module):
|
175 |
+
"""
|
176 |
+
BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model
|
177 |
+
torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
|
178 |
+
copied from Megatron-DeepSpeed code and adapted for our needs
|
179 |
+
|
180 |
+
See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
|
181 |
+
|
182 |
+
"""
|
183 |
+
|
184 |
+
def __init__(self):
|
185 |
+
super().__init__()
|
186 |
+
|
187 |
+
def forward(self, x):
|
188 |
+
if self.training:
|
189 |
+
return GeLUFunction.apply(x)
|
190 |
+
else:
|
191 |
+
return bloom_gelu_forward(x)
|
192 |
+
|
193 |
+
|
194 |
+
class BloomScaledSoftmax(nn.Module):
|
195 |
+
"""
|
196 |
+
fused operation: scaling + mask + softmax
|
197 |
+
|
198 |
+
Args:
|
199 |
+
input_in_fp16 (`bool`, *required*):
|
200 |
+
flag to indicate if input in fp16 data format.
|
201 |
+
input_in_bf16 (`bool`, *required*):
|
202 |
+
flag to indicate if input in bf16 data format.
|
203 |
+
scaled_masked_softmax_fusion (`bool`, *required*):
|
204 |
+
flag to indicate user want to use softmax fusion
|
205 |
+
mask_func (`function`, *required*):
|
206 |
+
mask function to be applied.
|
207 |
+
softmax_in_fp32 (`bool`, *required*):
|
208 |
+
if true, softmax in performed at fp32 precision.
|
209 |
+
scale (`float`, *required*):
|
210 |
+
scaling factor used in input tensor scaling.
|
211 |
+
"""
|
212 |
+
|
213 |
+
def __init__(self, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale):
|
214 |
+
super().__init__()
|
215 |
+
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
|
216 |
+
self.mask_func = mask_func
|
217 |
+
self.softmax_in_fp32 = softmax_in_fp32
|
218 |
+
self.scale = scale
|
219 |
+
|
220 |
+
if not (self.scale is None or softmax_in_fp32):
|
221 |
+
raise ValueError("softmax should be in fp32 when scaled")
|
222 |
+
|
223 |
+
def forward(self, input, mask, max_positions):
|
224 |
+
input_dtype = input.dtype
|
225 |
+
input_in_16bit = input_dtype in [torch.float16, torch.bfloat16]
|
226 |
+
softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype
|
227 |
+
|
228 |
+
if self.scale is not None:
|
229 |
+
input = input * self.scale
|
230 |
+
|
231 |
+
if mask is None:
|
232 |
+
mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
|
233 |
+
|
234 |
+
mask = mask.to(input.device)
|
235 |
+
causal_mask = (
|
236 |
+
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
|
237 |
+
.view(1, 1, max_positions, max_positions)
|
238 |
+
.to(input.device)
|
239 |
+
)
|
240 |
+
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
|
241 |
+
probs = F.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
|
242 |
+
|
243 |
+
if input_in_16bit and self.softmax_in_fp32:
|
244 |
+
probs = probs.to(dtype=input_dtype)
|
245 |
+
|
246 |
+
return probs
|
src/client/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
|
2 |
+
from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
|
3 |
+
from src.client.remote_sequence_info import RemoteSequenceInfo
|
4 |
+
from src.client.remote_sequential import RemoteSequential
|
src/client/remote_block.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
|
2 |
+
from __future__ import annotations
|
3 |
+
|
4 |
+
import asyncio
|
5 |
+
import random
|
6 |
+
from typing import Any, AsyncIterator, Dict, Optional
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
|
10 |
+
from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
|
11 |
+
from hivemind.moe.expert_uid import ExpertInfo
|
12 |
+
from hivemind.p2p import P2P, StubBase
|
13 |
+
from hivemind.proto import runtime_pb2
|
14 |
+
from hivemind.utils import anext, get_logger, nested_flatten, use_hivemind_log_handler
|
15 |
+
|
16 |
+
from src.data_structures import RemoteModuleInfo
|
17 |
+
from src.dht_utils import ModuleUID
|
18 |
+
from src.server.handler import TransformerConnectionHandler
|
19 |
+
|
20 |
+
use_hivemind_log_handler("in_root_logger")
|
21 |
+
logger = get_logger(__file__)
|
22 |
+
|
23 |
+
|
24 |
+
class RemoteTransformerBlock(RemoteExpert):
|
25 |
+
"""A class that interacts with a remote module on a specific server for forward/backward or inference"""
|
26 |
+
|
27 |
+
def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
|
28 |
+
peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.peer_ids))) # TODO replace this
|
29 |
+
super().__init__(peer_info, p2p)
|
30 |
+
|
31 |
+
@property
|
32 |
+
def stub(self) -> StubBase:
|
33 |
+
return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
|
34 |
+
|
35 |
+
def forward(self, inputs: torch.Tensor, **kwargs):
|
36 |
+
for k, v in kwargs.items():
|
37 |
+
assert v is None or v is False, f"Extra keyword arguments are not yet supported (got {k} = {v})"
|
38 |
+
return super().forward(inputs)
|
39 |
+
|
40 |
+
def inference_session(self) -> RemoteTransformerBlockInferenceSession:
|
41 |
+
"""Initialize a new inference session with the specified remote server"""
|
42 |
+
_ = self.info # create _info manually since the built-in property will not work inside RemoteExpertWorker
|
43 |
+
return RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(self))
|
44 |
+
|
45 |
+
def begin_inference_session(self):
|
46 |
+
logger.warning("beging_inference_session was renamed to just inference_session")
|
47 |
+
return self.inference_session()
|
48 |
+
|
49 |
+
|
50 |
+
class RemoteTransformerBlockInferenceSession:
|
51 |
+
"""An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
|
52 |
+
|
53 |
+
def __init__(self, uid: ModuleUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
|
54 |
+
self.uid, self.info = uid, info
|
55 |
+
# warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
|
56 |
+
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
|
57 |
+
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
|
58 |
+
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
|
59 |
+
self.stepped = False
|
60 |
+
self.closed = False
|
61 |
+
|
62 |
+
@classmethod
|
63 |
+
async def _create(
|
64 |
+
cls, remote_module: RemoteTransformerBlock, timeout: Optional[float] = None
|
65 |
+
) -> RemoteTransformerBlockInferenceSession:
|
66 |
+
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
|
67 |
+
inputs_queue = asyncio.Queue()
|
68 |
+
outputs_stream = await remote_module.stub.rpc_inference(
|
69 |
+
cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout
|
70 |
+
)
|
71 |
+
return cls(remote_module.uid, remote_module.info, inputs_queue, outputs_stream)
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
|
75 |
+
while True:
|
76 |
+
next_input_message = await asyncio.wait_for(queue.get(), timeout)
|
77 |
+
yield next_input_message
|
78 |
+
if not next_input_message.uid and not next_input_message.tensors:
|
79 |
+
break # this message means "done sending"
|
80 |
+
|
81 |
+
def step(self, new_hidden_states: torch.Tensor):
|
82 |
+
"""Inference step: send a chunk of input tensors and receive a chunk of outputs"""
|
83 |
+
if self.closed:
|
84 |
+
raise Exception("Session is closed, cannot perform step")
|
85 |
+
# serialize inputs and put them into the queue
|
86 |
+
inputs = (new_hidden_states,)
|
87 |
+
outputs_serialized = RemoteExpertWorker.run_coroutine(
|
88 |
+
self._step(
|
89 |
+
runtime_pb2.ExpertRequest(
|
90 |
+
uid=self.uid,
|
91 |
+
tensors=[
|
92 |
+
serialize_torch_tensor(tensor, proto.compression)
|
93 |
+
for tensor, proto in zip(inputs, nested_flatten(self.info["forward_schema"]))
|
94 |
+
],
|
95 |
+
)
|
96 |
+
)
|
97 |
+
)
|
98 |
+
outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
|
99 |
+
assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
|
100 |
+
return outputs[0]
|
101 |
+
|
102 |
+
async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
|
103 |
+
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
|
104 |
+
await self._inputs_queue.put(inputs_serialized)
|
105 |
+
self.stepped = True
|
106 |
+
return await anext(self._outputs_stream)
|
107 |
+
|
108 |
+
def close(self):
|
109 |
+
"""Finish a given inference session, close the underlying connection"""
|
110 |
+
if self._outputs_stream is None:
|
111 |
+
return # already closed
|
112 |
+
RemoteExpertWorker.run_coroutine(self._aclose_stream())
|
113 |
+
self._outputs_stream = self._inputs_queue = None
|
114 |
+
self.closed = True
|
115 |
+
|
116 |
+
async def _aclose_stream(self):
|
117 |
+
"""Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
|
118 |
+
if self._outputs_stream is None:
|
119 |
+
return # already closed
|
120 |
+
if self.stepped:
|
121 |
+
await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
|
122 |
+
try:
|
123 |
+
await anext(self._outputs_stream)
|
124 |
+
except StopAsyncIteration:
|
125 |
+
pass
|
126 |
+
|
127 |
+
def __del__(self):
|
128 |
+
self.close()
|
129 |
+
|
130 |
+
def __enter__(self):
|
131 |
+
assert not self.closed
|
132 |
+
return self
|
133 |
+
|
134 |
+
def __exit__(self, *exc_details):
|
135 |
+
self.close()
|
src/client/remote_model.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this code is in active development, interfaces may change
|
2 |
+
import os
|
3 |
+
from typing import Optional, Tuple, Union
|
4 |
+
|
5 |
+
import hivemind
|
6 |
+
from hivemind import DHT, get_logger, use_hivemind_log_handler
|
7 |
+
|
8 |
+
from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
|
9 |
+
from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel
|
10 |
+
from src.client.remote_sequential import RemoteSequential
|
11 |
+
from src.data_structures import UID_DELIMITER
|
12 |
+
|
13 |
+
use_hivemind_log_handler("in_root_logger")
|
14 |
+
logger = get_logger(__file__)
|
15 |
+
|
16 |
+
|
17 |
+
class DistributedBloomConfig(BloomConfig):
|
18 |
+
"""
|
19 |
+
A bloom config that contains information about DHT peers.
|
20 |
+
To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
|
21 |
+
"""
|
22 |
+
|
23 |
+
initial_peers: Tuple[str, ...] = () # a list of initial peers for hivemind DHT
|
24 |
+
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
|
25 |
+
dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
|
26 |
+
|
27 |
+
|
28 |
+
class DistributedBloomModel(BloomModel):
|
29 |
+
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
30 |
+
config_class = DistributedBloomConfig
|
31 |
+
|
32 |
+
def __init__(self, config: DistributedBloomConfig):
|
33 |
+
assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
|
34 |
+
assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
|
35 |
+
|
36 |
+
n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
|
37 |
+
super().__init__(config)
|
38 |
+
assert len(self.h) == 0
|
39 |
+
config.n_layer = n_layer
|
40 |
+
|
41 |
+
dht = (
|
42 |
+
config.dht
|
43 |
+
if config.dht is not None
|
44 |
+
else hivemind.DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
|
45 |
+
)
|
46 |
+
assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
|
47 |
+
self.h = RemoteSequential(config, dht, config.dht_prefix)
|
48 |
+
|
49 |
+
|
50 |
+
class DistributedBloomForCausalLM(BloomForCausalLM):
|
51 |
+
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
52 |
+
config_class = DistributedBloomConfig
|
53 |
+
|
54 |
+
def __init__(self, config: DistributedBloomConfig):
|
55 |
+
BloomPreTrainedModel.__init__(self, config)
|
56 |
+
self.transformer = DistributedBloomModel(config)
|
57 |
+
# Initialize weights and apply final processing
|
58 |
+
self.post_init()
|
src/client/remote_sequence_info.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import dataclasses
|
4 |
+
import threading
|
5 |
+
from functools import partial
|
6 |
+
from typing import List, NamedTuple, Optional, Sequence, Tuple
|
7 |
+
|
8 |
+
from hivemind import DHT, PeerID
|
9 |
+
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
10 |
+
|
11 |
+
from src.data_structures import ModuleUID, RemoteModuleInfo
|
12 |
+
from src.dht_utils import _get_remote_module_infos
|
13 |
+
|
14 |
+
use_hivemind_log_handler("in_root_logger")
|
15 |
+
logger = get_logger(__file__)
|
16 |
+
|
17 |
+
|
18 |
+
Span = NamedTuple("Span", [("start", int), ("end", Optional[int]), ("peer_id", PeerID)])
|
19 |
+
|
20 |
+
|
21 |
+
@dataclasses.dataclass(frozen=False, init=False) # TODO[borzunov@] eto ne dataclass
|
22 |
+
class RemoteSequenceInfo:
|
23 |
+
"""Keeps and updates the meta-information about which peers host which blocks"""
|
24 |
+
|
25 |
+
dht: DHT
|
26 |
+
block_uids: List[ModuleUID, ...]
|
27 |
+
block_infos: List[Optional[RemoteModuleInfo], ...]
|
28 |
+
spans_by_priority: List[Span] # sorted from best to worst
|
29 |
+
spans_containing_block: Tuple[List[Span], ...]
|
30 |
+
lock_changes: threading.Lock
|
31 |
+
|
32 |
+
def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
|
33 |
+
self.dht = dht
|
34 |
+
self.block_uids = list(block_uids)
|
35 |
+
self.block_infos: List[Optional[RemoteModuleInfo], ...] = [None] * len(self.block_uids)
|
36 |
+
self.spans_by_priority = []
|
37 |
+
self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
|
38 |
+
self.lock_changes = threading.Lock()
|
39 |
+
self.update_()
|
40 |
+
|
41 |
+
for uid, info in zip(self.block_uids, self.block_infos):
|
42 |
+
assert info is not None, f"Found no remote peers for block {uid}"
|
43 |
+
assert self.spans_by_priority and self.spans_containing_block
|
44 |
+
|
45 |
+
def update_(self):
|
46 |
+
with self.lock_changes:
|
47 |
+
self.update_block_infos_()
|
48 |
+
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
|
49 |
+
|
50 |
+
def update_block_infos_(self):
|
51 |
+
new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine(
|
52 |
+
partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")), return_future=False
|
53 |
+
)
|
54 |
+
assert len(new_block_infos) == len(self.block_uids)
|
55 |
+
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
|
56 |
+
if info is None:
|
57 |
+
logger.warning(f"Found no block info for block {uid}")
|
58 |
+
if not isinstance(info, RemoteModuleInfo):
|
59 |
+
logger.warning(f"Unexpected dht entry type for {uid}: {info}")
|
60 |
+
if not info.peer_ids:
|
61 |
+
logger.warning(f"Found no active peers for block {uid}")
|
62 |
+
if info.uid != uid:
|
63 |
+
logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
|
64 |
+
if not isinstance(info.peer_ids, set):
|
65 |
+
logger.warning(f"Expected peer_ids for {uid} to be a set, got {type(info.peer_ids)}")
|
66 |
+
self.block_infos[block_index] = info
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
|
70 |
+
closed_spans = []
|
71 |
+
active_spans = {}
|
72 |
+
for block_index, info in enumerate(block_infos):
|
73 |
+
for peer_id in info.peer_ids:
|
74 |
+
if peer_id not in active_spans:
|
75 |
+
active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
|
76 |
+
else: # peer_id in active_spans
|
77 |
+
active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
|
78 |
+
|
79 |
+
for peer_id in list(active_spans.keys()):
|
80 |
+
if peer_id not in info.peer_ids or block_index == len(block_infos) - 1:
|
81 |
+
closed_spans.append(active_spans.pop(peer_id))
|
82 |
+
assert not active_spans
|
83 |
+
|
84 |
+
closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
|
85 |
+
|
86 |
+
spans_containing_block = tuple(list() for _ in range(len(block_infos)))
|
87 |
+
for span in closed_spans:
|
88 |
+
for block_index in range(span.start, span.end):
|
89 |
+
spans_containing_block[block_index].append(span)
|
90 |
+
|
91 |
+
return closed_spans, spans_containing_block
|
92 |
+
|
93 |
+
def __len__(self):
|
94 |
+
return len(self.block_uids)
|
src/client/remote_sequential.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import contextlib
|
4 |
+
import logging
|
5 |
+
import random
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
|
9 |
+
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
10 |
+
from hivemind.moe.expert_uid import ExpertInfo
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
import src
|
14 |
+
from src.client.remote_block import RemoteTransformerBlock
|
15 |
+
from src.client.remote_sequence_info import RemoteSequenceInfo
|
16 |
+
from src.data_structures import UID_DELIMITER
|
17 |
+
from src.dht_utils import _create_remote_modules_from_infos
|
18 |
+
|
19 |
+
use_hivemind_log_handler("in_root_logger")
|
20 |
+
logger = get_logger(__file__)
|
21 |
+
|
22 |
+
|
23 |
+
class RemoteSequential(nn.Module):
|
24 |
+
"""
|
25 |
+
A sequence of transformer blocks hosted by the swarm.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, config: src.DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
|
29 |
+
logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
|
30 |
+
if prefix.endswith(UID_DELIMITER):
|
31 |
+
logger.warning(
|
32 |
+
f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
|
33 |
+
f"This will cause {self.__class__.__name__} to look for modules under "
|
34 |
+
f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended."
|
35 |
+
)
|
36 |
+
|
37 |
+
super().__init__()
|
38 |
+
self.config = config
|
39 |
+
self.dht = dht
|
40 |
+
self.prefix = prefix
|
41 |
+
self.max_retries = max_retries
|
42 |
+
self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
|
43 |
+
|
44 |
+
block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
|
45 |
+
|
46 |
+
logger.debug(f"Remote block uids: {block_uids}")
|
47 |
+
self.remote_sequence_info = RemoteSequenceInfo(dht, block_uids)
|
48 |
+
|
49 |
+
def forward(self, inputs: torch.Tensor):
|
50 |
+
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
|
51 |
+
for block_index in range(self.config.n_layer):
|
52 |
+
for retry_index in range(self.max_retries):
|
53 |
+
try:
|
54 |
+
block = self[block_index]
|
55 |
+
(outputs,) = block(inputs)
|
56 |
+
assert isinstance(outputs, torch.Tensor)
|
57 |
+
assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
|
58 |
+
inputs = outputs
|
59 |
+
break
|
60 |
+
except Exception as e:
|
61 |
+
if retry_index == self.max_retries - 1:
|
62 |
+
raise e
|
63 |
+
else:
|
64 |
+
logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
|
65 |
+
return inputs
|
66 |
+
|
67 |
+
def __getitem__(self, block_index: int):
|
68 |
+
assert 0 <= block_index < self.config.n_layer
|
69 |
+
(module,) = _create_remote_modules_from_infos([self.remote_sequence_info.block_infos[block_index]], self.p2p)
|
70 |
+
return module
|
71 |
+
|
72 |
+
def __iter__(self):
|
73 |
+
for block_index in range(self.config.n_layer):
|
74 |
+
yield self[block_index]
|
75 |
+
|
76 |
+
def __len__(self):
|
77 |
+
return len(self.remote_sequence_info)
|
78 |
+
|
79 |
+
def inference_session(self) -> RemoteSequentialInferenceSession:
|
80 |
+
self.remote_sequence_info.update_()
|
81 |
+
return RemoteSequentialInferenceSession(self.remote_sequence_info, self.p2p)
|
82 |
+
|
83 |
+
|
84 |
+
class RemoteSequentialInferenceSession:
|
85 |
+
"""An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
|
86 |
+
|
87 |
+
def __init__(self, remote_sequence_info: RemoteSequenceInfo, p2p: P2P):
|
88 |
+
self.remote_sequence_info = remote_sequence_info
|
89 |
+
self.p2p = p2p
|
90 |
+
self.closed = False
|
91 |
+
self.stack = contextlib.ExitStack()
|
92 |
+
self.active_sessions = []
|
93 |
+
|
94 |
+
def __enter__(self):
|
95 |
+
assert not self.closed
|
96 |
+
self.stack.__enter__()
|
97 |
+
# TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
|
98 |
+
current_block = 0
|
99 |
+
while current_block != len(self.remote_sequence_info):
|
100 |
+
candidate_spans = self.remote_sequence_info.spans_containing_block[current_block]
|
101 |
+
chosen_span = random.choice(candidate_spans) # TODO this is a temporary code
|
102 |
+
assert chosen_span.start <= current_block < chosen_span.end
|
103 |
+
|
104 |
+
# TODO begin throwaway prototype code
|
105 |
+
remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p)
|
106 |
+
_ = remote.info # TODO fix
|
107 |
+
span_uids = self.remote_sequence_info.block_uids[current_block : chosen_span.end]
|
108 |
+
remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
|
109 |
+
self.active_sessions.append(remote.inference_session())
|
110 |
+
self.stack.enter_context(self.active_sessions[-1])
|
111 |
+
current_block = chosen_span.end
|
112 |
+
# TODO end throwaway prototype code
|
113 |
+
|
114 |
+
return self
|
115 |
+
|
116 |
+
def step(self, inputs: torch.Tensor):
|
117 |
+
assert not self.closed
|
118 |
+
for session in self.active_sessions:
|
119 |
+
outputs = session.step(inputs)
|
120 |
+
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
|
121 |
+
inputs = outputs
|
122 |
+
return inputs
|
123 |
+
|
124 |
+
def close(self, *exc_details):
|
125 |
+
"""Finish a given inference session, close the underlying connection"""
|
126 |
+
if not self.closed:
|
127 |
+
self.stack.__exit__(*exc_details or (None, None, None))
|
128 |
+
self.active_sessions.clear()
|
129 |
+
self.closed = True
|
130 |
+
|
131 |
+
def __exit__(self, *exc_details):
|
132 |
+
self.close(*exc_details)
|
133 |
+
|
134 |
+
def __del__(self):
|
135 |
+
self.close()
|
src/data_structures.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Collection, NamedTuple
|
2 |
+
|
3 |
+
from hivemind import PeerID
|
4 |
+
|
5 |
+
ModuleUID = str
|
6 |
+
UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
|
7 |
+
CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
|
8 |
+
RemoteModuleInfo = NamedTuple("RemoteModuleInfo", [("uid", ModuleUID), ("peer_ids", Collection[PeerID])])
|
src/dht_utils.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for declaring and retrieving active model layers using a shared DHT.
|
3 |
+
"""
|
4 |
+
from __future__ import annotations
|
5 |
+
|
6 |
+
from functools import partial
|
7 |
+
from typing import Dict, List, Optional, Sequence, Union
|
8 |
+
|
9 |
+
from hivemind.dht import DHT, DHTNode, DHTValue
|
10 |
+
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
11 |
+
from hivemind.p2p import P2P, PeerID
|
12 |
+
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
|
13 |
+
|
14 |
+
import src
|
15 |
+
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo
|
16 |
+
|
17 |
+
use_hivemind_log_handler("in_root_logger")
|
18 |
+
logger = get_logger(__file__)
|
19 |
+
|
20 |
+
|
21 |
+
def declare_active_modules(
|
22 |
+
dht: DHT,
|
23 |
+
uids: Sequence[ModuleUID],
|
24 |
+
expiration_time: DHTExpiration,
|
25 |
+
throughput: Optional[float] = None,
|
26 |
+
wait: bool = True,
|
27 |
+
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
|
28 |
+
"""
|
29 |
+
Declare that your node serves the specified modules; update timestamps if declared previously
|
30 |
+
|
31 |
+
:param uids: a list of module ids to declare
|
32 |
+
:param wait: if True, awaits for declaration to finish, otherwise runs in background
|
33 |
+
:param throughput: optionally specify your performance in terms of compute throughput
|
34 |
+
:param expiration_time: declated modules will be visible for this many seconds
|
35 |
+
:returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
|
36 |
+
"""
|
37 |
+
if isinstance(uids, str):
|
38 |
+
uids = [uids]
|
39 |
+
if not isinstance(uids, list):
|
40 |
+
uids = list(uids)
|
41 |
+
for uid in uids:
|
42 |
+
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
|
43 |
+
return dht.run_coroutine(
|
44 |
+
partial(_declare_active_modules, uids=uids, expiration_time=expiration_time, throughput=throughput),
|
45 |
+
return_future=not wait,
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
async def _declare_active_modules(
|
50 |
+
dht: DHT,
|
51 |
+
node: DHTNode,
|
52 |
+
uids: List[ModuleUID],
|
53 |
+
expiration_time: DHTExpiration,
|
54 |
+
throughput: Optional[float] = None,
|
55 |
+
) -> Dict[ModuleUID, bool]:
|
56 |
+
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
|
57 |
+
return await node.store_many(
|
58 |
+
keys=uids,
|
59 |
+
subkeys=[dht.peer_id.to_base58()] * len(uids),
|
60 |
+
values=[throughput] * len(uids),
|
61 |
+
expiration_time=expiration_time,
|
62 |
+
num_workers=num_workers,
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
def get_remote_module(
|
67 |
+
dht: DHT,
|
68 |
+
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
|
69 |
+
expiration_time: Optional[DHTExpiration] = None,
|
70 |
+
return_future: bool = False,
|
71 |
+
) -> Union[List[Optional[src.RemoteTransformerBlock]], MPFuture[List[Optional[src.RemoteTransformerBlock]]]]:
|
72 |
+
"""
|
73 |
+
:param uid_or_uids: find one or more modules with these ids from across the DHT
|
74 |
+
:param expiration_time: if specified, return modules that expire no sooner than this (based on get_dht_time)
|
75 |
+
:param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
|
76 |
+
:returns: a list of [RemoteTransformerBlock if found else None]
|
77 |
+
"""
|
78 |
+
single_uid = isinstance(uid_or_uids, ModuleUID)
|
79 |
+
uids = [uid_or_uids] if single_uid else uid_or_uids
|
80 |
+
infos = dht.run_coroutine(
|
81 |
+
partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future
|
82 |
+
)
|
83 |
+
|
84 |
+
if return_future:
|
85 |
+
|
86 |
+
async def _unpack(infos_future: MPFuture, dht: DHT):
|
87 |
+
p2p = await dht.replicate_p2p()
|
88 |
+
modules = _create_remote_modules_from_infos(await infos_future, p2p)
|
89 |
+
return modules[0] if single_uid else modules
|
90 |
+
|
91 |
+
return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
|
92 |
+
p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
|
93 |
+
modules = _create_remote_modules_from_infos(infos, p2p)
|
94 |
+
return modules[0] if single_uid else modules
|
95 |
+
|
96 |
+
|
97 |
+
async def _get_remote_module_infos(
|
98 |
+
dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration]
|
99 |
+
) -> List[Optional[RemoteModuleInfo]]:
|
100 |
+
if expiration_time is None:
|
101 |
+
expiration_time = get_dht_time()
|
102 |
+
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
|
103 |
+
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
|
104 |
+
|
105 |
+
modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
|
106 |
+
for i, uid in enumerate(uids):
|
107 |
+
metadata = found[uid]
|
108 |
+
if metadata is None or not isinstance(metadata.value, dict):
|
109 |
+
if metadata is not None:
|
110 |
+
logger.error(f"Incorrect metadata for {uid}: {metadata}")
|
111 |
+
continue
|
112 |
+
valid_entries = set()
|
113 |
+
for maybe_peer_id, _unused_value in metadata.value.items():
|
114 |
+
try:
|
115 |
+
valid_entries.add(PeerID.from_base58(maybe_peer_id))
|
116 |
+
except:
|
117 |
+
logger.error(f"Incorrect peer entry for {uid}: {maybe_peer_id}")
|
118 |
+
if valid_entries:
|
119 |
+
modules[i] = RemoteModuleInfo(uid, valid_entries)
|
120 |
+
return modules
|
121 |
+
|
122 |
+
|
123 |
+
def _create_remote_modules_from_infos(
|
124 |
+
infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
|
125 |
+
) -> List[Optional[src.RemoteTransformerBlock]]:
|
126 |
+
modules: List[Optional[src.RemoteTransformerBlock]] = []
|
127 |
+
for info in infos:
|
128 |
+
if info is not None:
|
129 |
+
modules.append(src.RemoteTransformerBlock(info, p2p))
|
130 |
+
else:
|
131 |
+
modules.append(None)
|
132 |
+
return modules
|
src/server/__init__.py
ADDED
File without changes
|
src/server/backend.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Code for serving bloom blocks via hivemind-server"""
|
2 |
+
from typing import Sequence, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from hivemind.moe.server.module_backend import ModuleBackend
|
6 |
+
from hivemind.moe.server.task_pool import TaskPool
|
7 |
+
|
8 |
+
from src.bloom.from_pretrained import BloomBlock
|
9 |
+
from src.server.cache import MemoryCache
|
10 |
+
|
11 |
+
MAX_LENGTH = 2048
|
12 |
+
|
13 |
+
|
14 |
+
class TransformerBackend(ModuleBackend):
|
15 |
+
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
|
16 |
+
|
17 |
+
def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
|
18 |
+
super().__init__(*args, **kwargs)
|
19 |
+
assert isinstance(self.module, BloomBlock)
|
20 |
+
self.memory_cache = memory_cache
|
21 |
+
for name, param in self.module.named_parameters():
|
22 |
+
assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
|
23 |
+
for name, buf in self.module.named_buffers():
|
24 |
+
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
|
25 |
+
|
26 |
+
self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
|
27 |
+
|
28 |
+
def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
29 |
+
with torch.inference_mode():
|
30 |
+
attention_cache_handle = int(cache_metadata[0, 0].item())
|
31 |
+
prefix_length = int(cache_metadata[0, 1].item())
|
32 |
+
hidden_states = inputs[0] # todo: in future, it would be best to support attention mask here
|
33 |
+
assert (
|
34 |
+
hidden_states.ndim == 3
|
35 |
+
), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
|
36 |
+
|
37 |
+
with self.memory_cache.use_cache(attention_cache_handle) as cache:
|
38 |
+
assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
|
39 |
+
layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
|
40 |
+
print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
|
41 |
+
hidden_states, (new_k, new_v) = self.module.forward(
|
42 |
+
hidden_states, layer_past=layer_past, use_cache=True
|
43 |
+
)
|
44 |
+
|
45 |
+
# todo remove these asserts once we pass all tests
|
46 |
+
new_length = new_v.shape[1]
|
47 |
+
assert new_length > prefix_length
|
48 |
+
assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
|
49 |
+
assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
|
50 |
+
assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
|
51 |
+
assert torch.allclose(new_v[:, : past_v.shape[1]], past_v)
|
52 |
+
assert torch.allclose(new_k[:, : past_k.shape[1]], past_k)
|
53 |
+
cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
|
54 |
+
cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
|
55 |
+
return (hidden_states,)
|
56 |
+
|
57 |
+
def get_pools(self) -> Sequence[TaskPool]:
|
58 |
+
return self.forward_pool, self.backward_pool, self.inference_pool
|
src/server/cache.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and used over multiple calls to Runtime.
|
3 |
+
|
4 |
+
For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
|
5 |
+
|
6 |
+
"""
|
7 |
+
import contextlib
|
8 |
+
import ctypes
|
9 |
+
import multiprocessing as mp
|
10 |
+
import os
|
11 |
+
from typing import AsyncContextManager, Dict, Optional, Union
|
12 |
+
|
13 |
+
import hivemind
|
14 |
+
import torch
|
15 |
+
from hivemind import use_hivemind_log_handler
|
16 |
+
from hivemind.utils import TensorDescriptor, get_logger
|
17 |
+
|
18 |
+
use_hivemind_log_handler("in_root_logger")
|
19 |
+
logger = get_logger(__file__)
|
20 |
+
|
21 |
+
Handle = int
|
22 |
+
|
23 |
+
|
24 |
+
class MemoryCache:
|
25 |
+
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
|
26 |
+
|
27 |
+
def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
|
28 |
+
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
|
29 |
+
self.device = device
|
30 |
+
self.lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
|
31 |
+
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
|
32 |
+
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
|
33 |
+
self._active_handles: Optional[Dict[Handle, TensorDescriptor]] = None
|
34 |
+
self._allocated_tensors: Optional[Dict[Handle, torch.Tensor]] = None
|
35 |
+
self.runtime_pid = os.getpid()
|
36 |
+
|
37 |
+
self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False) # any ConnectionHandler -> runtime
|
38 |
+
self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False)
|
39 |
+
|
40 |
+
@property
|
41 |
+
def current_size_bytes(self) -> int:
|
42 |
+
return self._current_size.value
|
43 |
+
|
44 |
+
@current_size_bytes.setter
|
45 |
+
def current_size_bytes(self, value: int):
|
46 |
+
self._current_size.value = value
|
47 |
+
|
48 |
+
@property
|
49 |
+
def handle_counter(self) -> int:
|
50 |
+
return self._handle_counter.value
|
51 |
+
|
52 |
+
@handle_counter.setter
|
53 |
+
def handle_counter(self, value: int):
|
54 |
+
self._handle_counter.value = value
|
55 |
+
|
56 |
+
@contextlib.asynccontextmanager
|
57 |
+
async def allocate_cache(self, descr: TensorDescriptor) -> AsyncContextManager[Handle]:
|
58 |
+
"""
|
59 |
+
Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
|
60 |
+
|
61 |
+
:param descr: allocate a tensor of this size, dtype, etc
|
62 |
+
|
63 |
+
:note: This function should be called by connection handlers, it can be called concurrently from multiple processes.
|
64 |
+
Furthermore, it can be called concurrently with at most one use_cache call in runtime.
|
65 |
+
"""
|
66 |
+
assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
|
67 |
+
assert descr.device is None and descr
|
68 |
+
allocated_handle = None
|
69 |
+
allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
|
70 |
+
try:
|
71 |
+
async with hivemind.utils.enter_asynchronously(self.lock_metadata):
|
72 |
+
if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
|
73 |
+
raise AllocationFailed(
|
74 |
+
f"Could not allocate {allocated_size_bytes} bytes in cache; cache size = "
|
75 |
+
f"{self.max_size_bytes} bytes; {self.current_size_bytes} already allocated."
|
76 |
+
)
|
77 |
+
|
78 |
+
allocated_handle = int(self.handle_counter)
|
79 |
+
self.current_size_bytes += allocated_size_bytes
|
80 |
+
self.handle_counter += 1 # note: this will eventually overflow and it is okay
|
81 |
+
self._pending_messages.value += 1
|
82 |
+
self._pipe_send.send((allocated_handle, descr))
|
83 |
+
|
84 |
+
yield allocated_handle
|
85 |
+
finally:
|
86 |
+
if allocated_handle is not None:
|
87 |
+
async with hivemind.utils.enter_asynchronously(self.lock_metadata):
|
88 |
+
self._pending_messages.value += 1
|
89 |
+
self._pipe_send.send((allocated_handle, None)) # signal runtime to free that handle
|
90 |
+
self.current_size_bytes -= allocated_size_bytes
|
91 |
+
|
92 |
+
@contextlib.contextmanager
|
93 |
+
def use_cache(self, handle: Handle) -> torch.Tensor:
|
94 |
+
"""
|
95 |
+
Return a tensor that was previously allocated with try_allocate_cache,
|
96 |
+
|
97 |
+
:note: This method is called by ExpertBackend in runtime: a single process with NO process parallelism.
|
98 |
+
However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache
|
99 |
+
"""
|
100 |
+
assert os.getpid() == self.runtime_pid
|
101 |
+
# note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
|
102 |
+
|
103 |
+
with self.lock_metadata:
|
104 |
+
if self._allocated_tensors is None:
|
105 |
+
self._allocated_tensors = {}
|
106 |
+
|
107 |
+
# read creation/deletion requests from connection handlers
|
108 |
+
for i in range(int(self._pending_messages.value)):
|
109 |
+
recv_handle, recv_data = self._pipe_recv.recv()
|
110 |
+
self._pending_messages.value -= 1
|
111 |
+
if isinstance(recv_data, TensorDescriptor):
|
112 |
+
self._allocated_tensors[recv_handle] = recv_data.make_zeros(device=self.device)
|
113 |
+
elif recv_data is None:
|
114 |
+
if recv_handle not in self._allocated_tensors:
|
115 |
+
logger.warning(
|
116 |
+
f"Sanity check failed: asked to delete handle {recv_handle}, but there is no such handle"
|
117 |
+
)
|
118 |
+
self._allocated_tensors.pop(recv_handle, None)
|
119 |
+
else:
|
120 |
+
logger.error(f"MemoryCache pipe received unexpected message: {recv_data}")
|
121 |
+
|
122 |
+
assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
|
123 |
+
yield self._allocated_tensors[handle]
|
124 |
+
|
125 |
+
|
126 |
+
class AllocationFailed(Exception):
|
127 |
+
pass
|
src/server/handler.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
|
2 |
+
import contextlib
|
3 |
+
from typing import AsyncIterator, Dict, Sequence
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor
|
7 |
+
from hivemind.moe.server.connection_handler import ConnectionHandler
|
8 |
+
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
|
9 |
+
from hivemind.proto import runtime_pb2
|
10 |
+
from hivemind.utils import as_aiter
|
11 |
+
from hivemind.utils.asyncio import anext
|
12 |
+
from hivemind.utils.streaming import split_for_streaming
|
13 |
+
|
14 |
+
from src.data_structures import CHAIN_DELIMITER, ModuleUID
|
15 |
+
from src.server.backend import MAX_LENGTH, TransformerBackend
|
16 |
+
|
17 |
+
|
18 |
+
class TransformerConnectionHandler(ConnectionHandler):
|
19 |
+
"""Handles three request types: forward, backward and forward-incremental (inference)"""
|
20 |
+
|
21 |
+
module_backends: Dict[ModuleUID, TransformerBackend]
|
22 |
+
|
23 |
+
def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
|
24 |
+
super().__init__(dht, module_backends)
|
25 |
+
for module_backend in self.module_backends.values():
|
26 |
+
assert isinstance(module_backend, TransformerBackend)
|
27 |
+
|
28 |
+
async def rpc_inference(
|
29 |
+
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
30 |
+
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
31 |
+
"""Compute a single step of inference using attention cache; update attention cache accordingly."""
|
32 |
+
try:
|
33 |
+
print("OPENED RPC_INFERENCE")
|
34 |
+
request = await anext(requests)
|
35 |
+
requested_uids = self._check_header(request)
|
36 |
+
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
37 |
+
|
38 |
+
cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64) # [cache_handle, prefix_length]
|
39 |
+
prefix_length = 0
|
40 |
+
|
41 |
+
async with self._allocate_caches(requested_backends) as cache_handles:
|
42 |
+
assert len(cache_handles) == len(requested_backends)
|
43 |
+
while request.tensors: # iterate while user is willing to supply tensors
|
44 |
+
hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
45 |
+
|
46 |
+
# run request tensors through all requested modules, update caches
|
47 |
+
for backend, cache_handle in zip(requested_backends, cache_handles):
|
48 |
+
cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, prefix_length
|
49 |
+
assert (
|
50 |
+
len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
51 |
+
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
52 |
+
|
53 |
+
hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
|
54 |
+
assert isinstance(hidden_states, (list, tuple))
|
55 |
+
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
56 |
+
|
57 |
+
# serialize and send last layer outputs
|
58 |
+
yield runtime_pb2.ExpertResponse(
|
59 |
+
tensors=[
|
60 |
+
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
61 |
+
for result, proto in zip(
|
62 |
+
hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
|
63 |
+
)
|
64 |
+
]
|
65 |
+
)
|
66 |
+
|
67 |
+
# prepare for next step
|
68 |
+
prefix_length += hidden_states[0].shape[1]
|
69 |
+
request = await (anext(requests))
|
70 |
+
finally:
|
71 |
+
print("CLOSED RPC_INFERENCE")
|
72 |
+
|
73 |
+
async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
74 |
+
# Parse request and prepare backends
|
75 |
+
hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
76 |
+
requested_uids = self._check_header(request)
|
77 |
+
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
78 |
+
|
79 |
+
# Run a chain of requested backends
|
80 |
+
for backend in requested_backends:
|
81 |
+
assert isinstance(hidden_states, (list, tuple))
|
82 |
+
assert (
|
83 |
+
len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
84 |
+
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
85 |
+
hidden_states = await backend.forward_pool.submit_task(*hidden_states)
|
86 |
+
|
87 |
+
# Serialize the overall output and respond
|
88 |
+
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
89 |
+
return runtime_pb2.ExpertResponse(
|
90 |
+
tensors=[
|
91 |
+
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
92 |
+
for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
|
93 |
+
]
|
94 |
+
)
|
95 |
+
|
96 |
+
async def rpc_forward_stream(
|
97 |
+
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
98 |
+
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
99 |
+
# Parse requests and prepare backends
|
100 |
+
uids_header, hidden_states = await self._gather_inputs(requests, context)
|
101 |
+
requested_uids = self._check_header_str(uids_header)
|
102 |
+
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
103 |
+
|
104 |
+
# Run a chain of requested backends
|
105 |
+
for backend in requested_backends:
|
106 |
+
assert isinstance(hidden_states, (list, tuple))
|
107 |
+
assert (
|
108 |
+
len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
109 |
+
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
110 |
+
hidden_states = await backend.forward_pool.submit_task(*hidden_states)
|
111 |
+
|
112 |
+
# Serialize the overall output
|
113 |
+
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
114 |
+
serialized_output = [
|
115 |
+
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
116 |
+
for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
|
117 |
+
]
|
118 |
+
|
119 |
+
# Split the serialized_output for streaming and respond
|
120 |
+
output_split = [
|
121 |
+
part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
122 |
+
]
|
123 |
+
async for part in as_aiter(*output_split):
|
124 |
+
yield runtime_pb2.ExpertResponse(tensors=[part])
|
125 |
+
|
126 |
+
async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
127 |
+
# Parse requests and prepare backends
|
128 |
+
inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
129 |
+
requested_uids = self._check_header(request)
|
130 |
+
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
131 |
+
|
132 |
+
# Run a forward chain to collect intermediate inputs
|
133 |
+
# Note that we do not forward for the last module since we do not need its output
|
134 |
+
inter_inputs = [inputs]
|
135 |
+
for backend in requested_backends[:-1]:
|
136 |
+
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
137 |
+
inputs = await backend.forward_pool.submit_task(inputs)
|
138 |
+
assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
|
139 |
+
inputs = inputs[0]
|
140 |
+
inter_inputs.append(inputs)
|
141 |
+
|
142 |
+
# Run a chain of requested backends
|
143 |
+
for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
|
144 |
+
inputs_and_grads = [inp, grads]
|
145 |
+
grads = await backend.backward_pool.submit_task(*inputs_and_grads)
|
146 |
+
assert isinstance(grads, (list, tuple)) and len(grads) == 1
|
147 |
+
grads = grads[0]
|
148 |
+
|
149 |
+
# Serialize the overall grad_input and respond
|
150 |
+
return runtime_pb2.ExpertResponse(
|
151 |
+
tensors=[
|
152 |
+
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
153 |
+
for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
|
154 |
+
]
|
155 |
+
)
|
156 |
+
|
157 |
+
async def rpc_backward_stream(
|
158 |
+
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
159 |
+
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
|
160 |
+
uids_header, inputs_and_grads = await self._gather_inputs(requests, context)
|
161 |
+
inputs, grads = inputs_and_grads
|
162 |
+
requested_uids = self._check_header_str(uids_header)
|
163 |
+
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
164 |
+
|
165 |
+
# Run a forward chain to collect intermediate inputs
|
166 |
+
# Note that we do not forward for the last module since we do not need its outputs
|
167 |
+
inter_inputs = [inputs]
|
168 |
+
for backend in requested_backends[:-1]:
|
169 |
+
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
170 |
+
inputs = await backend.forward_pool.submit_task(inputs)
|
171 |
+
assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
|
172 |
+
inputs = inputs[0]
|
173 |
+
inter_inputs.append(inputs)
|
174 |
+
|
175 |
+
# Run a backward chain for requested backends
|
176 |
+
for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
|
177 |
+
inputs_and_grads = [inp, grads]
|
178 |
+
grads = await backend.backward_pool.submit_task(*inputs_and_grads)
|
179 |
+
assert isinstance(grads, (list, tuple)) and len(grads) == 1
|
180 |
+
grads = grads[0]
|
181 |
+
|
182 |
+
# Serialize the overall grad_inputs
|
183 |
+
serialized_grad_inputs = [
|
184 |
+
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
185 |
+
for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
|
186 |
+
]
|
187 |
+
# Split the serialized_grad_inputs for streaming and respond
|
188 |
+
output_split = [
|
189 |
+
part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
190 |
+
]
|
191 |
+
|
192 |
+
async for part in as_aiter(*output_split):
|
193 |
+
yield runtime_pb2.ExpertResponse(tensors=[part])
|
194 |
+
|
195 |
+
def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
|
196 |
+
"""Check that the first request to rpc_inference is valid"""
|
197 |
+
uids = (request.uid or "").split(CHAIN_DELIMITER)
|
198 |
+
if not uids:
|
199 |
+
raise RuntimeError("User did not provide any uids")
|
200 |
+
for uid in uids:
|
201 |
+
if uid not in self.module_backends:
|
202 |
+
raise RuntimeError(f"Remote peer does not serve {uid}")
|
203 |
+
return tuple(uids)
|
204 |
+
|
205 |
+
def _check_header_str(self, header) -> Sequence[ModuleUID]:
|
206 |
+
"""Check that the first request to rpc_inference is valid"""
|
207 |
+
uids = (header or "").split(CHAIN_DELIMITER)
|
208 |
+
if not uids:
|
209 |
+
raise RuntimeError("User did not provide any uids")
|
210 |
+
for uid in uids:
|
211 |
+
if uid not in self.module_backends:
|
212 |
+
raise RuntimeError(f"Remote peer does not serve {uid}")
|
213 |
+
return tuple(uids)
|
214 |
+
|
215 |
+
@contextlib.asynccontextmanager
|
216 |
+
async def _allocate_caches(self, backends: Sequence[TransformerBackend]) -> Sequence[int]:
|
217 |
+
"""Allocate memory caches for each transformer block, return cache handles"""
|
218 |
+
async with contextlib.AsyncExitStack() as stack:
|
219 |
+
handles = []
|
220 |
+
for backend in backends:
|
221 |
+
num_heads = backend.module.self_attention.num_heads
|
222 |
+
head_dim = backend.module.self_attention.head_dim
|
223 |
+
|
224 |
+
cache_descriptor = TensorDescriptor(size=(2, 1, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32)
|
225 |
+
# [key_or_value, batch_size, max_length, num_heads, head_dim]
|
226 |
+
|
227 |
+
handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
|
228 |
+
|
229 |
+
yield handles
|
src/server/server.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import multiprocessing as mp
|
4 |
+
import threading
|
5 |
+
from typing import Dict, Optional, Sequence, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
|
9 |
+
from hivemind.moe.server.layers import add_custom_models_from_file
|
10 |
+
from hivemind.moe.server.runtime import Runtime
|
11 |
+
from hivemind.proto.runtime_pb2 import CompressionType
|
12 |
+
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
13 |
+
|
14 |
+
from src import declare_active_modules, BloomConfig
|
15 |
+
from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
|
16 |
+
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER
|
17 |
+
from src.server.backend import TransformerBackend
|
18 |
+
from src.server.cache import MemoryCache
|
19 |
+
from src.server.handler import TransformerConnectionHandler
|
20 |
+
|
21 |
+
use_hivemind_log_handler("in_root_logger")
|
22 |
+
logger = get_logger(__file__)
|
23 |
+
|
24 |
+
|
25 |
+
class Server(threading.Thread):
|
26 |
+
"""Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
dht: DHT,
|
31 |
+
module_backends: Dict[str, TransformerBackend],
|
32 |
+
*,
|
33 |
+
device: torch.device,
|
34 |
+
num_connection_handlers: int = 8,
|
35 |
+
update_period: float = 30,
|
36 |
+
expiration: Optional[float] = None,
|
37 |
+
start: bool,
|
38 |
+
**kwargs,
|
39 |
+
):
|
40 |
+
threading.Thread.__init__(self)
|
41 |
+
self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
|
42 |
+
self.conn_handlers = [
|
43 |
+
TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
|
44 |
+
]
|
45 |
+
self.runtime = Runtime(self.module_backends, device=device, **kwargs)
|
46 |
+
self.dht_handler_thread = ModuleAnnouncerThread(
|
47 |
+
self.module_backends, dht, update_period, expiration, daemon=True
|
48 |
+
)
|
49 |
+
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
|
50 |
+
|
51 |
+
if start:
|
52 |
+
self.run_in_background(await_ready=True)
|
53 |
+
|
54 |
+
def run(self):
|
55 |
+
"""
|
56 |
+
Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
|
57 |
+
runs Runtime (self.runtime) to process incoming requests.
|
58 |
+
"""
|
59 |
+
logger.info(f"Serving {len(self.module_backends)} blocks:")
|
60 |
+
for expert_name, backend in self.module_backends.items():
|
61 |
+
num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
|
62 |
+
logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
|
63 |
+
|
64 |
+
if not self.dht.is_alive():
|
65 |
+
self.dht.run_in_background(await_ready=True)
|
66 |
+
|
67 |
+
if self.module_backends:
|
68 |
+
self.dht_handler_thread.start()
|
69 |
+
|
70 |
+
if self.checkpoint_saver is not None:
|
71 |
+
self.checkpoint_saver.start()
|
72 |
+
|
73 |
+
for process in self.conn_handlers:
|
74 |
+
if not process.is_alive():
|
75 |
+
process.start()
|
76 |
+
process.ready.result()
|
77 |
+
|
78 |
+
try:
|
79 |
+
self.runtime.run()
|
80 |
+
finally:
|
81 |
+
self.shutdown()
|
82 |
+
|
83 |
+
# noinspection PyMethodOverriding
|
84 |
+
@classmethod
|
85 |
+
def create(
|
86 |
+
cls,
|
87 |
+
prefix: Optional[str],
|
88 |
+
converted_model_name_or_path: str,
|
89 |
+
num_blocks: Optional[int] = None,
|
90 |
+
block_indices: Optional[str] = None,
|
91 |
+
num_handlers: Optional[int] = None,
|
92 |
+
min_batch_size: int = 1,
|
93 |
+
max_batch_size: int = 4096,
|
94 |
+
torch_dtype: str = "auto",
|
95 |
+
cache_size_bytes: Optional[int] = None,
|
96 |
+
device: Union[str, torch.device] = None,
|
97 |
+
initial_peers: Sequence[str] = (),
|
98 |
+
compression=CompressionType.NONE,
|
99 |
+
stats_report_interval: Optional[int] = None,
|
100 |
+
custom_module_path=None,
|
101 |
+
update_period: float = 30,
|
102 |
+
expiration: Optional[float] = None,
|
103 |
+
use_auth_token: Optional[str] = None,
|
104 |
+
*,
|
105 |
+
start: bool,
|
106 |
+
**kwargs,
|
107 |
+
) -> Server:
|
108 |
+
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
109 |
+
if custom_module_path is not None:
|
110 |
+
add_custom_models_from_file(custom_module_path)
|
111 |
+
if prefix is None:
|
112 |
+
prefix = converted_model_name_or_path
|
113 |
+
assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
|
114 |
+
f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); "
|
115 |
+
f"Please specify --prefix manually when starting a server"
|
116 |
+
)
|
117 |
+
logger.info(f"Automatic dht prefix: {prefix}")
|
118 |
+
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
|
119 |
+
dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
|
120 |
+
visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
|
121 |
+
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
|
122 |
+
|
123 |
+
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
124 |
+
memory_cache = MemoryCache(device, cache_size_bytes)
|
125 |
+
|
126 |
+
if isinstance(torch_dtype, str):
|
127 |
+
torch_dtype = DTYPE_MAP[torch_dtype]
|
128 |
+
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
129 |
+
|
130 |
+
if block_indices is not None:
|
131 |
+
try:
|
132 |
+
first_block_index, last_block_index = block_indices.split(":")
|
133 |
+
first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
|
134 |
+
except Exception as e:
|
135 |
+
logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
|
136 |
+
raise
|
137 |
+
block_indices = range(first_block_index, last_block_index)
|
138 |
+
else:
|
139 |
+
assert num_blocks is not None
|
140 |
+
block_indices = range(num_blocks) # TODO replace with proper load balancing
|
141 |
+
|
142 |
+
block_config = BloomConfig.from_pretrained(
|
143 |
+
converted_model_name_or_path, use_auth_token=use_auth_token
|
144 |
+
)
|
145 |
+
|
146 |
+
# initialize modules
|
147 |
+
blocks = {}
|
148 |
+
for block_index in block_indices:
|
149 |
+
module_uid = f"{prefix}.{block_index}"
|
150 |
+
block = load_pretrained_block(
|
151 |
+
converted_model_name_or_path,
|
152 |
+
block_index,
|
153 |
+
block_config,
|
154 |
+
torch_dtype=torch_dtype,
|
155 |
+
use_auth_token=use_auth_token,
|
156 |
+
)
|
157 |
+
for param in block.parameters():
|
158 |
+
param.requires_grad = False
|
159 |
+
|
160 |
+
blocks[module_uid] = TransformerBackend(
|
161 |
+
module_uid,
|
162 |
+
block,
|
163 |
+
memory_cache=memory_cache,
|
164 |
+
args_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
|
165 |
+
kwargs_schema={},
|
166 |
+
outputs_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
|
167 |
+
min_batch_size=min_batch_size,
|
168 |
+
max_batch_size=max_batch_size,
|
169 |
+
)
|
170 |
+
|
171 |
+
num_handlers = num_handlers if num_handlers is not None else len(blocks) * 4
|
172 |
+
|
173 |
+
return cls(
|
174 |
+
dht,
|
175 |
+
blocks,
|
176 |
+
num_connection_handlers=num_handlers,
|
177 |
+
device=device,
|
178 |
+
stats_report_interval=stats_report_interval,
|
179 |
+
update_period=update_period,
|
180 |
+
expiration=expiration,
|
181 |
+
start=start,
|
182 |
+
)
|
183 |
+
|
184 |
+
def run_in_background(self, await_ready=True, timeout=None):
|
185 |
+
"""
|
186 |
+
Starts Server in a background thread. if await_ready, this method will wait until background server
|
187 |
+
is ready to process incoming requests or for :timeout: seconds max.
|
188 |
+
"""
|
189 |
+
self.start()
|
190 |
+
if await_ready and not self.ready.wait(timeout=timeout):
|
191 |
+
raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
|
192 |
+
|
193 |
+
@property
|
194 |
+
def ready(self) -> mp.synchronize.Event:
|
195 |
+
"""
|
196 |
+
An event (multiprocessing.Event) that is set when the server is ready to process requests.
|
197 |
+
|
198 |
+
Example
|
199 |
+
=======
|
200 |
+
>>> server.start()
|
201 |
+
>>> server.ready.wait(timeout=10)
|
202 |
+
>>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
|
203 |
+
"""
|
204 |
+
return self.runtime.ready # mp.Event that is true if self is ready to process batches
|
205 |
+
|
206 |
+
def shutdown(self):
|
207 |
+
"""
|
208 |
+
Gracefully terminate the server, process-safe.
|
209 |
+
Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
|
210 |
+
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
|
211 |
+
"""
|
212 |
+
self.ready.clear()
|
213 |
+
|
214 |
+
for process in self.conn_handlers:
|
215 |
+
process.terminate()
|
216 |
+
process.join()
|
217 |
+
logger.debug("Connection handlers terminated")
|
218 |
+
|
219 |
+
if self.module_backends:
|
220 |
+
self.dht_handler_thread.stop.set()
|
221 |
+
self.dht_handler_thread.join()
|
222 |
+
|
223 |
+
if self.checkpoint_saver is not None:
|
224 |
+
self.checkpoint_saver.stop.set()
|
225 |
+
self.checkpoint_saver.join()
|
226 |
+
|
227 |
+
self.dht.shutdown()
|
228 |
+
self.dht.join()
|
229 |
+
|
230 |
+
logger.debug(f"Shutting down runtime")
|
231 |
+
|
232 |
+
self.runtime.shutdown()
|
233 |
+
logger.info("Server shutdown succesfully")
|
234 |
+
|
235 |
+
|
236 |
+
class ModuleAnnouncerThread(threading.Thread):
|
237 |
+
"""Periodically announces that this server hosts the specified modules, visible to all DHT peers"""
|
238 |
+
|
239 |
+
def __init__(
|
240 |
+
self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs
|
241 |
+
):
|
242 |
+
super().__init__(**kwargs)
|
243 |
+
if expiration is None:
|
244 |
+
expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
|
245 |
+
self.module_backends = module_backends
|
246 |
+
self.dht = dht
|
247 |
+
self.update_period = update_period
|
248 |
+
self.expiration = expiration
|
249 |
+
self.stop = threading.Event()
|
250 |
+
|
251 |
+
def run(self) -> None:
|
252 |
+
declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)
|
253 |
+
while not self.stop.wait(self.update_period):
|
254 |
+
declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)
|