petil777 commited on
Commit
649bc8a
·
1 Parent(s): 74a5e35

Upload weights.py

Browse files
Files changed (1) hide show
  1. weights.py +191 -0
weights.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List, Dict, Optional, Tuple
4
+ from safetensors import safe_open, SafetensorError
5
+ import torch
6
+ from huggingface_hub import hf_hub_download
7
+ import json
8
+
9
+
10
+ class Weights:
11
+ def __init__(
12
+ self,
13
+ filenames: List[Path],
14
+ device,
15
+ dtype,
16
+ process_group,
17
+ aliases: Optional[Dict[str, List[str]]] = None,
18
+ prefix: Optional[str] = None
19
+ ):
20
+ routing = {}
21
+ for filename in filenames:
22
+ with safe_open(filename, framework="pytorch") as f:
23
+ for k in f.keys():
24
+ if k in routing:
25
+ raise RuntimeError(
26
+ f"Key {k} was found in multiple files: {filename} and {routing[k]}"
27
+ )
28
+ routing[k] = filename
29
+ if aliases is None:
30
+ aliases = {}
31
+ self.aliases = aliases
32
+ self.routing = routing
33
+ self.device = device
34
+ self.dtype = dtype
35
+ self.process_group = process_group
36
+ self.prefix = prefix
37
+ self._handles = {}
38
+
39
+ def _get_handle(self, filename):
40
+ if filename not in self._handles:
41
+ f = safe_open(filename, framework="pytorch")
42
+ self._handles[filename] = f
43
+
44
+ return self._handles[filename]
45
+
46
+ def get_filename(self, tensor_name: str):
47
+
48
+ names = [tensor_name]
49
+ if self.prefix is not None:
50
+ prefixed = f"{self.prefix}.{tensor_name}"
51
+ names.append(prefixed)
52
+ for name in names:
53
+ filename = self.routing.get(name, None)
54
+ if filename is not None:
55
+ return str(filename), name
56
+
57
+ aliases = self.aliases.get(name, [])
58
+ for alias in aliases:
59
+ filename = self.routing.get(alias, None)
60
+ if filename is not None:
61
+ return str(filename), alias
62
+ raise RuntimeError(f"weight {tensor_name} does not exist")
63
+
64
+ def _get_slice(self, tensor_name: str):
65
+ filename, tensor_name = self.get_filename(tensor_name)
66
+ f = self._get_handle(filename)
67
+ slice_ = f.get_slice(tensor_name)
68
+ return slice_
69
+
70
+ def get_shape(self, tensor_name: str):
71
+ return self._get_slice(tensor_name).get_shape()
72
+
73
+ def get_tensor(self, tensor_name: str, to_device=True):
74
+ filename, tensor_name = self.get_filename(tensor_name)
75
+ f = self._get_handle(filename)
76
+ tensor = f.get_tensor(tensor_name)
77
+ # Special case for gptq which shouldn't convert
78
+ # u4 which are disguised as int32
79
+ if tensor.dtype not in [torch.int32, torch.int64]:
80
+ tensor = tensor.to(dtype=self.dtype)
81
+ if to_device:
82
+ tensor = tensor.to(device=self.device)
83
+ return tensor
84
+
85
+ def get_partial_sharded(self, tensor_name: str, dim: int):
86
+ filename, tensor_name = self.get_filename(tensor_name)
87
+ f = self._get_handle(filename)
88
+ slice_ = f.get_slice(tensor_name)
89
+ world_size = self.process_group.size()
90
+ rank = self.process_group.rank()
91
+
92
+ size = slice_.get_shape()[dim]
93
+ block_size = size // world_size
94
+ start = rank * block_size
95
+ stop = (rank + 1) * block_size
96
+
97
+ if dim == 0:
98
+ tensor = slice_[start:stop]
99
+ elif dim == 1:
100
+ tensor = slice_[:, start:stop]
101
+ else:
102
+ raise NotImplementedError("Let's make that generic when needed")
103
+ # Special case for gptq which shouldn't convert
104
+ # u4 which are disguised as int32
105
+ if tensor.dtype != torch.int32:
106
+ tensor = tensor.to(dtype=self.dtype)
107
+ tensor = tensor.to(device=self.device)
108
+ return tensor
109
+
110
+ def get_sharded(self, tensor_name: str, dim: int):
111
+ filename, tensor_name = self.get_filename(tensor_name)
112
+ f = self._get_handle(filename)
113
+ slice_ = f.get_slice(tensor_name)
114
+ world_size = self.process_group.size()
115
+ size = slice_.get_shape()[dim]
116
+ assert (
117
+ size % world_size == 0
118
+ ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
119
+ return self.get_partial_sharded(tensor_name, dim)
120
+
121
+ def _get_qweight(self, name: str):
122
+ slice_ = self._get_slice(name)
123
+ total_size = slice_.get_shape()[1]
124
+ assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3"
125
+ single_size = total_size // 3
126
+ world_size = self.process_group.size()
127
+ rank = self.process_group.rank()
128
+
129
+ assert (
130
+ single_size % world_size == 0
131
+ ), f"Prepacked quantized qkv cannot be sharded across {world_size} shards"
132
+ block_size = single_size // world_size
133
+ start = rank * block_size
134
+ stop = (rank + 1) * block_size
135
+ q = slice_[:, start:stop]
136
+ k = slice_[:, start + single_size : stop + single_size]
137
+ v = slice_[:, start + 2 * single_size : stop + 2 * single_size]
138
+ weight = torch.cat([q, k, v], dim=1)
139
+ weight = weight.to(device=self.device)
140
+ return weight
141
+
142
+ def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
143
+ """
144
+ Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
145
+ already alternating Q,K,V within the main tensor
146
+ """
147
+ slice_ = self._get_slice(f"{prefix}.weight")
148
+ total_size = slice_.get_shape()[0]
149
+ assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3"
150
+ single_size = total_size // 3
151
+ world_size = self.process_group.size()
152
+ rank = self.process_group.rank()
153
+
154
+ assert (
155
+ single_size % world_size == 0
156
+ ), f"Prepacked qkv cannot be sharded across {world_size} shards"
157
+ block_size = single_size // world_size
158
+ start = rank * block_size
159
+ stop = (rank + 1) * block_size
160
+ q = slice_[start:stop]
161
+ k = slice_[start + single_size : stop + single_size]
162
+ v = slice_[start + 2 * single_size : stop + 2 * single_size]
163
+ weight = torch.cat([q, k, v], dim=0)
164
+ weight = weight.to(device=self.device)
165
+ weight = weight.to(dtype=self.dtype)
166
+ return weight
167
+
168
+ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
169
+ w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
170
+ weight = torch.cat(w, dim=dim)
171
+ return weight
172
+
173
+ def get_tensor_shard(self, var, dim):
174
+ world_size = self.process_group.size()
175
+ rank = self.process_group.rank()
176
+ block_size = var.size()[dim] // world_size
177
+ start = rank * block_size
178
+ stop = (rank + 1) * block_size
179
+ if dim == 0:
180
+ tensor = var[start:stop]
181
+ elif dim == 1:
182
+ tensor = var[:, start:stop]
183
+ else:
184
+ raise NotImplementedError("Let's make that generic when needed")
185
+ tensor = tensor.to(dtype=self.dtype)
186
+ tensor = tensor.to(device=self.device)
187
+ return tensor
188
+
189
+ def get_multi_weights_row(self, prefix: str, quantize: str):
190
+ weight = self.get_sharded(f"{prefix}.weight", dim=1)
191
+ return weight