File size: 2,291 Bytes
e71a2ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import argparse

import torch
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from tqdm.auto import trange

from src.bloom.block import BloomBlock
from src.bloom.model import BloomConfig
from src.bloom.ops import build_alibi_tensor

use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

logger.warning("inference_one_block will soon be deprecated in favour of tests!")


def print_device_info(device=None):
    """Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
    device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
    logger.info(f"Using device: {device}")

    # Additional Info when using cuda
    if device.type == "cuda":
        logger.info(torch.cuda.get_device_name(0))
        logger.info(f"Memory Usage:")
        logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
        logger.info(f"Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run a single bloom block locally on dummy data")
    parser.add_argument("--config", required=True, type=str, help="Path to a config json file")
    parser.add_argument("--state_dict", default=None, type=str, help="Optional path to saved block state dict")
    parser.add_argument("--layer_index", default=0, type=int, help="Optional path to saved block state dict")
    parser.add_argument("--num_steps", default=500, type=int, help="How many inference steps to run")
    parser.add_argument("--device", default=None, type=str, help="Run inference on this device")
    args = parser.parse_args()

    if args.device is None:
        args.device = "cuda" if torch.cuda.is_available() else "cpu"

    config = BloomConfig.from_json_file(args.config)
    block = BloomBlock(config, args.layer_index).to(args.device)

    cache = None

    for i in trange(args.num_steps):
        dummy_input = torch.randn(1, 1, config.hidden_size, device=args.device)
        alibi = build_alibi_tensor(i + 1, config.num_attention_heads).to(args.device)
        with torch.no_grad():
            outputs, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)

    print_device_info(args.device)