Fix issues with loading F8Linear from state dict when init_scale not initialized & loaded from meta device
Browse files- float8_quantize.py +30 -1
float8_quantize.py
CHANGED
@@ -125,7 +125,7 @@ class F8Linear(nn.Module):
|
|
125 |
) and sd["weight"] == torch.zeros_like(sd["weight"]):
|
126 |
w = sd["weight"]
|
127 |
# Set the init values as if it's already quantized float8_data
|
128 |
-
self.float8_data = sd["float8_data"]
|
129 |
self._parameters["weight"] = nn.Parameter(
|
130 |
torch.zeros(
|
131 |
1,
|
@@ -156,6 +156,31 @@ class F8Linear(nn.Module):
|
|
156 |
self.input_scale_reciprocal = sd["input_scale_reciprocal"].float()
|
157 |
self.input_scale_initialized = True
|
158 |
self.trial_index = self.num_scale_trials
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
else:
|
160 |
# If scales are not initialized, reset trials
|
161 |
self.input_scale_initialized = False
|
@@ -292,6 +317,7 @@ def recursive_swap_linears(
|
|
292 |
float8_dtype=torch.float8_e4m3fn,
|
293 |
input_float8_dtype=torch.float8_e5m2,
|
294 |
quantize_modulation: bool = True,
|
|
|
295 |
) -> None:
|
296 |
"""
|
297 |
Recursively swaps all nn.Linear modules in the given model with F8Linear modules.
|
@@ -309,6 +335,8 @@ def recursive_swap_linears(
|
|
309 |
all linear layers in the model will be using 8-bit quantization.
|
310 |
"""
|
311 |
for name, child in model.named_children():
|
|
|
|
|
312 |
if isinstance(child, Modulation) and not quantize_modulation:
|
313 |
continue
|
314 |
if isinstance(child, nn.Linear) and not isinstance(
|
@@ -331,6 +359,7 @@ def recursive_swap_linears(
|
|
331 |
float8_dtype=float8_dtype,
|
332 |
input_float8_dtype=input_float8_dtype,
|
333 |
quantize_modulation=quantize_modulation,
|
|
|
334 |
)
|
335 |
|
336 |
|
|
|
125 |
) and sd["weight"] == torch.zeros_like(sd["weight"]):
|
126 |
w = sd["weight"]
|
127 |
# Set the init values as if it's already quantized float8_data
|
128 |
+
self._buffers["float8_data"] = sd["float8_data"]
|
129 |
self._parameters["weight"] = nn.Parameter(
|
130 |
torch.zeros(
|
131 |
1,
|
|
|
156 |
self.input_scale_reciprocal = sd["input_scale_reciprocal"].float()
|
157 |
self.input_scale_initialized = True
|
158 |
self.trial_index = self.num_scale_trials
|
159 |
+
elif "scale" in sd and "scale_reciprocal" in sd:
|
160 |
+
self.scale = sd["scale"].float()
|
161 |
+
self.input_scale = (
|
162 |
+
sd["input_scale"].float() if "input_scale" in sd else None
|
163 |
+
)
|
164 |
+
self.scale_reciprocal = sd["scale_reciprocal"].float()
|
165 |
+
self.input_scale_reciprocal = (
|
166 |
+
sd["input_scale_reciprocal"].float()
|
167 |
+
if "input_scale_reciprocal" in sd
|
168 |
+
else None
|
169 |
+
)
|
170 |
+
self.input_scale_initialized = (
|
171 |
+
True if "input_scale" in sd else False
|
172 |
+
)
|
173 |
+
self.trial_index = (
|
174 |
+
self.num_scale_trials if "input_scale" in sd else 0
|
175 |
+
)
|
176 |
+
self.input_amax_trials = torch.zeros(
|
177 |
+
self.num_scale_trials,
|
178 |
+
requires_grad=False,
|
179 |
+
dtype=torch.float32,
|
180 |
+
device=self.weight.device,
|
181 |
+
)
|
182 |
+
self.input_scale_initialized = False
|
183 |
+
self.trial_index = 0
|
184 |
else:
|
185 |
# If scales are not initialized, reset trials
|
186 |
self.input_scale_initialized = False
|
|
|
317 |
float8_dtype=torch.float8_e4m3fn,
|
318 |
input_float8_dtype=torch.float8_e5m2,
|
319 |
quantize_modulation: bool = True,
|
320 |
+
ignore_keys: list[str] = [],
|
321 |
) -> None:
|
322 |
"""
|
323 |
Recursively swaps all nn.Linear modules in the given model with F8Linear modules.
|
|
|
335 |
all linear layers in the model will be using 8-bit quantization.
|
336 |
"""
|
337 |
for name, child in model.named_children():
|
338 |
+
if name in ignore_keys:
|
339 |
+
continue
|
340 |
if isinstance(child, Modulation) and not quantize_modulation:
|
341 |
continue
|
342 |
if isinstance(child, nn.Linear) and not isinstance(
|
|
|
359 |
float8_dtype=float8_dtype,
|
360 |
input_float8_dtype=input_float8_dtype,
|
361 |
quantize_modulation=quantize_modulation,
|
362 |
+
ignore_keys=ignore_keys,
|
363 |
)
|
364 |
|
365 |
|