File size: 3,488 Bytes
75fa479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
from typing import List, Callable

import numpy as np
import torch
from torch.nn import functional as F

np_dtype_to_torch_dtype = {
    np.float16: torch.float16,
    np.float32: torch.float32,
    np.uint8: torch.uint8,
    np.int8: torch.int8,
    np.int32: torch.int32,
    np.int64: torch.int64,
    bool: torch.bool,
}

class IndexDiff:
    def __init__(self, off_elements: torch.Tensor=None, off_positions: torch.Tensor=None, on_positions: torch.Tensor=None):
        self.off_elements = off_elements
        self.off_positions = off_positions
        self.on_positions = on_positions

def batch_copy(sources: List[torch.Tensor], copy_stream, indices=None, device='cuda'):
    with torch.cuda.stream(copy_stream):
        out = ()
        for src in sources:
            indexed = src[indices] if indices is not None else src
            dst = torch.empty(indexed.shape, device=device, dtype=src.dtype)
            dst.copy_(indexed, non_blocking=True)
            out += (dst,)
    return out

def mmap_to_tensor(torch_wrapped_mmap, pin_memory=False) -> torch.Tensor:
    out = torch.empty(torch_wrapped_mmap.shape, dtype=torch_wrapped_mmap.dtype, device='cpu', pin_memory=pin_memory)
    out.copy_(torch_wrapped_mmap)
    return out

# Assuming that each entry of cached_indices is a step down the memory hierarchy,
# compute the diff at each level of the hierarchy.
#   e.g. the first loop computes the indices that the GPU does not have,
#        and the second loop computes the indices *of that diff* that the CPU does not have.
def compute_index_diffs(new_indices: torch.Tensor, cached_indices_list: List[torch.Tensor], pin_memory=True):
    diffs = []
    current_diff = new_indices
    for cached_indices in cached_indices_list:
        if current_diff.size(0) == 0:
            # No need to go further down the hierarchy
            break

        # Compute elements of new indices not contained current indices
        off_elements = torch.tensor(
            list(set(current_diff.tolist()).difference(set(cached_indices.tolist()))),
            device='cpu',
            dtype=torch.int32,
            pin_memory=pin_memory
        )
        # Compute mask of current indices where new indices does not contain the element
        on_position_mask = torch.isin(cached_indices, current_diff, assume_unique=True)
        on_positions = torch.nonzero(on_position_mask).flatten()
        off_positions = torch.nonzero(~on_position_mask).flatten()[:off_elements.size(0)]

        diffs.append(IndexDiff(off_elements, off_positions, on_positions))
        current_diff = off_elements
    return diffs

def topk_and_threshold(x, k, threshold=1):
    vals, indices = torch.topk(x, k, sorted=True)
    return indices[vals > threshold].int()

def load_mlp_sparsity_predictor(weight_path_prefix: str, layer_num: int, dtype: torch.dtype, device: str = 'cuda') -> Callable:
    path_prefix = f'{weight_path_prefix}decoder.layers.{layer_num}.attn.mlp-sparsity-predictor.'
    return load_predictor(path_prefix, dtype, device=device)

def load_predictor(path_prefix: str, dtype: torch.dtype, device: str='cuda') -> Callable:
    path = lambda i: os.path.expanduser(f'{path_prefix}{i}.weight')
    if os.path.exists(path(1)):
        l1 = torch.load(path(1)).to(device).to(dtype)
        l2 = torch.load(path(2)).to(device).to(dtype)
        return lambda x: F.linear(F.linear(x, l1), l2)
    else:
        print(f'could not find predictor at {path(1)}')
        return None