justheuristic commited on
Commit
23d0807
·
1 Parent(s): 661bd6a
README.md CHANGED
@@ -1,12 +1,9 @@
1
  ---
2
- title: Dbg4
3
- emoji: 📊
4
- colorFrom: purple
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 3.1.1
8
  app_file: app.py
9
  pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Distributd-Bloom API
3
+ colorFrom: pink
4
+ colorTo: grey
 
5
  sdk: gradio
6
+ sdk_version: 3.0.25
7
  app_file: app.py
8
  pinned: false
9
+ ---
 
 
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import transformers
4
+ import gradio as gr
5
+
6
+ from src.client import DistributedBloomForCausalLM
7
+
8
+ INITIAL_PEERS = ['/ip4/193.106.95.184/tcp/31337/p2p/QmUigSxrVz9x5FR9ZYr4iRfEX2vDxihL2YZtDd7sp2eKnM']
9
+ tokenizer = transformers.BloomTokenizerFast.from_pretrained("bigscience/test-bloomd-6b3")
10
+ model = DistributedBloomForCausalLM.from_pretrained("bigscience/test-bloomd-6b3", initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32)
11
+
12
+ def inference(text, seq_length=1):
13
+ input_ids = tokenizer(text, return_tensors='pt')['input_ids']
14
+ with torch.inference_mode(), model.transformer.h.inference_session() as remote_transformer:
15
+ for i in range(seq_length):
16
+ h = model.transformer.word_embeddings(input_ids)
17
+ h = model.transformer.word_embeddings_layernorm(h)
18
+
19
+ h = remote_transformer.step(h) # note [yozh]: this line currently freezes for 10 seconds first time only, its gonna be fixed in the nearest PR
20
+
21
+ h = model.transformer.ln_f(h)
22
+ h = F.linear(h, weight=model.transformer.word_embeddings.weight) # note: this line takes a while, will also be fixed
23
+ next_token_ix = torch.multinomial((h[0, -1] / 0.8).softmax(-1), 1)
24
+
25
+ # print(end=tokenizer.decode(next_token_ix.item()))
26
+ input_ids = next_token_ix.view(1, 1)
27
+ return tokenizer.decode(input_ids.item())
28
+
29
+ iface = gr.Interface(fn=inference, inputs="text", outputs="text")
30
+ iface.launch()
cli/__init__.py ADDED
File without changes
cli/config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "apply_residual_connection_post_layernorm": false,
3
+ "attention_dropout": 0.0,
4
+ "attention_softmax_in_fp32": true,
5
+ "bos_token_id": 1,
6
+ "eos_token_id": 2,
7
+ "hidden_dropout": 0.0,
8
+ "initializer_range": 0.02,
9
+ "layer_norm_epsilon": 1e-05,
10
+ "masked_softmax_fusion": true,
11
+ "model_type": "bloom",
12
+ "n_embed": 14336,
13
+ "n_layer": 70,
14
+ "num_attention_heads": 112,
15
+ "pretraining_tp": 4,
16
+ "slow_but_exact": false,
17
+ "transformers_version": "4.20.0.dev0",
18
+ "use_cache": true,
19
+ "vocab_size": 250880
20
+ }
cli/convert_model.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import psutil
5
+ import torch.backends.quantized
6
+ import torch.nn as nn
7
+ import transformers
8
+ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
9
+ from huggingface_hub import Repository
10
+ from tqdm.auto import tqdm
11
+
12
+ from src import BloomModel
13
+ from src.client import DistributedBloomConfig
14
+ from src.bloom.from_pretrained import CLIENT_BRANCH, BLOCK_BRANCH_PREFIX
15
+ use_hivemind_log_handler("in_root_logger")
16
+ logger = get_logger(__file__)
17
+
18
+ DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
19
+
20
+
21
+ if __name__ == "__main__":
22
+ parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
23
+
24
+ parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
25
+ parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
26
+ parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
27
+ parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
28
+ parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
29
+ parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
30
+ parser.add_argument(
31
+ "--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix"
32
+ )
33
+ parser.add_argument(
34
+ "--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
35
+ )
36
+ parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
37
+ args = parser.parse_args()
38
+
39
+ free_ram_gb = psutil.virtual_memory().available / 2**30
40
+ if args.model == "bigscience/bloom" and free_ram_gb < 400:
41
+ logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
42
+
43
+ assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
44
+ if os.path.exists(args.output_path) and (
45
+ len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
46
+ ):
47
+ raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
48
+
49
+ logger.info(f"Loading source model {args.model} (this may take a few minutes)")
50
+ config = DistributedBloomConfig.from_pretrained(
51
+ args.model, use_auth_token=args.use_auth_token, revision=args.revision
52
+ )
53
+ config.dht_prefix = args.output_repo
54
+
55
+ model = BloomModel.from_pretrained(
56
+ args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
57
+ )
58
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
59
+ args.model, use_auth_token=args.use_auth_token, revision=args.revision
60
+ )
61
+ os.makedirs(args.output_path, exist_ok=True)
62
+
63
+ repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
64
+ repo.git_pull()
65
+
66
+ transformer_blocks = model.h
67
+ logger.info(
68
+ f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
69
+ f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
70
+ )
71
+ for i, block in enumerate(tqdm(transformer_blocks)):
72
+ repo.git_checkout(args.client_branch, create_branch_ok=True)
73
+ with repo.commit(
74
+ commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
75
+ ):
76
+ torch.save(block.state_dict(), "./pytorch_model.bin")
77
+
78
+ logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
79
+ repo.git_checkout(args.client_branch, create_branch_ok=True)
80
+ with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
81
+ model.h = nn.ModuleList()
82
+ model.save_pretrained(".")
83
+ tokenizer.save_pretrained(".")
84
+ config.save_pretrained(".")
85
+
86
+ logger.info(f"Converted {args.model} and pushed to {args.output_repo}")
cli/deploy_server.sh ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ #################
4
+ # Parse options #
5
+ #################
6
+
7
+ instructions() {
8
+ echo "Usage: $0 [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2
9
+ echo " -i: initial peer"
10
+ echo " -d: device" >&2
11
+ echo " -p: server identity path" >&2
12
+ echo " -b: block_ids" >&2
13
+ echo " -a: host maddrs" >&2
14
+ echo " -t: whether to run local tests" >&2
15
+ exit 1
16
+ }
17
+
18
+ if [ ! $# -ge 8 ]; then
19
+ instructions
20
+ fi
21
+
22
+ while getopts ":i:d:p:b:a:t:" option; do
23
+ case $option in
24
+ i) INITIAL_PEER=${OPTARG}
25
+ ;;
26
+ d) DEVICE=${OPTARG}
27
+ ;;
28
+ p) SERVER_ID_PATH=${OPTARG}
29
+ ;;
30
+ b) BLOCK_IDS=${OPTARG}
31
+ ;;
32
+ a) HOST_MADDR=${OPTARG} # TODO: allow several maddrs
33
+ ;;
34
+ t) RUN_LOCAL_TESTS=true
35
+ ;;
36
+ \?) instructions
37
+ ;;
38
+ esac
39
+ done
40
+
41
+
42
+ echo "=========="
43
+ echo "= Config ="
44
+ echo "=========="
45
+ echo "Initial peer: ${INITIAL_PEER}"
46
+ echo "Device: ${DEVICE}"
47
+ echo "Server name: ${SERVER_ID_PATH}"
48
+ echo "Server address: ${HOST_MADDR}"
49
+ echo "Bloom blocks: ${BLOCK_IDS}"
50
+
51
+
52
+ ###########################
53
+ # Install or activate env #
54
+ ###########################
55
+
56
+ # TODO fix bug with self calling
57
+ source ~/miniconda3/etc/profile.d/conda.sh
58
+ if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
59
+ conda activate bloom-demo
60
+ else
61
+ conda create -y --name bloom-demo python=3.8.12 pip
62
+ conda activate bloom-demo
63
+
64
+ conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
65
+ pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
66
+ pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
67
+ pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
68
+ pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
69
+ fi
70
+
71
+
72
+ ##############
73
+ # Local test #
74
+ ##############
75
+
76
+ if [ "$RUN_LOCAL_TESTS" = true ] ; then
77
+ echo "Run test on your local machine"
78
+ python -m cli.inference_one_block --config cli/config.json --device ${DEVICE} # see other args
79
+ fi
80
+
81
+
82
+ ##############
83
+ # Run server #
84
+ ##############
85
+
86
+ python -m cli.run_server --prefix bloom6b3 --converted_model_name_or_path bigscience/test-bloomd-6b3 --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
87
+ --block_indices ${BLOCK_IDS} --torch_dtype float32 --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} &> ${SERVER_ID_PATH}.log
cli/inference_one_block.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
5
+ from tqdm.auto import trange
6
+
7
+ from src.bloom.block import BloomBlock
8
+ from src.bloom.model import BloomConfig
9
+ from src.bloom.ops import build_alibi_tensor
10
+
11
+ use_hivemind_log_handler("in_root_logger")
12
+ logger = get_logger(__file__)
13
+
14
+ logger.warning("inference_one_block will soon be deprecated in favour of tests!")
15
+
16
+
17
+ def print_device_info(device=None):
18
+ """Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
19
+ device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
20
+ logger.info(f"Using device: {device}")
21
+
22
+ # Additional Info when using cuda
23
+ if device.type == "cuda":
24
+ logger.info(torch.cuda.get_device_name(0))
25
+ logger.info(f"Memory Usage:")
26
+ logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
27
+ logger.info(f"Cached: {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
28
+
29
+
30
+ if __name__ == "__main__":
31
+ parser = argparse.ArgumentParser(description="Run a single bloom block locally on dummy data")
32
+ parser.add_argument("--config", required=True, type=str, help="Path to a config json file")
33
+ parser.add_argument("--state_dict", default=None, type=str, help="Optional path to saved block state dict")
34
+ parser.add_argument("--layer_index", default=0, type=int, help="Optional path to saved block state dict")
35
+ parser.add_argument("--num_steps", default=500, type=int, help="How many inference steps to run")
36
+ parser.add_argument("--device", default=None, type=str, help="Run inference on this device")
37
+ args = parser.parse_args()
38
+
39
+ if args.device is None:
40
+ args.device = "cuda" if torch.cuda.is_available() else "cpu"
41
+
42
+ config = BloomConfig.from_json_file(args.config)
43
+ block = BloomBlock(config, args.layer_index).to(args.device)
44
+
45
+ cache = None
46
+
47
+ for i in trange(args.num_steps):
48
+ dummy_input = torch.randn(1, 1, config.hidden_size, device=args.device)
49
+ alibi = build_alibi_tensor(i + 1, config.num_attention_heads).to(args.device)
50
+ with torch.no_grad():
51
+ outputs, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
52
+
53
+ print_device_info(args.device)
cli/local_server_config_example.cfg ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ device=cpu
2
+ block_ids=2:3
3
+ id_path=./server.id
4
+ maddr=/ip4/127.0.0.1/tcp/30000
5
+ #
cli/remote_server_config_example.cfg ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ name=bloom-peer-0.bloom.net
2
+ device=cpu
3
+ block_ids=1:3
4
+ id_path=./server.id
5
+ maddr=/ip4/0.0.0.0/tcp/30000
6
+ #
cli/run_local_servers.sh ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # !/usr/bin/env bash
2
+
3
+ #################
4
+ # Parse options #
5
+ #################
6
+
7
+ instructions() {
8
+ echo "Usage: $0 [-n] [-c]" >&2
9
+ echo " -n: number of servers to run" >&2
10
+ echo " -c: path to the server configs" >&2
11
+ exit 1
12
+ }
13
+
14
+ if [ $# != 4 ]; then
15
+ instructions
16
+ fi
17
+
18
+ while getopts ":n:c:t:" option; do
19
+ case $option in
20
+ n) NUM_SERVERS=${OPTARG}
21
+ ;;
22
+ c) CONFIG_PATH=${OPTARG}
23
+ ;;
24
+ \?) instructions
25
+ ;;
26
+ esac
27
+ done
28
+
29
+
30
+ ###########################
31
+ # Install or activate env #
32
+ ###########################
33
+
34
+ source ~/miniconda3/etc/profile.d/conda.sh
35
+ if conda env list | grep ".*bloom-demo.*" &>/dev/null; then
36
+ conda activate bloom-demo
37
+ else
38
+ conda create -y --name bloom-demo python=3.8.12 pip
39
+ conda activate bloom-demo
40
+
41
+ conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
42
+ pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
43
+ pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
44
+ pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
45
+ pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
46
+ fi
47
+
48
+
49
+ #######################
50
+ # Create Initial peer #
51
+ #######################
52
+
53
+ hivemind-dht &> tmp.out &
54
+ sleep 3
55
+ INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-1])" )
56
+ echo "Initial peer: ${INITIAL_PEER}"
57
+
58
+
59
+ ##############################
60
+ # Initialize the config file #
61
+ ##############################
62
+
63
+ typeset -A cfg
64
+ cfg=( # set default values in config array
65
+ [device]="cpu"
66
+ [block_ids]="1:2"
67
+ [id_path]="server.id"
68
+ [maddr]="/ip4/127.0.0.1/tcp/30000"
69
+ )
70
+
71
+ ###############
72
+ # Run servers #
73
+ ###############
74
+
75
+ for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
76
+ do
77
+ ###############
78
+ # Read config #
79
+ ###############
80
+
81
+ while read line
82
+ do
83
+ if echo $line | grep -F = &>/dev/null
84
+ then
85
+ varname=$(echo "$line" | cut -d '=' -f 1)
86
+ cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
87
+ fi
88
+ done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
89
+
90
+ echo "=== Server #${SERVER_ID} ==="
91
+ echo "Server ID: ${id_path}"
92
+ echo "Device: ${cfg[device]}"
93
+ echo "Bloom block ids: ${cfg[block_ids]}"
94
+ echo "Host maddr: ${cfg[maddr]}"
95
+ echo ""
96
+
97
+ ##############
98
+ # Run server #
99
+ ##############
100
+
101
+ tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
102
+ done
103
+
104
+
105
+ #####################
106
+ # Kill initial peer #
107
+ #####################
108
+
109
+ sleep 10
110
+ pkill -f hivemind-dht # TODO: kill only particular pids of hivemind-dht
111
+ rm tmp.out
cli/run_remote_servers.sh ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # !/usr/bin/env bash
2
+
3
+ SSH_KEY_PATH="~/.ssh/<YOUR_KEY>"
4
+
5
+ #################
6
+ # Parse options #
7
+ #################
8
+
9
+ instructions() {
10
+ echo "Usage: $0 [-u] [-n] [-c]" >&2
11
+ echo " -u: username" >&2
12
+ echo " -n: number of servers to run" >&2
13
+ echo " -c: path to the server configs" >&2
14
+ exit 1
15
+ }
16
+
17
+ if [ $# != 6 ]; then
18
+ instructions
19
+ fi
20
+
21
+ while getopts ":u:n:c:" option; do
22
+ case $option in
23
+ u) USERNAME=${OPTARG}
24
+ ;;
25
+ n) NUM_SERVERS=${OPTARG}
26
+ ;;
27
+ c) CONFIG_PATH=${OPTARG}
28
+ ;;
29
+ \?) instructions
30
+ ;;
31
+ esac
32
+ done
33
+
34
+
35
+ ###########################
36
+ # Install or activate env #
37
+ ###########################
38
+
39
+ source ~/miniconda3/etc/profile.d/conda.sh
40
+ if conda env list | grep ".*bloom-demo.*" &>/dev/null; then
41
+ conda activate bloom-demo
42
+ else
43
+ conda create -y --name bloom-demo python=3.8.12 pip
44
+ conda activate bloom-demo
45
+
46
+ conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
47
+ pip install -i https://pypi.org/simple torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
48
+ pip install -i https://pypi.org/simple accelerate==0.10.0 huggingface-hub==0.7.0 hivemind==1.1.0
49
+ pip install -i https://pypi.org/simple bitsandbytes-cuda113==0.26.0
50
+ pip install -i https://pypi.org/simple https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
51
+ fi
52
+
53
+
54
+ #######################
55
+ # Create Initial peer #
56
+ #######################
57
+
58
+ hivemind-dht &> tmp.out &
59
+
60
+ sleep 3
61
+ INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-2])" )
62
+ rm tmp.out
63
+ echo "Initial peer: ${INITIAL_PEER}"
64
+
65
+
66
+ ##############################
67
+ # Initialize the config file #
68
+ ##############################
69
+
70
+ typeset -A cfg
71
+ cfg=( # set default values in config array
72
+ [name]=""
73
+ [device]="cpu"
74
+ [block_ids]="1:2"
75
+ [id_path]="server.id"
76
+ [maddr]="/ip4/0.0.0.0/tcp/30000"
77
+ )
78
+
79
+ ###############
80
+ # Run servers #
81
+ ###############
82
+
83
+ for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
84
+ do
85
+ ###############
86
+ # Read config #
87
+ ###############
88
+
89
+ while read line
90
+ do
91
+ if echo $line | grep -F = &>/dev/null
92
+ then
93
+ varname=$(echo "$line" | cut -d '=' -f 1)
94
+ cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
95
+ fi
96
+ done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
97
+
98
+ SERVER_NAME="${USERNAME}@${cfg[name]}"
99
+ echo "=== Server #${SERVER_ID} ==="
100
+ echo "Server name ${SERVER_NAME}"
101
+ echo "Server ID: ${cfg[id_path]}"
102
+ echo "Device: ${cfg[device]}"
103
+ echo "Bloom block ids: ${cfg[block_ids]}"
104
+ echo "Host maddr: ${cfg[maddr]}"
105
+ echo "================="
106
+
107
+ ##############
108
+ # Run server #
109
+ ##############
110
+
111
+ ssh -i ${SSH_KEY_PATH} ${SERVER_NAME} "tmux new-session -d -s 'Server_${SERVER_ID}' 'cd bloom-demo && bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}'"
112
+ done
cli/run_server.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import configargparse
2
+ from hivemind.proto.runtime_pb2 import CompressionType
3
+ from hivemind.utils.limits import increase_file_limit
4
+ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
5
+
6
+ from src.server.server import Server
7
+
8
+ use_hivemind_log_handler("in_root_logger")
9
+ logger = get_logger(__file__)
10
+
11
+
12
+ def main():
13
+ # fmt:off
14
+ parser = configargparse.ArgParser(default_config_files=["config.yml"])
15
+ parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
16
+
17
+ parser.add_argument('--converted_model_name_or_path', type=str, default='bigscience/test-bloomd-6b3',
18
+ help="path or name of a pretrained model, converted with cli/convert_model.py (see README.md)")
19
+ parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
20
+ parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
21
+ parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default,"
22
+ "use the same name as in the converted model.")
23
+ parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
24
+ help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
25
+ parser.add_argument('--announce_maddrs', nargs='+', default=None, required=False,
26
+ help='Visible multiaddrs the host announces for external connections from other p2p instances')
27
+
28
+ parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
29
+
30
+ parser.add_argument('--num_handlers', type=int, default=None, required=False,
31
+ help='server will use this many processes to handle incoming requests')
32
+ parser.add_argument('--min_batch_size', type=int, default=1,
33
+ help='Minimum required batch size for all expert operations')
34
+ parser.add_argument('--max_batch_size', type=int, default=16384,
35
+ help='The total number of examples in the same batch will not exceed this value')
36
+ parser.add_argument('--cache_size_bytes', type=int, default=None,
37
+ help='The size of memory cache for storing past attention keys/values between inference steps')
38
+ parser.add_argument('--device', type=str, default=None, required=False,
39
+ help='all experts will use this device in torch notation; default: cuda if available else cpu')
40
+ parser.add_argument("--torch_dtype", type=str, default="auto",
41
+ help="Use this dtype to store block weights and do computations. "
42
+ "By default, respect the dtypes in the pre-trained state dict.")
43
+
44
+ parser.add_argument('--update_period', type=float, required=False, default=30,
45
+ help='Server will report experts to DHT once in this many seconds')
46
+ parser.add_argument('--expiration', type=float, required=False, default=None,
47
+ help='DHT entries will expire after this many seconds')
48
+ parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
49
+ help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
50
+ parser.add_argument('--increase_file_limit', action='store_true',
51
+ help='On *nix, this will increase the max number of processes '
52
+ 'a server can spawn before hitting "Too many open files"; Use at your own risk.')
53
+ parser.add_argument('--stats_report_interval', type=int, required=False,
54
+ help='Interval between two reports of batch processing performance statistics')
55
+
56
+ parser.add_argument('--custom_module_path', type=str, required=False,
57
+ help='Path of a file with custom nn.modules, wrapped into special decorator')
58
+ parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
59
+ parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
60
+
61
+ # fmt:on
62
+ args = vars(parser.parse_args())
63
+ args.pop("config", None)
64
+
65
+ if args.pop("increase_file_limit"):
66
+ increase_file_limit()
67
+
68
+ compression_type = args.pop("compression")
69
+ compression = getattr(CompressionType, compression_type)
70
+
71
+ use_auth_token = args.pop("use_auth_token")
72
+ args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
73
+
74
+ server = Server.create(**args, start=True, compression=compression)
75
+
76
+ try:
77
+ server.join()
78
+ except KeyboardInterrupt:
79
+ logger.info("Caught KeyboardInterrupt, shutting down")
80
+ finally:
81
+ server.shutdown()
82
+
83
+
84
+ if __name__ == "__main__":
85
+ main()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
2
+ https://github.com/learning-at-home/hivemind/archive/61e5e8c1f33dd2390e6d0d0221e2de6e75741a9c.zip
3
+ huggingface-hub==0.7.0
4
+ accelerate==0.10.0
src/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from src.bloom import *
2
+ from src.client import *
3
+ from src.dht_utils import declare_active_modules, get_remote_module
4
+
5
+ __version__ = "0.2"
src/bloom/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from src.bloom.block import BloomBlock
2
+ from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel
src/bloom/block.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Bloom intermediate layer
3
+ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
4
+ See commit history for authorship.
5
+ """
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.quantized.dynamic.modules.linear
11
+
12
+ from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
13
+ pre_process_alibi_for_pad, split_tensor_along_last_dim)
14
+
15
+
16
+ class BloomAttention(nn.Module):
17
+ def __init__(self, config, layer_number=None):
18
+ super().__init__()
19
+
20
+ self.hidden_size = config.hidden_size
21
+ self.num_heads = config.n_head
22
+ self.head_dim = self.hidden_size // self.num_heads
23
+ self.split_size = self.hidden_size
24
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
25
+ self.masked_softmax_fusion = config.masked_softmax_fusion
26
+ self.hidden_dropout = config.hidden_dropout
27
+
28
+ if self.head_dim * self.num_heads != self.hidden_size:
29
+ raise ValueError(
30
+ f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
31
+ f" {self.num_heads})."
32
+ )
33
+
34
+ # Layer-wise attention scaling
35
+ self.layer_number = max(1, layer_number)
36
+ self.norm_factor = math.sqrt(self.head_dim) * self.layer_number
37
+
38
+ # Scaled Softmax
39
+ self.scale_mask_softmax = BloomScaledSoftmax(
40
+ self.masked_softmax_fusion,
41
+ attention_mask_func,
42
+ self.attention_softmax_in_fp32,
43
+ self.layer_number,
44
+ )
45
+
46
+ self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
47
+ self.dense = nn.Linear(self.hidden_size, self.hidden_size)
48
+
49
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
50
+
51
+ def forward(
52
+ self,
53
+ hidden_states,
54
+ residual,
55
+ layer_past=None,
56
+ attention_mask=None,
57
+ alibi=None,
58
+ head_mask=None,
59
+ use_cache=False,
60
+ output_attentions=False,
61
+ ):
62
+ if alibi is None:
63
+ current_sequence_length = hidden_states.shape[1] + (0 if layer_past is None else layer_past[0].shape[1])
64
+ alibi = build_alibi_tensor(
65
+ current_sequence_length, n_head=self.num_heads, dtype=hidden_states.dtype, device=hidden_states.device
66
+ )
67
+
68
+ # hidden_states: [batch_size, seq_length, hidden_size]
69
+ # apply preprocessing if the input is padded
70
+ if attention_mask is not None:
71
+ alibi = pre_process_alibi_for_pad(alibi, attention_mask)
72
+ # otherwise repeat alibi tensor with the batch size
73
+ else:
74
+ alibi = alibi.repeat(hidden_states.shape[0], 1, 1)
75
+
76
+ mixed_x_layer = self.query_key_value(hidden_states)
77
+
78
+ # [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim]
79
+ new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim)
80
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
81
+
82
+ # [batch_size, seq_length, num_heads, 3 x head_dim] --> 3 [batch_size, seq_length, num_heads, head_dim]
83
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
84
+
85
+ if layer_past is not None:
86
+ past_key, past_value = layer_past
87
+ key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
88
+ value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1)
89
+
90
+ if use_cache is True:
91
+ present = (key_layer, value_layer)
92
+ else:
93
+ present = None
94
+
95
+ # [batch_size, head_dim, q_length, k_length]
96
+ output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1))
97
+
98
+ # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
99
+ query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1)
100
+
101
+ # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
102
+ key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1)
103
+
104
+ # Raw attention scores. [batch_size * num_heads, q_length, k_length]
105
+ beta = 1.0 / self.layer_number
106
+
107
+ matmul_result = torch.baddbmm(
108
+ alibi,
109
+ query_layer.transpose(1, 0),
110
+ key_layer.transpose(1, 0).transpose(1, 2),
111
+ beta=beta,
112
+ alpha=(1.0 / self.norm_factor),
113
+ )
114
+
115
+ # change view to [batch_size, num_heads, q_length, k_length]
116
+ attention_scores = matmul_result.view(*output_size)
117
+
118
+ # attention scores and attention mask [b, np, sq, sk]
119
+ max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
120
+ attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(value_layer.dtype)
121
+ attention_probs = self.attention_dropout(attention_probs)
122
+
123
+ if head_mask is not None:
124
+ attention_probs = attention_probs * head_mask
125
+
126
+ # context layer shape: [batch_size, num_heads, q_length, head_dim]
127
+ output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3))
128
+
129
+ # change view [k_length, batch_size x num_heads, head_dim]
130
+ value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1)
131
+
132
+ # change view [batch_size x num_heads, q_length, k_length]
133
+ attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
134
+
135
+ # matmul: [batch_size * num_heads, q_length, head_dim]
136
+ context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
137
+
138
+ # change view [batch_size, num_heads, q_length, head_dim]
139
+ context_layer = context_layer.view(*output_size)
140
+
141
+ # [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim]
142
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
143
+
144
+ # [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size]
145
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
146
+
147
+ context_layer = context_layer.view(*new_context_layer_shape)
148
+
149
+ # Output. [q_length, batch_size, hidden_size]
150
+
151
+ # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
152
+ output_tensor = self.dense(context_layer)
153
+ output = output_tensor.transpose(1, 0)
154
+
155
+ output = dropout_add(output, residual, self.hidden_dropout, self.training)
156
+
157
+ outputs = (output, present)
158
+ if output_attentions:
159
+ outputs += (attention_probs,)
160
+
161
+ return outputs
162
+
163
+
164
+ class BloomMLP(nn.Module):
165
+ def __init__(self, config):
166
+ super().__init__()
167
+ self.hidden_size = config.hidden_size
168
+ self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
169
+ self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
170
+ self.hidden_dropout = config.hidden_dropout
171
+ self.gelu_impl = BloomGelu()
172
+
173
+ def forward(self, hidden_states, residual):
174
+ hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
175
+ intermediate_output = self.dense_4h_to_h(hidden_states)
176
+ output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
177
+ return output
178
+
179
+
180
+ class BloomBlock(nn.Module):
181
+ def __init__(self, config, layer_number=None):
182
+ super().__init__()
183
+ self.hidden_size = config.hidden_size
184
+
185
+ self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
186
+ self.n_head = config.n_head
187
+ self.self_attention = BloomAttention(config, layer_number=layer_number)
188
+ self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
189
+
190
+ self.mlp = BloomMLP(config)
191
+
192
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
193
+ self.hidden_dropout = config.hidden_dropout
194
+
195
+ def forward(
196
+ self,
197
+ hidden_states,
198
+ layer_past=None,
199
+ attention_mask=None,
200
+ head_mask=None,
201
+ use_cache=False,
202
+ output_attentions=False,
203
+ alibi=None,
204
+ ):
205
+ # hidden_states: [batch_size, seq_length, hidden_size]
206
+
207
+ # Layer norm at the beginning of the transformer layer.
208
+ layernorm_output = self.input_layernorm(hidden_states)
209
+
210
+ # Layer norm post the self attention.
211
+ if self.apply_residual_connection_post_layernorm:
212
+ residual = layernorm_output
213
+ else:
214
+ residual = hidden_states
215
+
216
+ # Self attention.
217
+ attn_outputs = self.self_attention(
218
+ layernorm_output,
219
+ residual,
220
+ layer_past=layer_past,
221
+ attention_mask=attention_mask,
222
+ alibi=alibi,
223
+ head_mask=head_mask,
224
+ use_cache=use_cache,
225
+ output_attentions=output_attentions,
226
+ )
227
+
228
+ attention_output = attn_outputs[0]
229
+
230
+ outputs = attn_outputs[1:]
231
+
232
+ layernorm_output = self.post_attention_layernorm(attention_output)
233
+
234
+ # Get residual
235
+ if self.apply_residual_connection_post_layernorm:
236
+ residual = layernorm_output
237
+ else:
238
+ residual = attention_output
239
+
240
+ # MLP.
241
+ output = self.mlp(layernorm_output, residual)
242
+
243
+ if use_cache:
244
+ outputs = (output,) + outputs
245
+ else:
246
+ outputs = (output,) + outputs[1:]
247
+
248
+ return outputs # hidden_states, present, attentions
src/bloom/from_pretrained.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
3
+ If necessary, one can rewrite this to implement a different behavior, such as:
4
+ - loading files from a local data source (e.g. S3)
5
+ - load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
6
+ - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
7
+
8
+ """
9
+ from __future__ import annotations
10
+
11
+ from typing import Optional, OrderedDict, Union
12
+
13
+ import torch
14
+ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
15
+ from transformers.modeling_utils import WEIGHTS_NAME
16
+ from transformers.utils.hub import cached_path, hf_bucket_url
17
+
18
+ from src.bloom import BloomBlock, BloomConfig
19
+
20
+ use_hivemind_log_handler("in_root_logger")
21
+ logger = get_logger(__file__)
22
+
23
+ CLIENT_BRANCH = "main"
24
+ BLOCK_BRANCH_PREFIX = "block_"
25
+ USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
26
+ FORCE_DOWNLOAD = False
27
+ RESUME_DOWNLOAD = False
28
+ LOCAL_FILES_ONLY = False
29
+
30
+
31
+ def load_pretrained_block(
32
+ converted_model_name_or_path: str,
33
+ block_index: int,
34
+ config: Optional[BloomConfig] = None,
35
+ torch_dtype: Union[torch.dtype, str] = "auto",
36
+ use_auth_token: Optional[str] = None,
37
+ ) -> BloomBlock:
38
+ """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
39
+ if config is None:
40
+ config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
41
+ block = BloomBlock(config, layer_number=block_index)
42
+ state_dict = _load_state_dict(converted_model_name_or_path, block_index, use_auth_token=use_auth_token)
43
+ block.load_state_dict(state_dict)
44
+
45
+ if torch_dtype == "auto":
46
+ with torch.no_grad():
47
+ for name, param in block.named_parameters():
48
+ assert name in state_dict, f"{name} not in state dict"
49
+ param.data = param.data.to(state_dict[name].dtype)
50
+ else:
51
+ assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
52
+ block = block.to(dtype=torch_dtype)
53
+
54
+ report = block.load_state_dict(state_dict, strict=True)
55
+ logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
56
+ return block
57
+
58
+
59
+ def _load_state_dict(
60
+ pretrained_model_name_or_path: str, block_index: Optional[int] = None, use_auth_token: Optional[str] = None
61
+ ) -> OrderedDict[str, torch.Tensor]:
62
+ revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
63
+ archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
64
+
65
+ # Load from URL or cache if already cached
66
+ resolved_archive_file = cached_path(
67
+ archive_file,
68
+ cache_dir=None,
69
+ force_download=FORCE_DOWNLOAD,
70
+ proxies=None,
71
+ resume_download=RESUME_DOWNLOAD,
72
+ local_files_only=LOCAL_FILES_ONLY,
73
+ use_auth_token=use_auth_token,
74
+ user_agent=USER_AGENT,
75
+ )
76
+ state_dict = torch.load(resolved_archive_file, map_location="cpu")
77
+ return state_dict
78
+
79
+
80
+ DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
src/bloom/model.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch BLOOM model that implements several memory-efficient modes.
3
+ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
4
+ See commit history for authorship.
5
+ """
6
+ from typing import Tuple
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint
11
+ from hivemind import use_hivemind_log_handler
12
+ from torch import nn
13
+ from torch.nn import CrossEntropyLoss, LayerNorm
14
+ from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
15
+ add_start_docstrings_to_model_forward)
16
+ from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.models.bloom.configuration_bloom import BloomConfig
19
+ from transformers.utils import logging
20
+
21
+ from src.bloom.block import BloomBlock
22
+
23
+ use_hivemind_log_handler("in_root_logger")
24
+ logger = logging.get_logger(__file__)
25
+
26
+ _CHECKPOINT_FOR_DOC = "bigscience/Bloom"
27
+ _CONFIG_FOR_DOC = "BloomConfig"
28
+ _TOKENIZER_FOR_DOC = "BloomTokenizer"
29
+
30
+
31
+ class BloomPreTrainedModel(PreTrainedModel):
32
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
33
+ """
34
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
35
+ models.
36
+ """
37
+
38
+ config_class = BloomConfig
39
+ base_model_prefix = "transformer"
40
+ supports_gradient_checkpointing = True
41
+ _no_split_modules = ["BloomBlock"]
42
+
43
+ def __init__(self, *inputs, **kwargs):
44
+ super().__init__(*inputs, **kwargs)
45
+
46
+ def _init_weights(self, module):
47
+ """Initialize the weights."""
48
+ if isinstance(module, (nn.Linear)):
49
+ # Slightly different from the TF version which uses truncated_normal for initialization
50
+ # cf https://github.com/pytorch/pytorch/pull/5617
51
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
52
+ if module.bias is not None:
53
+ module.bias.data.zero_()
54
+ elif isinstance(module, nn.Embedding):
55
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
56
+ if module.padding_idx is not None:
57
+ module.weight.data[module.padding_idx].zero_()
58
+ elif isinstance(module, LayerNorm):
59
+ module.bias.data.zero_()
60
+ module.weight.data.fill_(1.0)
61
+
62
+ def _set_gradient_checkpointing(self, module, value=False):
63
+ if isinstance(module, BloomModel):
64
+ module.gradient_checkpointing = value
65
+
66
+
67
+ BLOOM_START_DOCSTRING = r"""
68
+
69
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
70
+ library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
71
+
72
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
73
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
74
+ and behavior.
75
+
76
+ Parameters:
77
+ config ([`MemoryEfficientBloomConfig`]): Model configuration class with all the parameters of the model.
78
+ Initializing with a config file does not load the weights associated with the model, only the
79
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
80
+ """
81
+
82
+ BLOOM_INPUTS_DOCSTRING = r"""
83
+ Args:
84
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
85
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
86
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
87
+ sequence tokens in the vocabulary.
88
+
89
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
90
+ `input_ids`.
91
+
92
+ Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
93
+ [`PreTrainedTokenizer.__call__`] for details.
94
+
95
+ [What are input IDs?](../glossary#input-ids)
96
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
97
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
98
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
99
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
100
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
101
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
102
+
103
+ - 1 for tokens that are **not masked**,
104
+ - 0 for tokens that are **masked**.
105
+
106
+ [What are attention masks?](../glossary#attention-mask)
107
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
108
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
109
+ config.max_position_embeddings - 1]`.
110
+
111
+ [What are position IDs?](../glossary#position-ids)
112
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
113
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
114
+
115
+ - 1 indicates the head is **not masked**,
116
+ - 0 indicates the head is **masked**.
117
+
118
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
119
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
120
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
121
+ model's internal embedding lookup matrix.
122
+
123
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
124
+ `past_key_values`).
125
+ use_cache (`bool`, *optional*):
126
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
127
+ `past_key_values`).
128
+ output_attentions (`bool`, *optional*):
129
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
130
+ tensors for more detail.
131
+ output_hidden_states (`bool`, *optional*):
132
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
133
+ more detail.
134
+ return_dict (`bool`, *optional*):
135
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
136
+ """
137
+
138
+
139
+ @add_start_docstrings(
140
+ "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
141
+ BLOOM_START_DOCSTRING,
142
+ )
143
+ class BloomModel(BloomPreTrainedModel):
144
+ def __init__(self, config):
145
+ super().__init__(config)
146
+ assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
147
+
148
+ self.embed_dim = config.hidden_size
149
+ self.n_head = config.n_head
150
+
151
+ # Embedding + LN Embedding
152
+
153
+ # TODO: @dbaranchuk make efficient fp16 on cpu (convert only word_embeddings!)
154
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) # dtype=config.torch_dtype
155
+ self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
156
+
157
+ # Transformer blocks
158
+ self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)])
159
+
160
+ # Final Layer Norm
161
+ self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
162
+
163
+ self.gradient_checkpointing = False
164
+
165
+ # Initialize weights and apply final processing
166
+ self.post_init()
167
+
168
+ # Forbid accumulate grads for embeddings and layernorm
169
+ self.set_requires_grad(False)
170
+
171
+ def get_input_embeddings(self):
172
+ return self.word_embeddings
173
+
174
+ def set_input_embeddings(self, new_embeddings):
175
+ self.word_embeddings = new_embeddings
176
+
177
+ def set_requires_grad(self, value):
178
+ for p in self.parameters():
179
+ p.requires_grad = value
180
+
181
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
182
+ @add_code_sample_docstrings(
183
+ processor_class=_TOKENIZER_FOR_DOC,
184
+ checkpoint=_CHECKPOINT_FOR_DOC,
185
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
186
+ config_class=_CONFIG_FOR_DOC,
187
+ )
188
+ def forward(
189
+ self,
190
+ input_ids=None,
191
+ past_key_values=None,
192
+ attention_mask=None,
193
+ position_ids=None,
194
+ head_mask=None,
195
+ inputs_embeds=None,
196
+ use_cache=None,
197
+ output_attentions=None,
198
+ output_hidden_states=None,
199
+ return_dict=None,
200
+ ):
201
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
202
+ output_hidden_states = (
203
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
204
+ )
205
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
206
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
207
+
208
+ if input_ids is not None and inputs_embeds is not None:
209
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
210
+ if position_ids is not None:
211
+ logger.warning("position_ids are ignored in this bloom implementation")
212
+ elif input_ids is not None:
213
+ input_shape = input_ids.size()
214
+ input_ids = input_ids.view(-1, input_shape[-1])
215
+ elif inputs_embeds is not None:
216
+ input_shape = inputs_embeds.size()[:-1]
217
+ else:
218
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
219
+
220
+ if past_key_values is None:
221
+ past_key_values = tuple([None] * len(self.h))
222
+
223
+ # Prepare head mask if needed
224
+ # 1.0 in head_mask indicate we keep the head
225
+ # attention_probs has shape bsz x n_head x N x N
226
+ # head_mask has shape n_layer x batch x n_head x N x N
227
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
228
+
229
+ if inputs_embeds is None:
230
+ inputs_embeds = self.word_embeddings(input_ids)
231
+
232
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
233
+
234
+ output_shape = input_shape + (hidden_states.size(-1),)
235
+
236
+ presents = () if use_cache else None
237
+ all_self_attentions = () if output_attentions else None
238
+ all_hidden_states = () if output_hidden_states else None
239
+
240
+ # Compute alibi tensor: check build_alibi_tensor documentation
241
+ current_sequence_length = hidden_states.shape[1]
242
+ if past_key_values and past_key_values[0]:
243
+ current_sequence_length += past_key_values[0][0].shape[1]
244
+
245
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
246
+
247
+ if output_hidden_states:
248
+ all_hidden_states = all_hidden_states + (hidden_states,)
249
+
250
+ if self.gradient_checkpointing and self.training:
251
+
252
+ if use_cache:
253
+ logger.warning(
254
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
255
+ )
256
+ use_cache = False
257
+
258
+ def create_custom_forward(module):
259
+ def custom_forward(*inputs):
260
+ # None for past_key_value
261
+ return module(*inputs, use_cache, output_attentions, alibi=None)
262
+
263
+ return custom_forward
264
+
265
+ outputs = torch.utils.checkpoint.checkpoint(
266
+ create_custom_forward(block),
267
+ hidden_states,
268
+ None,
269
+ attention_mask,
270
+ head_mask[i],
271
+ )
272
+ else:
273
+ outputs = block(
274
+ hidden_states,
275
+ layer_past=layer_past,
276
+ attention_mask=attention_mask,
277
+ head_mask=head_mask[i],
278
+ use_cache=use_cache,
279
+ output_attentions=output_attentions,
280
+ alibi=None,
281
+ )
282
+
283
+ hidden_states = outputs[0]
284
+ if use_cache is True:
285
+ presents = presents + (outputs[1],)
286
+
287
+ if output_attentions:
288
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
289
+
290
+ # Add last hidden state
291
+ hidden_states = self.ln_f(hidden_states)
292
+
293
+ if output_hidden_states:
294
+ all_hidden_states = all_hidden_states + (hidden_states,)
295
+
296
+ hidden_states = hidden_states.view(output_shape)
297
+
298
+ if not return_dict:
299
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
300
+
301
+ return BaseModelOutputWithPastAndCrossAttentions(
302
+ last_hidden_state=hidden_states,
303
+ past_key_values=presents,
304
+ hidden_states=all_hidden_states,
305
+ attentions=all_self_attentions,
306
+ )
307
+
308
+
309
+ @add_start_docstrings(
310
+ """
311
+ The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
312
+ embeddings).
313
+ """,
314
+ BLOOM_START_DOCSTRING,
315
+ )
316
+ class BloomForCausalLM(BloomPreTrainedModel):
317
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
318
+
319
+ def __init__(self, config):
320
+ super().__init__(config)
321
+ self.transformer = BloomModel(config)
322
+ # Initialize weights and apply final processing
323
+ self.post_init()
324
+
325
+ def get_output_embeddings(self):
326
+ return self.transformer.word_embeddings
327
+
328
+ def set_output_embeddings(self, new_embeddings):
329
+ self.transformer.word_embeddings.weight = new_embeddings.weight
330
+
331
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
332
+ # only last token for inputs_ids if past is defined in kwargs
333
+ if past:
334
+ input_ids = input_ids[:, -1].unsqueeze(-1)
335
+
336
+ attention_mask = kwargs.get("attention_mask", None)
337
+ position_ids = kwargs.get("position_ids", None)
338
+
339
+ if attention_mask is not None and position_ids is None:
340
+ # create position_ids on the fly for batch generation
341
+ position_ids = attention_mask.long().cumsum(-1) - 1
342
+ position_ids.masked_fill_(attention_mask == 0, 1)
343
+ if past:
344
+ position_ids = position_ids[:, -1].unsqueeze(-1)
345
+ else:
346
+ position_ids = None
347
+ return {
348
+ "input_ids": input_ids,
349
+ "past_key_values": past,
350
+ "use_cache": kwargs.get("use_cache"),
351
+ "position_ids": position_ids,
352
+ "attention_mask": attention_mask,
353
+ }
354
+
355
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
356
+ @add_code_sample_docstrings(
357
+ processor_class=_TOKENIZER_FOR_DOC,
358
+ checkpoint=_CHECKPOINT_FOR_DOC,
359
+ output_type=CausalLMOutputWithCrossAttentions,
360
+ config_class=_CONFIG_FOR_DOC,
361
+ )
362
+ def forward(self, input_ids=None, labels=None, return_dict=None, **kwargs):
363
+ r"""
364
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
365
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
366
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
367
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
368
+ """
369
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
370
+ transformer_outputs = self.transformer.forward(input_ids=input_ids, return_dict=return_dict, **kwargs)
371
+ word_embeddings = self.transformer.word_embeddings.weight
372
+
373
+ # Switch dtype in case word_embeddings are fp16/bf16
374
+ hidden_states = transformer_outputs[0].to(word_embeddings.dtype)
375
+ lm_logits = F.linear(hidden_states, word_embeddings).float()
376
+
377
+ loss = None
378
+ if labels is not None:
379
+ # Shift so that tokens < n predict n
380
+ shift_logits = lm_logits[..., :-1, :].contiguous()
381
+ shift_labels = labels[..., 1:].contiguous()
382
+ # Flatten the tokens
383
+ loss_fct = CrossEntropyLoss()
384
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
385
+
386
+ if not return_dict:
387
+ output = (lm_logits,) + transformer_outputs[1:]
388
+ return ((loss,) + output) if loss is not None else output
389
+
390
+ return CausalLMOutputWithCrossAttentions(
391
+ loss=loss,
392
+ logits=lm_logits,
393
+ past_key_values=transformer_outputs.past_key_values,
394
+ hidden_states=transformer_outputs.hidden_states,
395
+ attentions=transformer_outputs.attentions,
396
+ )
397
+
398
+ @staticmethod
399
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
400
+ """
401
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
402
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
403
+ beam_idx at every generation step.
404
+ """
405
+ return tuple(
406
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
407
+ for layer_past in past
408
+ )
src/bloom/ops.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility operations used in the the BLOOM model
3
+ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
4
+ See commit history for authorship.
5
+ """
6
+ import math
7
+
8
+ import torch
9
+ import torch.autograd
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+
13
+
14
+ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
15
+ """Split a tensor along its last dimension.
16
+
17
+ Args:
18
+ tensor: ([`torch.tensor`], *required*):
19
+ input tensor to split
20
+ num_partitions ([`int`], *required*):
21
+ number of partitions to split the tensor
22
+ contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
23
+ If True, make each chunk contiguous in memory.
24
+ """
25
+ # Get the size and dimension.
26
+ last_dim = tensor.dim() - 1
27
+ numerator, denominator = tensor.size()[last_dim], num_partitions
28
+ if not (numerator % denominator == 0):
29
+ raise ValueError(f"{numerator} is not divisible by {denominator}")
30
+ last_dim_size = numerator // denominator
31
+ # Split.
32
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
33
+ # Note: torch.split does not create contiguous tensors by default.
34
+ if contiguous_split_chunks:
35
+ return tuple(chunk.contiguous() for chunk in tensor_list)
36
+
37
+ return tensor_list
38
+
39
+
40
+ def attention_mask_func(attention_scores, attention_mask, causal_mask):
41
+ if attention_mask.dtype == torch.bool:
42
+ attention_mask_bool = ~attention_mask
43
+ else:
44
+ attention_mask_bool = (1 - attention_mask).bool()
45
+
46
+ query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
47
+ padded_causal_mask = (
48
+ attention_mask_bool[:, None, key_length - query_length : key_length, None]
49
+ + ~causal_mask[:, :, key_length - query_length : key_length, :key_length]
50
+ ).bool()
51
+ padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool()
52
+ # Make use of floats
53
+ return (
54
+ attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
55
+ padded_causal_mask,
56
+ )
57
+
58
+
59
+ def build_alibi_tensor(
60
+ max_seq_len: int, n_head: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu")
61
+ ) -> torch.Tensor:
62
+ """
63
+ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
64
+ relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
65
+ `softmax(l+a) = softmax(l)`. Based on
66
+ https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
67
+ Args:
68
+ Returns tensor shaped (n_head, 1, max_seq_len)
69
+ max_seq_len: (`int`, *required*):
70
+ max sequence length
71
+ n_head: (`int`, *required*):
72
+ number of heads
73
+ dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`):
74
+ dtype of the output tensor
75
+ device: (`torch.device`, *optional*, default=`torch.device('cpu')`):
76
+ device of the output alibi tensor
77
+ """
78
+ closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
79
+ base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
80
+ powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
81
+ slopes = torch.pow(base, powers)
82
+
83
+ if closest_power_of_2 != n_head:
84
+ extra_base = torch.tensor(
85
+ 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
86
+ )
87
+ num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2)
88
+ extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
89
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
90
+
91
+ lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32)
92
+ return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype)
93
+
94
+
95
+ def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor):
96
+ """
97
+ Args:
98
+ Pre-process the alibi tensor for padding.
99
+ alibi: ([`torch.tensor`], *required*):
100
+ alibi tensor to pre-process
101
+ attention_mask: ([`torch.tensor`], *required*):
102
+ attention mask to pre-process
103
+ """
104
+ assert attention_mask.shape.ndim == 2, "mask should be [batch_size, seq_length]"
105
+ unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1)
106
+ # ^-- [batch, max_len], values correspond to element indices after removing padding
107
+ # We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0
108
+ alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0)
109
+ return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1)
110
+
111
+
112
+ def dropout_add(x, residual, prob, training):
113
+ """
114
+ Dropout add function
115
+
116
+ Args:
117
+ x (`torch.tensor`, *required*):
118
+ input tensor
119
+ residual (`torch.tensor`, *rquired*):
120
+ esidual tensor
121
+ prob (`float`, *required*):
122
+ dropout probability
123
+ training (`bool`, *required*):
124
+ training mode
125
+ """
126
+ out = nn.functional.dropout(x, p=prob, training=training)
127
+ out = residual + out
128
+ return out
129
+
130
+
131
+ def bloom_gelu_forward(x):
132
+ """
133
+ Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
134
+ make the model jitable.
135
+
136
+ Args:
137
+ x (`torch.tensor`, *required*):
138
+ input hidden states
139
+ """
140
+ return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
141
+
142
+
143
+ def bloom_gelu_back(g, x):
144
+ """
145
+ gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
146
+ 0.3989423 * x * torch.exp(-0.5 * x * x)
147
+
148
+ Args:
149
+ g (`torch.tensor`, *required*):
150
+ gradient output tensor
151
+ x (`torch.tensor`, *required*):
152
+ input tensor
153
+ """
154
+ x = x[0] # x is a tuple of 1 element, needs to unpack it first
155
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
156
+ # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
157
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
158
+ return ff * g
159
+
160
+
161
+ class GeLUFunction(torch.autograd.Function):
162
+ @staticmethod
163
+ def forward(ctx, input):
164
+ ctx.save_for_backward(input)
165
+ return bloom_gelu_forward(input)
166
+
167
+ @staticmethod
168
+ def backward(ctx, grad_output):
169
+ input = ctx.saved_tensors
170
+ tmp = bloom_gelu_back(grad_output, input)
171
+ return tmp
172
+
173
+
174
+ class BloomGelu(nn.Module):
175
+ """
176
+ BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model
177
+ torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
178
+ copied from Megatron-DeepSpeed code and adapted for our needs
179
+
180
+ See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
181
+
182
+ """
183
+
184
+ def __init__(self):
185
+ super().__init__()
186
+
187
+ def forward(self, x):
188
+ if self.training:
189
+ return GeLUFunction.apply(x)
190
+ else:
191
+ return bloom_gelu_forward(x)
192
+
193
+
194
+ class BloomScaledSoftmax(nn.Module):
195
+ """
196
+ fused operation: scaling + mask + softmax
197
+
198
+ Args:
199
+ input_in_fp16 (`bool`, *required*):
200
+ flag to indicate if input in fp16 data format.
201
+ input_in_bf16 (`bool`, *required*):
202
+ flag to indicate if input in bf16 data format.
203
+ scaled_masked_softmax_fusion (`bool`, *required*):
204
+ flag to indicate user want to use softmax fusion
205
+ mask_func (`function`, *required*):
206
+ mask function to be applied.
207
+ softmax_in_fp32 (`bool`, *required*):
208
+ if true, softmax in performed at fp32 precision.
209
+ scale (`float`, *required*):
210
+ scaling factor used in input tensor scaling.
211
+ """
212
+
213
+ def __init__(self, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale):
214
+ super().__init__()
215
+ self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
216
+ self.mask_func = mask_func
217
+ self.softmax_in_fp32 = softmax_in_fp32
218
+ self.scale = scale
219
+
220
+ if not (self.scale is None or softmax_in_fp32):
221
+ raise ValueError("softmax should be in fp32 when scaled")
222
+
223
+ def forward(self, input, mask, max_positions):
224
+ input_dtype = input.dtype
225
+ input_in_16bit = input_dtype in [torch.float16, torch.bfloat16]
226
+ softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype
227
+
228
+ if self.scale is not None:
229
+ input = input * self.scale
230
+
231
+ if mask is None:
232
+ mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
233
+
234
+ mask = mask.to(input.device)
235
+ causal_mask = (
236
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
237
+ .view(1, 1, max_positions, max_positions)
238
+ .to(input.device)
239
+ )
240
+ mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
241
+ probs = F.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
242
+
243
+ if input_in_16bit and self.softmax_in_fp32:
244
+ probs = probs.to(dtype=input_dtype)
245
+
246
+ return probs
src/client/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
2
+ from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
3
+ from src.client.remote_sequence_info import RemoteSequenceInfo
4
+ from src.client.remote_sequential import RemoteSequential
src/client/remote_block.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
2
+ from __future__ import annotations
3
+
4
+ import asyncio
5
+ import random
6
+ from typing import Any, AsyncIterator, Dict, Optional
7
+
8
+ import torch
9
+ from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
10
+ from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
11
+ from hivemind.moe.expert_uid import ExpertInfo
12
+ from hivemind.p2p import P2P, StubBase
13
+ from hivemind.proto import runtime_pb2
14
+ from hivemind.utils import anext, get_logger, nested_flatten, use_hivemind_log_handler
15
+
16
+ from src.data_structures import RemoteModuleInfo
17
+ from src.dht_utils import ModuleUID
18
+ from src.server.handler import TransformerConnectionHandler
19
+
20
+ use_hivemind_log_handler("in_root_logger")
21
+ logger = get_logger(__file__)
22
+
23
+
24
+ class RemoteTransformerBlock(RemoteExpert):
25
+ """A class that interacts with a remote module on a specific server for forward/backward or inference"""
26
+
27
+ def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
28
+ peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.peer_ids))) # TODO replace this
29
+ super().__init__(peer_info, p2p)
30
+
31
+ @property
32
+ def stub(self) -> StubBase:
33
+ return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
34
+
35
+ def forward(self, inputs: torch.Tensor, **kwargs):
36
+ for k, v in kwargs.items():
37
+ assert v is None or v is False, f"Extra keyword arguments are not yet supported (got {k} = {v})"
38
+ return super().forward(inputs)
39
+
40
+ def inference_session(self) -> RemoteTransformerBlockInferenceSession:
41
+ """Initialize a new inference session with the specified remote server"""
42
+ _ = self.info # create _info manually since the built-in property will not work inside RemoteExpertWorker
43
+ return RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(self))
44
+
45
+ def begin_inference_session(self):
46
+ logger.warning("beging_inference_session was renamed to just inference_session")
47
+ return self.inference_session()
48
+
49
+
50
+ class RemoteTransformerBlockInferenceSession:
51
+ """An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
52
+
53
+ def __init__(self, uid: ModuleUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
54
+ self.uid, self.info = uid, info
55
+ # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
56
+ # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
57
+ self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
58
+ self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
59
+ self.stepped = False
60
+ self.closed = False
61
+
62
+ @classmethod
63
+ async def _create(
64
+ cls, remote_module: RemoteTransformerBlock, timeout: Optional[float] = None
65
+ ) -> RemoteTransformerBlockInferenceSession:
66
+ """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
67
+ inputs_queue = asyncio.Queue()
68
+ outputs_stream = await remote_module.stub.rpc_inference(
69
+ cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout
70
+ )
71
+ return cls(remote_module.uid, remote_module.info, inputs_queue, outputs_stream)
72
+
73
+ @staticmethod
74
+ async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
75
+ while True:
76
+ next_input_message = await asyncio.wait_for(queue.get(), timeout)
77
+ yield next_input_message
78
+ if not next_input_message.uid and not next_input_message.tensors:
79
+ break # this message means "done sending"
80
+
81
+ def step(self, new_hidden_states: torch.Tensor):
82
+ """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
83
+ if self.closed:
84
+ raise Exception("Session is closed, cannot perform step")
85
+ # serialize inputs and put them into the queue
86
+ inputs = (new_hidden_states,)
87
+ outputs_serialized = RemoteExpertWorker.run_coroutine(
88
+ self._step(
89
+ runtime_pb2.ExpertRequest(
90
+ uid=self.uid,
91
+ tensors=[
92
+ serialize_torch_tensor(tensor, proto.compression)
93
+ for tensor, proto in zip(inputs, nested_flatten(self.info["forward_schema"]))
94
+ ],
95
+ )
96
+ )
97
+ )
98
+ outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
99
+ assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
100
+ return outputs[0]
101
+
102
+ async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
103
+ """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
104
+ await self._inputs_queue.put(inputs_serialized)
105
+ self.stepped = True
106
+ return await anext(self._outputs_stream)
107
+
108
+ def close(self):
109
+ """Finish a given inference session, close the underlying connection"""
110
+ if self._outputs_stream is None:
111
+ return # already closed
112
+ RemoteExpertWorker.run_coroutine(self._aclose_stream())
113
+ self._outputs_stream = self._inputs_queue = None
114
+ self.closed = True
115
+
116
+ async def _aclose_stream(self):
117
+ """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
118
+ if self._outputs_stream is None:
119
+ return # already closed
120
+ if self.stepped:
121
+ await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
122
+ try:
123
+ await anext(self._outputs_stream)
124
+ except StopAsyncIteration:
125
+ pass
126
+
127
+ def __del__(self):
128
+ self.close()
129
+
130
+ def __enter__(self):
131
+ assert not self.closed
132
+ return self
133
+
134
+ def __exit__(self, *exc_details):
135
+ self.close()
src/client/remote_model.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this code is in active development, interfaces may change
2
+ import os
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import hivemind
6
+ from hivemind import DHT, get_logger, use_hivemind_log_handler
7
+
8
+ from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
9
+ from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel
10
+ from src.client.remote_sequential import RemoteSequential
11
+ from src.data_structures import UID_DELIMITER
12
+
13
+ use_hivemind_log_handler("in_root_logger")
14
+ logger = get_logger(__file__)
15
+
16
+
17
+ class DistributedBloomConfig(BloomConfig):
18
+ """
19
+ A bloom config that contains information about DHT peers.
20
+ To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
21
+ """
22
+
23
+ initial_peers: Tuple[str, ...] = () # a list of initial peers for hivemind DHT
24
+ dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
25
+ dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
26
+
27
+
28
+ class DistributedBloomModel(BloomModel):
29
+ """BloomModel, but all transformer layers are hosted by the swarm"""
30
+ config_class = DistributedBloomConfig
31
+
32
+ def __init__(self, config: DistributedBloomConfig):
33
+ assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
34
+ assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
35
+
36
+ n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
37
+ super().__init__(config)
38
+ assert len(self.h) == 0
39
+ config.n_layer = n_layer
40
+
41
+ dht = (
42
+ config.dht
43
+ if config.dht is not None
44
+ else hivemind.DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
45
+ )
46
+ assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
47
+ self.h = RemoteSequential(config, dht, config.dht_prefix)
48
+
49
+
50
+ class DistributedBloomForCausalLM(BloomForCausalLM):
51
+ """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
52
+ config_class = DistributedBloomConfig
53
+
54
+ def __init__(self, config: DistributedBloomConfig):
55
+ BloomPreTrainedModel.__init__(self, config)
56
+ self.transformer = DistributedBloomModel(config)
57
+ # Initialize weights and apply final processing
58
+ self.post_init()
src/client/remote_sequence_info.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import threading
5
+ from functools import partial
6
+ from typing import List, NamedTuple, Optional, Sequence, Tuple
7
+
8
+ from hivemind import DHT, PeerID
9
+ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
10
+
11
+ from src.data_structures import ModuleUID, RemoteModuleInfo
12
+ from src.dht_utils import _get_remote_module_infos
13
+
14
+ use_hivemind_log_handler("in_root_logger")
15
+ logger = get_logger(__file__)
16
+
17
+
18
+ Span = NamedTuple("Span", [("start", int), ("end", Optional[int]), ("peer_id", PeerID)])
19
+
20
+
21
+ @dataclasses.dataclass(frozen=False, init=False) # TODO[borzunov@] eto ne dataclass
22
+ class RemoteSequenceInfo:
23
+ """Keeps and updates the meta-information about which peers host which blocks"""
24
+
25
+ dht: DHT
26
+ block_uids: List[ModuleUID, ...]
27
+ block_infos: List[Optional[RemoteModuleInfo], ...]
28
+ spans_by_priority: List[Span] # sorted from best to worst
29
+ spans_containing_block: Tuple[List[Span], ...]
30
+ lock_changes: threading.Lock
31
+
32
+ def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
33
+ self.dht = dht
34
+ self.block_uids = list(block_uids)
35
+ self.block_infos: List[Optional[RemoteModuleInfo], ...] = [None] * len(self.block_uids)
36
+ self.spans_by_priority = []
37
+ self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
38
+ self.lock_changes = threading.Lock()
39
+ self.update_()
40
+
41
+ for uid, info in zip(self.block_uids, self.block_infos):
42
+ assert info is not None, f"Found no remote peers for block {uid}"
43
+ assert self.spans_by_priority and self.spans_containing_block
44
+
45
+ def update_(self):
46
+ with self.lock_changes:
47
+ self.update_block_infos_()
48
+ self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
49
+
50
+ def update_block_infos_(self):
51
+ new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine(
52
+ partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")), return_future=False
53
+ )
54
+ assert len(new_block_infos) == len(self.block_uids)
55
+ for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
56
+ if info is None:
57
+ logger.warning(f"Found no block info for block {uid}")
58
+ if not isinstance(info, RemoteModuleInfo):
59
+ logger.warning(f"Unexpected dht entry type for {uid}: {info}")
60
+ if not info.peer_ids:
61
+ logger.warning(f"Found no active peers for block {uid}")
62
+ if info.uid != uid:
63
+ logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
64
+ if not isinstance(info.peer_ids, set):
65
+ logger.warning(f"Expected peer_ids for {uid} to be a set, got {type(info.peer_ids)}")
66
+ self.block_infos[block_index] = info
67
+
68
+ @staticmethod
69
+ def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
70
+ closed_spans = []
71
+ active_spans = {}
72
+ for block_index, info in enumerate(block_infos):
73
+ for peer_id in info.peer_ids:
74
+ if peer_id not in active_spans:
75
+ active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
76
+ else: # peer_id in active_spans
77
+ active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
78
+
79
+ for peer_id in list(active_spans.keys()):
80
+ if peer_id not in info.peer_ids or block_index == len(block_infos) - 1:
81
+ closed_spans.append(active_spans.pop(peer_id))
82
+ assert not active_spans
83
+
84
+ closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
85
+
86
+ spans_containing_block = tuple(list() for _ in range(len(block_infos)))
87
+ for span in closed_spans:
88
+ for block_index in range(span.start, span.end):
89
+ spans_containing_block[block_index].append(span)
90
+
91
+ return closed_spans, spans_containing_block
92
+
93
+ def __len__(self):
94
+ return len(self.block_uids)
src/client/remote_sequential.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import logging
5
+ import random
6
+
7
+ import torch
8
+ from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
9
+ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
10
+ from hivemind.moe.expert_uid import ExpertInfo
11
+ from torch import nn
12
+
13
+ import src
14
+ from src.client.remote_block import RemoteTransformerBlock
15
+ from src.client.remote_sequence_info import RemoteSequenceInfo
16
+ from src.data_structures import UID_DELIMITER
17
+ from src.dht_utils import _create_remote_modules_from_infos
18
+
19
+ use_hivemind_log_handler("in_root_logger")
20
+ logger = get_logger(__file__)
21
+
22
+
23
+ class RemoteSequential(nn.Module):
24
+ """
25
+ A sequence of transformer blocks hosted by the swarm.
26
+ """
27
+
28
+ def __init__(self, config: src.DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
29
+ logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
30
+ if prefix.endswith(UID_DELIMITER):
31
+ logger.warning(
32
+ f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
33
+ f"This will cause {self.__class__.__name__} to look for modules under "
34
+ f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended."
35
+ )
36
+
37
+ super().__init__()
38
+ self.config = config
39
+ self.dht = dht
40
+ self.prefix = prefix
41
+ self.max_retries = max_retries
42
+ self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
43
+
44
+ block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
45
+
46
+ logger.debug(f"Remote block uids: {block_uids}")
47
+ self.remote_sequence_info = RemoteSequenceInfo(dht, block_uids)
48
+
49
+ def forward(self, inputs: torch.Tensor):
50
+ assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
51
+ for block_index in range(self.config.n_layer):
52
+ for retry_index in range(self.max_retries):
53
+ try:
54
+ block = self[block_index]
55
+ (outputs,) = block(inputs)
56
+ assert isinstance(outputs, torch.Tensor)
57
+ assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
58
+ inputs = outputs
59
+ break
60
+ except Exception as e:
61
+ if retry_index == self.max_retries - 1:
62
+ raise e
63
+ else:
64
+ logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
65
+ return inputs
66
+
67
+ def __getitem__(self, block_index: int):
68
+ assert 0 <= block_index < self.config.n_layer
69
+ (module,) = _create_remote_modules_from_infos([self.remote_sequence_info.block_infos[block_index]], self.p2p)
70
+ return module
71
+
72
+ def __iter__(self):
73
+ for block_index in range(self.config.n_layer):
74
+ yield self[block_index]
75
+
76
+ def __len__(self):
77
+ return len(self.remote_sequence_info)
78
+
79
+ def inference_session(self) -> RemoteSequentialInferenceSession:
80
+ self.remote_sequence_info.update_()
81
+ return RemoteSequentialInferenceSession(self.remote_sequence_info, self.p2p)
82
+
83
+
84
+ class RemoteSequentialInferenceSession:
85
+ """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
86
+
87
+ def __init__(self, remote_sequence_info: RemoteSequenceInfo, p2p: P2P):
88
+ self.remote_sequence_info = remote_sequence_info
89
+ self.p2p = p2p
90
+ self.closed = False
91
+ self.stack = contextlib.ExitStack()
92
+ self.active_sessions = []
93
+
94
+ def __enter__(self):
95
+ assert not self.closed
96
+ self.stack.__enter__()
97
+ # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
98
+ current_block = 0
99
+ while current_block != len(self.remote_sequence_info):
100
+ candidate_spans = self.remote_sequence_info.spans_containing_block[current_block]
101
+ chosen_span = random.choice(candidate_spans) # TODO this is a temporary code
102
+ assert chosen_span.start <= current_block < chosen_span.end
103
+
104
+ # TODO begin throwaway prototype code
105
+ remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p)
106
+ _ = remote.info # TODO fix
107
+ span_uids = self.remote_sequence_info.block_uids[current_block : chosen_span.end]
108
+ remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
109
+ self.active_sessions.append(remote.inference_session())
110
+ self.stack.enter_context(self.active_sessions[-1])
111
+ current_block = chosen_span.end
112
+ # TODO end throwaway prototype code
113
+
114
+ return self
115
+
116
+ def step(self, inputs: torch.Tensor):
117
+ assert not self.closed
118
+ for session in self.active_sessions:
119
+ outputs = session.step(inputs)
120
+ assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
121
+ inputs = outputs
122
+ return inputs
123
+
124
+ def close(self, *exc_details):
125
+ """Finish a given inference session, close the underlying connection"""
126
+ if not self.closed:
127
+ self.stack.__exit__(*exc_details or (None, None, None))
128
+ self.active_sessions.clear()
129
+ self.closed = True
130
+
131
+ def __exit__(self, *exc_details):
132
+ self.close(*exc_details)
133
+
134
+ def __del__(self):
135
+ self.close()
src/data_structures.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from typing import Collection, NamedTuple
2
+
3
+ from hivemind import PeerID
4
+
5
+ ModuleUID = str
6
+ UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
7
+ CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
8
+ RemoteModuleInfo = NamedTuple("RemoteModuleInfo", [("uid", ModuleUID), ("peer_ids", Collection[PeerID])])
src/dht_utils.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for declaring and retrieving active model layers using a shared DHT.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ from functools import partial
7
+ from typing import Dict, List, Optional, Sequence, Union
8
+
9
+ from hivemind.dht import DHT, DHTNode, DHTValue
10
+ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
11
+ from hivemind.p2p import P2P, PeerID
12
+ from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
13
+
14
+ import src
15
+ from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo
16
+
17
+ use_hivemind_log_handler("in_root_logger")
18
+ logger = get_logger(__file__)
19
+
20
+
21
+ def declare_active_modules(
22
+ dht: DHT,
23
+ uids: Sequence[ModuleUID],
24
+ expiration_time: DHTExpiration,
25
+ throughput: Optional[float] = None,
26
+ wait: bool = True,
27
+ ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
28
+ """
29
+ Declare that your node serves the specified modules; update timestamps if declared previously
30
+
31
+ :param uids: a list of module ids to declare
32
+ :param wait: if True, awaits for declaration to finish, otherwise runs in background
33
+ :param throughput: optionally specify your performance in terms of compute throughput
34
+ :param expiration_time: declated modules will be visible for this many seconds
35
+ :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
36
+ """
37
+ if isinstance(uids, str):
38
+ uids = [uids]
39
+ if not isinstance(uids, list):
40
+ uids = list(uids)
41
+ for uid in uids:
42
+ assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
43
+ return dht.run_coroutine(
44
+ partial(_declare_active_modules, uids=uids, expiration_time=expiration_time, throughput=throughput),
45
+ return_future=not wait,
46
+ )
47
+
48
+
49
+ async def _declare_active_modules(
50
+ dht: DHT,
51
+ node: DHTNode,
52
+ uids: List[ModuleUID],
53
+ expiration_time: DHTExpiration,
54
+ throughput: Optional[float] = None,
55
+ ) -> Dict[ModuleUID, bool]:
56
+ num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
57
+ return await node.store_many(
58
+ keys=uids,
59
+ subkeys=[dht.peer_id.to_base58()] * len(uids),
60
+ values=[throughput] * len(uids),
61
+ expiration_time=expiration_time,
62
+ num_workers=num_workers,
63
+ )
64
+
65
+
66
+ def get_remote_module(
67
+ dht: DHT,
68
+ uid_or_uids: Union[ModuleUID, List[ModuleUID]],
69
+ expiration_time: Optional[DHTExpiration] = None,
70
+ return_future: bool = False,
71
+ ) -> Union[List[Optional[src.RemoteTransformerBlock]], MPFuture[List[Optional[src.RemoteTransformerBlock]]]]:
72
+ """
73
+ :param uid_or_uids: find one or more modules with these ids from across the DHT
74
+ :param expiration_time: if specified, return modules that expire no sooner than this (based on get_dht_time)
75
+ :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
76
+ :returns: a list of [RemoteTransformerBlock if found else None]
77
+ """
78
+ single_uid = isinstance(uid_or_uids, ModuleUID)
79
+ uids = [uid_or_uids] if single_uid else uid_or_uids
80
+ infos = dht.run_coroutine(
81
+ partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future
82
+ )
83
+
84
+ if return_future:
85
+
86
+ async def _unpack(infos_future: MPFuture, dht: DHT):
87
+ p2p = await dht.replicate_p2p()
88
+ modules = _create_remote_modules_from_infos(await infos_future, p2p)
89
+ return modules[0] if single_uid else modules
90
+
91
+ return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
92
+ p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
93
+ modules = _create_remote_modules_from_infos(infos, p2p)
94
+ return modules[0] if single_uid else modules
95
+
96
+
97
+ async def _get_remote_module_infos(
98
+ dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration]
99
+ ) -> List[Optional[RemoteModuleInfo]]:
100
+ if expiration_time is None:
101
+ expiration_time = get_dht_time()
102
+ num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
103
+ found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
104
+
105
+ modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
106
+ for i, uid in enumerate(uids):
107
+ metadata = found[uid]
108
+ if metadata is None or not isinstance(metadata.value, dict):
109
+ if metadata is not None:
110
+ logger.error(f"Incorrect metadata for {uid}: {metadata}")
111
+ continue
112
+ valid_entries = set()
113
+ for maybe_peer_id, _unused_value in metadata.value.items():
114
+ try:
115
+ valid_entries.add(PeerID.from_base58(maybe_peer_id))
116
+ except:
117
+ logger.error(f"Incorrect peer entry for {uid}: {maybe_peer_id}")
118
+ if valid_entries:
119
+ modules[i] = RemoteModuleInfo(uid, valid_entries)
120
+ return modules
121
+
122
+
123
+ def _create_remote_modules_from_infos(
124
+ infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
125
+ ) -> List[Optional[src.RemoteTransformerBlock]]:
126
+ modules: List[Optional[src.RemoteTransformerBlock]] = []
127
+ for info in infos:
128
+ if info is not None:
129
+ modules.append(src.RemoteTransformerBlock(info, p2p))
130
+ else:
131
+ modules.append(None)
132
+ return modules
src/server/__init__.py ADDED
File without changes
src/server/backend.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Code for serving bloom blocks via hivemind-server"""
2
+ from typing import Sequence, Tuple
3
+
4
+ import torch
5
+ from hivemind.moe.server.module_backend import ModuleBackend
6
+ from hivemind.moe.server.task_pool import TaskPool
7
+
8
+ from src.bloom.from_pretrained import BloomBlock
9
+ from src.server.cache import MemoryCache
10
+
11
+ MAX_LENGTH = 2048
12
+
13
+
14
+ class TransformerBackend(ModuleBackend):
15
+ """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
16
+
17
+ def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
18
+ super().__init__(*args, **kwargs)
19
+ assert isinstance(self.module, BloomBlock)
20
+ self.memory_cache = memory_cache
21
+ for name, param in self.module.named_parameters():
22
+ assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
23
+ for name, buf in self.module.named_buffers():
24
+ assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
25
+
26
+ self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
27
+
28
+ def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
29
+ with torch.inference_mode():
30
+ attention_cache_handle = int(cache_metadata[0, 0].item())
31
+ prefix_length = int(cache_metadata[0, 1].item())
32
+ hidden_states = inputs[0] # todo: in future, it would be best to support attention mask here
33
+ assert (
34
+ hidden_states.ndim == 3
35
+ ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
36
+
37
+ with self.memory_cache.use_cache(attention_cache_handle) as cache:
38
+ assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
39
+ layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
40
+ print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
41
+ hidden_states, (new_k, new_v) = self.module.forward(
42
+ hidden_states, layer_past=layer_past, use_cache=True
43
+ )
44
+
45
+ # todo remove these asserts once we pass all tests
46
+ new_length = new_v.shape[1]
47
+ assert new_length > prefix_length
48
+ assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
49
+ assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
50
+ assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
51
+ assert torch.allclose(new_v[:, : past_v.shape[1]], past_v)
52
+ assert torch.allclose(new_k[:, : past_k.shape[1]], past_k)
53
+ cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
54
+ cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
55
+ return (hidden_states,)
56
+
57
+ def get_pools(self) -> Sequence[TaskPool]:
58
+ return self.forward_pool, self.backward_pool, self.inference_pool
src/server/cache.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and used over multiple calls to Runtime.
3
+
4
+ For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
5
+
6
+ """
7
+ import contextlib
8
+ import ctypes
9
+ import multiprocessing as mp
10
+ import os
11
+ from typing import AsyncContextManager, Dict, Optional, Union
12
+
13
+ import hivemind
14
+ import torch
15
+ from hivemind import use_hivemind_log_handler
16
+ from hivemind.utils import TensorDescriptor, get_logger
17
+
18
+ use_hivemind_log_handler("in_root_logger")
19
+ logger = get_logger(__file__)
20
+
21
+ Handle = int
22
+
23
+
24
+ class MemoryCache:
25
+ """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
26
+
27
+ def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
28
+ self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
29
+ self.device = device
30
+ self.lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
31
+ self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
32
+ self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
33
+ self._active_handles: Optional[Dict[Handle, TensorDescriptor]] = None
34
+ self._allocated_tensors: Optional[Dict[Handle, torch.Tensor]] = None
35
+ self.runtime_pid = os.getpid()
36
+
37
+ self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False) # any ConnectionHandler -> runtime
38
+ self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False)
39
+
40
+ @property
41
+ def current_size_bytes(self) -> int:
42
+ return self._current_size.value
43
+
44
+ @current_size_bytes.setter
45
+ def current_size_bytes(self, value: int):
46
+ self._current_size.value = value
47
+
48
+ @property
49
+ def handle_counter(self) -> int:
50
+ return self._handle_counter.value
51
+
52
+ @handle_counter.setter
53
+ def handle_counter(self, value: int):
54
+ self._handle_counter.value = value
55
+
56
+ @contextlib.asynccontextmanager
57
+ async def allocate_cache(self, descr: TensorDescriptor) -> AsyncContextManager[Handle]:
58
+ """
59
+ Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
60
+
61
+ :param descr: allocate a tensor of this size, dtype, etc
62
+
63
+ :note: This function should be called by connection handlers, it can be called concurrently from multiple processes.
64
+ Furthermore, it can be called concurrently with at most one use_cache call in runtime.
65
+ """
66
+ assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
67
+ assert descr.device is None and descr
68
+ allocated_handle = None
69
+ allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
70
+ try:
71
+ async with hivemind.utils.enter_asynchronously(self.lock_metadata):
72
+ if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
73
+ raise AllocationFailed(
74
+ f"Could not allocate {allocated_size_bytes} bytes in cache; cache size = "
75
+ f"{self.max_size_bytes} bytes; {self.current_size_bytes} already allocated."
76
+ )
77
+
78
+ allocated_handle = int(self.handle_counter)
79
+ self.current_size_bytes += allocated_size_bytes
80
+ self.handle_counter += 1 # note: this will eventually overflow and it is okay
81
+ self._pending_messages.value += 1
82
+ self._pipe_send.send((allocated_handle, descr))
83
+
84
+ yield allocated_handle
85
+ finally:
86
+ if allocated_handle is not None:
87
+ async with hivemind.utils.enter_asynchronously(self.lock_metadata):
88
+ self._pending_messages.value += 1
89
+ self._pipe_send.send((allocated_handle, None)) # signal runtime to free that handle
90
+ self.current_size_bytes -= allocated_size_bytes
91
+
92
+ @contextlib.contextmanager
93
+ def use_cache(self, handle: Handle) -> torch.Tensor:
94
+ """
95
+ Return a tensor that was previously allocated with try_allocate_cache,
96
+
97
+ :note: This method is called by ExpertBackend in runtime: a single process with NO process parallelism.
98
+ However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache
99
+ """
100
+ assert os.getpid() == self.runtime_pid
101
+ # note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
102
+
103
+ with self.lock_metadata:
104
+ if self._allocated_tensors is None:
105
+ self._allocated_tensors = {}
106
+
107
+ # read creation/deletion requests from connection handlers
108
+ for i in range(int(self._pending_messages.value)):
109
+ recv_handle, recv_data = self._pipe_recv.recv()
110
+ self._pending_messages.value -= 1
111
+ if isinstance(recv_data, TensorDescriptor):
112
+ self._allocated_tensors[recv_handle] = recv_data.make_zeros(device=self.device)
113
+ elif recv_data is None:
114
+ if recv_handle not in self._allocated_tensors:
115
+ logger.warning(
116
+ f"Sanity check failed: asked to delete handle {recv_handle}, but there is no such handle"
117
+ )
118
+ self._allocated_tensors.pop(recv_handle, None)
119
+ else:
120
+ logger.error(f"MemoryCache pipe received unexpected message: {recv_data}")
121
+
122
+ assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
123
+ yield self._allocated_tensors[handle]
124
+
125
+
126
+ class AllocationFailed(Exception):
127
+ pass
src/server/handler.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
2
+ import contextlib
3
+ from typing import AsyncIterator, Dict, Sequence
4
+
5
+ import torch
6
+ from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor
7
+ from hivemind.moe.server.connection_handler import ConnectionHandler
8
+ from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
9
+ from hivemind.proto import runtime_pb2
10
+ from hivemind.utils import as_aiter
11
+ from hivemind.utils.asyncio import anext
12
+ from hivemind.utils.streaming import split_for_streaming
13
+
14
+ from src.data_structures import CHAIN_DELIMITER, ModuleUID
15
+ from src.server.backend import MAX_LENGTH, TransformerBackend
16
+
17
+
18
+ class TransformerConnectionHandler(ConnectionHandler):
19
+ """Handles three request types: forward, backward and forward-incremental (inference)"""
20
+
21
+ module_backends: Dict[ModuleUID, TransformerBackend]
22
+
23
+ def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
24
+ super().__init__(dht, module_backends)
25
+ for module_backend in self.module_backends.values():
26
+ assert isinstance(module_backend, TransformerBackend)
27
+
28
+ async def rpc_inference(
29
+ self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
30
+ ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
31
+ """Compute a single step of inference using attention cache; update attention cache accordingly."""
32
+ try:
33
+ print("OPENED RPC_INFERENCE")
34
+ request = await anext(requests)
35
+ requested_uids = self._check_header(request)
36
+ requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
37
+
38
+ cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64) # [cache_handle, prefix_length]
39
+ prefix_length = 0
40
+
41
+ async with self._allocate_caches(requested_backends) as cache_handles:
42
+ assert len(cache_handles) == len(requested_backends)
43
+ while request.tensors: # iterate while user is willing to supply tensors
44
+ hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
45
+
46
+ # run request tensors through all requested modules, update caches
47
+ for backend, cache_handle in zip(requested_backends, cache_handles):
48
+ cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, prefix_length
49
+ assert (
50
+ len(hidden_states) == 1 and hidden_states[0].ndim == 3
51
+ ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
52
+
53
+ hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
54
+ assert isinstance(hidden_states, (list, tuple))
55
+ assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
56
+
57
+ # serialize and send last layer outputs
58
+ yield runtime_pb2.ExpertResponse(
59
+ tensors=[
60
+ serialize_torch_tensor(result, proto.compression, allow_inplace=True)
61
+ for result, proto in zip(
62
+ hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
63
+ )
64
+ ]
65
+ )
66
+
67
+ # prepare for next step
68
+ prefix_length += hidden_states[0].shape[1]
69
+ request = await (anext(requests))
70
+ finally:
71
+ print("CLOSED RPC_INFERENCE")
72
+
73
+ async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
74
+ # Parse request and prepare backends
75
+ hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
76
+ requested_uids = self._check_header(request)
77
+ requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
78
+
79
+ # Run a chain of requested backends
80
+ for backend in requested_backends:
81
+ assert isinstance(hidden_states, (list, tuple))
82
+ assert (
83
+ len(hidden_states) == 1 and hidden_states[0].ndim == 3
84
+ ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
85
+ hidden_states = await backend.forward_pool.submit_task(*hidden_states)
86
+
87
+ # Serialize the overall output and respond
88
+ assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
89
+ return runtime_pb2.ExpertResponse(
90
+ tensors=[
91
+ serialize_torch_tensor(result, proto.compression, allow_inplace=True)
92
+ for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
93
+ ]
94
+ )
95
+
96
+ async def rpc_forward_stream(
97
+ self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
98
+ ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
99
+ # Parse requests and prepare backends
100
+ uids_header, hidden_states = await self._gather_inputs(requests, context)
101
+ requested_uids = self._check_header_str(uids_header)
102
+ requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
103
+
104
+ # Run a chain of requested backends
105
+ for backend in requested_backends:
106
+ assert isinstance(hidden_states, (list, tuple))
107
+ assert (
108
+ len(hidden_states) == 1 and hidden_states[0].ndim == 3
109
+ ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
110
+ hidden_states = await backend.forward_pool.submit_task(*hidden_states)
111
+
112
+ # Serialize the overall output
113
+ assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
114
+ serialized_output = [
115
+ serialize_torch_tensor(result, proto.compression, allow_inplace=True)
116
+ for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
117
+ ]
118
+
119
+ # Split the serialized_output for streaming and respond
120
+ output_split = [
121
+ part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
122
+ ]
123
+ async for part in as_aiter(*output_split):
124
+ yield runtime_pb2.ExpertResponse(tensors=[part])
125
+
126
+ async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
127
+ # Parse requests and prepare backends
128
+ inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
129
+ requested_uids = self._check_header(request)
130
+ requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
131
+
132
+ # Run a forward chain to collect intermediate inputs
133
+ # Note that we do not forward for the last module since we do not need its output
134
+ inter_inputs = [inputs]
135
+ for backend in requested_backends[:-1]:
136
+ assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
137
+ inputs = await backend.forward_pool.submit_task(inputs)
138
+ assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
139
+ inputs = inputs[0]
140
+ inter_inputs.append(inputs)
141
+
142
+ # Run a chain of requested backends
143
+ for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
144
+ inputs_and_grads = [inp, grads]
145
+ grads = await backend.backward_pool.submit_task(*inputs_and_grads)
146
+ assert isinstance(grads, (list, tuple)) and len(grads) == 1
147
+ grads = grads[0]
148
+
149
+ # Serialize the overall grad_input and respond
150
+ return runtime_pb2.ExpertResponse(
151
+ tensors=[
152
+ serialize_torch_tensor(result, proto.compression, allow_inplace=True)
153
+ for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
154
+ ]
155
+ )
156
+
157
+ async def rpc_backward_stream(
158
+ self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
159
+ ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
160
+ uids_header, inputs_and_grads = await self._gather_inputs(requests, context)
161
+ inputs, grads = inputs_and_grads
162
+ requested_uids = self._check_header_str(uids_header)
163
+ requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
164
+
165
+ # Run a forward chain to collect intermediate inputs
166
+ # Note that we do not forward for the last module since we do not need its outputs
167
+ inter_inputs = [inputs]
168
+ for backend in requested_backends[:-1]:
169
+ assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
170
+ inputs = await backend.forward_pool.submit_task(inputs)
171
+ assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
172
+ inputs = inputs[0]
173
+ inter_inputs.append(inputs)
174
+
175
+ # Run a backward chain for requested backends
176
+ for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
177
+ inputs_and_grads = [inp, grads]
178
+ grads = await backend.backward_pool.submit_task(*inputs_and_grads)
179
+ assert isinstance(grads, (list, tuple)) and len(grads) == 1
180
+ grads = grads[0]
181
+
182
+ # Serialize the overall grad_inputs
183
+ serialized_grad_inputs = [
184
+ serialize_torch_tensor(result, proto.compression, allow_inplace=True)
185
+ for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
186
+ ]
187
+ # Split the serialized_grad_inputs for streaming and respond
188
+ output_split = [
189
+ part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
190
+ ]
191
+
192
+ async for part in as_aiter(*output_split):
193
+ yield runtime_pb2.ExpertResponse(tensors=[part])
194
+
195
+ def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
196
+ """Check that the first request to rpc_inference is valid"""
197
+ uids = (request.uid or "").split(CHAIN_DELIMITER)
198
+ if not uids:
199
+ raise RuntimeError("User did not provide any uids")
200
+ for uid in uids:
201
+ if uid not in self.module_backends:
202
+ raise RuntimeError(f"Remote peer does not serve {uid}")
203
+ return tuple(uids)
204
+
205
+ def _check_header_str(self, header) -> Sequence[ModuleUID]:
206
+ """Check that the first request to rpc_inference is valid"""
207
+ uids = (header or "").split(CHAIN_DELIMITER)
208
+ if not uids:
209
+ raise RuntimeError("User did not provide any uids")
210
+ for uid in uids:
211
+ if uid not in self.module_backends:
212
+ raise RuntimeError(f"Remote peer does not serve {uid}")
213
+ return tuple(uids)
214
+
215
+ @contextlib.asynccontextmanager
216
+ async def _allocate_caches(self, backends: Sequence[TransformerBackend]) -> Sequence[int]:
217
+ """Allocate memory caches for each transformer block, return cache handles"""
218
+ async with contextlib.AsyncExitStack() as stack:
219
+ handles = []
220
+ for backend in backends:
221
+ num_heads = backend.module.self_attention.num_heads
222
+ head_dim = backend.module.self_attention.head_dim
223
+
224
+ cache_descriptor = TensorDescriptor(size=(2, 1, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32)
225
+ # [key_or_value, batch_size, max_length, num_heads, head_dim]
226
+
227
+ handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
228
+
229
+ yield handles
src/server/server.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import multiprocessing as mp
4
+ import threading
5
+ from typing import Dict, Optional, Sequence, Union
6
+
7
+ import torch
8
+ from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
9
+ from hivemind.moe.server.layers import add_custom_models_from_file
10
+ from hivemind.moe.server.runtime import Runtime
11
+ from hivemind.proto.runtime_pb2 import CompressionType
12
+ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
13
+
14
+ from src import declare_active_modules, BloomConfig
15
+ from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
16
+ from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER
17
+ from src.server.backend import TransformerBackend
18
+ from src.server.cache import MemoryCache
19
+ from src.server.handler import TransformerConnectionHandler
20
+
21
+ use_hivemind_log_handler("in_root_logger")
22
+ logger = get_logger(__file__)
23
+
24
+
25
+ class Server(threading.Thread):
26
+ """Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
27
+
28
+ def __init__(
29
+ self,
30
+ dht: DHT,
31
+ module_backends: Dict[str, TransformerBackend],
32
+ *,
33
+ device: torch.device,
34
+ num_connection_handlers: int = 8,
35
+ update_period: float = 30,
36
+ expiration: Optional[float] = None,
37
+ start: bool,
38
+ **kwargs,
39
+ ):
40
+ threading.Thread.__init__(self)
41
+ self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
42
+ self.conn_handlers = [
43
+ TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
44
+ ]
45
+ self.runtime = Runtime(self.module_backends, device=device, **kwargs)
46
+ self.dht_handler_thread = ModuleAnnouncerThread(
47
+ self.module_backends, dht, update_period, expiration, daemon=True
48
+ )
49
+ self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
50
+
51
+ if start:
52
+ self.run_in_background(await_ready=True)
53
+
54
+ def run(self):
55
+ """
56
+ Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
57
+ runs Runtime (self.runtime) to process incoming requests.
58
+ """
59
+ logger.info(f"Serving {len(self.module_backends)} blocks:")
60
+ for expert_name, backend in self.module_backends.items():
61
+ num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
62
+ logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
63
+
64
+ if not self.dht.is_alive():
65
+ self.dht.run_in_background(await_ready=True)
66
+
67
+ if self.module_backends:
68
+ self.dht_handler_thread.start()
69
+
70
+ if self.checkpoint_saver is not None:
71
+ self.checkpoint_saver.start()
72
+
73
+ for process in self.conn_handlers:
74
+ if not process.is_alive():
75
+ process.start()
76
+ process.ready.result()
77
+
78
+ try:
79
+ self.runtime.run()
80
+ finally:
81
+ self.shutdown()
82
+
83
+ # noinspection PyMethodOverriding
84
+ @classmethod
85
+ def create(
86
+ cls,
87
+ prefix: Optional[str],
88
+ converted_model_name_or_path: str,
89
+ num_blocks: Optional[int] = None,
90
+ block_indices: Optional[str] = None,
91
+ num_handlers: Optional[int] = None,
92
+ min_batch_size: int = 1,
93
+ max_batch_size: int = 4096,
94
+ torch_dtype: str = "auto",
95
+ cache_size_bytes: Optional[int] = None,
96
+ device: Union[str, torch.device] = None,
97
+ initial_peers: Sequence[str] = (),
98
+ compression=CompressionType.NONE,
99
+ stats_report_interval: Optional[int] = None,
100
+ custom_module_path=None,
101
+ update_period: float = 30,
102
+ expiration: Optional[float] = None,
103
+ use_auth_token: Optional[str] = None,
104
+ *,
105
+ start: bool,
106
+ **kwargs,
107
+ ) -> Server:
108
+ """Create a server with one or more bloom blocks. See run_server.py for documentation."""
109
+ if custom_module_path is not None:
110
+ add_custom_models_from_file(custom_module_path)
111
+ if prefix is None:
112
+ prefix = converted_model_name_or_path
113
+ assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
114
+ f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); "
115
+ f"Please specify --prefix manually when starting a server"
116
+ )
117
+ logger.info(f"Automatic dht prefix: {prefix}")
118
+ assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
119
+ dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
120
+ visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
121
+ logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
122
+
123
+ device = device or ("cuda" if torch.cuda.is_available() else "cpu")
124
+ memory_cache = MemoryCache(device, cache_size_bytes)
125
+
126
+ if isinstance(torch_dtype, str):
127
+ torch_dtype = DTYPE_MAP[torch_dtype]
128
+ assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
129
+
130
+ if block_indices is not None:
131
+ try:
132
+ first_block_index, last_block_index = block_indices.split(":")
133
+ first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
134
+ except Exception as e:
135
+ logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
136
+ raise
137
+ block_indices = range(first_block_index, last_block_index)
138
+ else:
139
+ assert num_blocks is not None
140
+ block_indices = range(num_blocks) # TODO replace with proper load balancing
141
+
142
+ block_config = BloomConfig.from_pretrained(
143
+ converted_model_name_or_path, use_auth_token=use_auth_token
144
+ )
145
+
146
+ # initialize modules
147
+ blocks = {}
148
+ for block_index in block_indices:
149
+ module_uid = f"{prefix}.{block_index}"
150
+ block = load_pretrained_block(
151
+ converted_model_name_or_path,
152
+ block_index,
153
+ block_config,
154
+ torch_dtype=torch_dtype,
155
+ use_auth_token=use_auth_token,
156
+ )
157
+ for param in block.parameters():
158
+ param.requires_grad = False
159
+
160
+ blocks[module_uid] = TransformerBackend(
161
+ module_uid,
162
+ block,
163
+ memory_cache=memory_cache,
164
+ args_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
165
+ kwargs_schema={},
166
+ outputs_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
167
+ min_batch_size=min_batch_size,
168
+ max_batch_size=max_batch_size,
169
+ )
170
+
171
+ num_handlers = num_handlers if num_handlers is not None else len(blocks) * 4
172
+
173
+ return cls(
174
+ dht,
175
+ blocks,
176
+ num_connection_handlers=num_handlers,
177
+ device=device,
178
+ stats_report_interval=stats_report_interval,
179
+ update_period=update_period,
180
+ expiration=expiration,
181
+ start=start,
182
+ )
183
+
184
+ def run_in_background(self, await_ready=True, timeout=None):
185
+ """
186
+ Starts Server in a background thread. if await_ready, this method will wait until background server
187
+ is ready to process incoming requests or for :timeout: seconds max.
188
+ """
189
+ self.start()
190
+ if await_ready and not self.ready.wait(timeout=timeout):
191
+ raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
192
+
193
+ @property
194
+ def ready(self) -> mp.synchronize.Event:
195
+ """
196
+ An event (multiprocessing.Event) that is set when the server is ready to process requests.
197
+
198
+ Example
199
+ =======
200
+ >>> server.start()
201
+ >>> server.ready.wait(timeout=10)
202
+ >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
203
+ """
204
+ return self.runtime.ready # mp.Event that is true if self is ready to process batches
205
+
206
+ def shutdown(self):
207
+ """
208
+ Gracefully terminate the server, process-safe.
209
+ Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
210
+ If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
211
+ """
212
+ self.ready.clear()
213
+
214
+ for process in self.conn_handlers:
215
+ process.terminate()
216
+ process.join()
217
+ logger.debug("Connection handlers terminated")
218
+
219
+ if self.module_backends:
220
+ self.dht_handler_thread.stop.set()
221
+ self.dht_handler_thread.join()
222
+
223
+ if self.checkpoint_saver is not None:
224
+ self.checkpoint_saver.stop.set()
225
+ self.checkpoint_saver.join()
226
+
227
+ self.dht.shutdown()
228
+ self.dht.join()
229
+
230
+ logger.debug(f"Shutting down runtime")
231
+
232
+ self.runtime.shutdown()
233
+ logger.info("Server shutdown succesfully")
234
+
235
+
236
+ class ModuleAnnouncerThread(threading.Thread):
237
+ """Periodically announces that this server hosts the specified modules, visible to all DHT peers"""
238
+
239
+ def __init__(
240
+ self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs
241
+ ):
242
+ super().__init__(**kwargs)
243
+ if expiration is None:
244
+ expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
245
+ self.module_backends = module_backends
246
+ self.dht = dht
247
+ self.update_period = update_period
248
+ self.expiration = expiration
249
+ self.stop = threading.Event()
250
+
251
+ def run(self) -> None:
252
+ declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)
253
+ while not self.stop.wait(self.update_period):
254
+ declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)