Safetensors
aredden commited on
Commit
3ddaa67
·
1 Parent(s): 7a7b2c1

Fix issues with loading F8Linear from state dict when init_scale not initialized & loaded from meta device

Browse files
Files changed (1) hide show
  1. float8_quantize.py +30 -1
float8_quantize.py CHANGED
@@ -125,7 +125,7 @@ class F8Linear(nn.Module):
125
  ) and sd["weight"] == torch.zeros_like(sd["weight"]):
126
  w = sd["weight"]
127
  # Set the init values as if it's already quantized float8_data
128
- self.float8_data = sd["float8_data"]
129
  self._parameters["weight"] = nn.Parameter(
130
  torch.zeros(
131
  1,
@@ -156,6 +156,31 @@ class F8Linear(nn.Module):
156
  self.input_scale_reciprocal = sd["input_scale_reciprocal"].float()
157
  self.input_scale_initialized = True
158
  self.trial_index = self.num_scale_trials
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  else:
160
  # If scales are not initialized, reset trials
161
  self.input_scale_initialized = False
@@ -292,6 +317,7 @@ def recursive_swap_linears(
292
  float8_dtype=torch.float8_e4m3fn,
293
  input_float8_dtype=torch.float8_e5m2,
294
  quantize_modulation: bool = True,
 
295
  ) -> None:
296
  """
297
  Recursively swaps all nn.Linear modules in the given model with F8Linear modules.
@@ -309,6 +335,8 @@ def recursive_swap_linears(
309
  all linear layers in the model will be using 8-bit quantization.
310
  """
311
  for name, child in model.named_children():
 
 
312
  if isinstance(child, Modulation) and not quantize_modulation:
313
  continue
314
  if isinstance(child, nn.Linear) and not isinstance(
@@ -331,6 +359,7 @@ def recursive_swap_linears(
331
  float8_dtype=float8_dtype,
332
  input_float8_dtype=input_float8_dtype,
333
  quantize_modulation=quantize_modulation,
 
334
  )
335
 
336
 
 
125
  ) and sd["weight"] == torch.zeros_like(sd["weight"]):
126
  w = sd["weight"]
127
  # Set the init values as if it's already quantized float8_data
128
+ self._buffers["float8_data"] = sd["float8_data"]
129
  self._parameters["weight"] = nn.Parameter(
130
  torch.zeros(
131
  1,
 
156
  self.input_scale_reciprocal = sd["input_scale_reciprocal"].float()
157
  self.input_scale_initialized = True
158
  self.trial_index = self.num_scale_trials
159
+ elif "scale" in sd and "scale_reciprocal" in sd:
160
+ self.scale = sd["scale"].float()
161
+ self.input_scale = (
162
+ sd["input_scale"].float() if "input_scale" in sd else None
163
+ )
164
+ self.scale_reciprocal = sd["scale_reciprocal"].float()
165
+ self.input_scale_reciprocal = (
166
+ sd["input_scale_reciprocal"].float()
167
+ if "input_scale_reciprocal" in sd
168
+ else None
169
+ )
170
+ self.input_scale_initialized = (
171
+ True if "input_scale" in sd else False
172
+ )
173
+ self.trial_index = (
174
+ self.num_scale_trials if "input_scale" in sd else 0
175
+ )
176
+ self.input_amax_trials = torch.zeros(
177
+ self.num_scale_trials,
178
+ requires_grad=False,
179
+ dtype=torch.float32,
180
+ device=self.weight.device,
181
+ )
182
+ self.input_scale_initialized = False
183
+ self.trial_index = 0
184
  else:
185
  # If scales are not initialized, reset trials
186
  self.input_scale_initialized = False
 
317
  float8_dtype=torch.float8_e4m3fn,
318
  input_float8_dtype=torch.float8_e5m2,
319
  quantize_modulation: bool = True,
320
+ ignore_keys: list[str] = [],
321
  ) -> None:
322
  """
323
  Recursively swaps all nn.Linear modules in the given model with F8Linear modules.
 
335
  all linear layers in the model will be using 8-bit quantization.
336
  """
337
  for name, child in model.named_children():
338
+ if name in ignore_keys:
339
+ continue
340
  if isinstance(child, Modulation) and not quantize_modulation:
341
  continue
342
  if isinstance(child, nn.Linear) and not isinstance(
 
359
  float8_dtype=float8_dtype,
360
  input_float8_dtype=input_float8_dtype,
361
  quantize_modulation=quantize_modulation,
362
+ ignore_keys=ignore_keys,
363
  )
364
 
365