Safetensors
aredden commited on
Commit
d45a331
·
1 Parent(s): 3ddaa67

remove torchao dependency, quantize entirely via linear

Browse files
Files changed (2) hide show
  1. float8_quantize.py +31 -25
  2. requirements.txt +0 -1
float8_quantize.py CHANGED
@@ -1,11 +1,6 @@
1
  from loguru import logger
2
  import torch
3
  import torch.nn as nn
4
- from torchao.float8.float8_utils import (
5
- amax_to_scale,
6
- tensor_to_amax,
7
- to_fp8_saturated,
8
- )
9
  from torch.nn import init
10
  import math
11
  from torch.compiler import is_compiling
@@ -200,42 +195,55 @@ class F8Linear(nn.Module):
200
  def quantize_weight(self):
201
  if self.weight_initialized:
202
  return
203
- amax = tensor_to_amax(self.weight.data)
204
- scale = amax_to_scale(amax, self.float8_dtype, self.weight.dtype)
205
- self.float8_data = to_fp8_saturated(self.weight.data * scale, self.float8_dtype)
206
- self.scale = scale.float()
207
- self.weight_initialized = True
208
- self.scale_reciprocal = self.scale.reciprocal().float()
209
  self.weight.data = torch.zeros(
210
  1, dtype=self.weight.dtype, device=self.weight.device, requires_grad=False
211
  )
 
212
 
213
  def set_weight_tensor(self, tensor: torch.Tensor):
214
  self.weight.data = tensor
215
  self.weight_initialized = False
216
  self.quantize_weight()
217
 
 
 
 
 
 
 
218
  def quantize_input(self, x: torch.Tensor):
219
  if self.input_scale_initialized:
220
- return to_fp8_saturated(x * self.input_scale, self.input_float8_dtype)
 
 
221
  elif self.trial_index < self.num_scale_trials:
222
- amax = tensor_to_amax(x)
 
 
223
  self.input_amax_trials[self.trial_index] = amax
224
  self.trial_index += 1
225
- self.input_scale = amax_to_scale(
226
- self.input_amax_trials[: self.trial_index].max(),
227
- self.input_float8_dtype,
228
- self.weight.dtype,
229
  )
230
  self.input_scale_reciprocal = self.input_scale.reciprocal()
231
- return to_fp8_saturated(x * self.input_scale, self.input_float8_dtype)
 
 
232
  else:
233
- self.input_scale = amax_to_scale(
234
- self.input_amax_trials.max(), self.input_float8_dtype, self.weight.dtype
235
  )
236
  self.input_scale_reciprocal = self.input_scale.reciprocal()
237
  self.input_scale_initialized = True
238
- return to_fp8_saturated(x * self.input_scale, self.input_float8_dtype)
 
 
239
 
240
  def reset_parameters(self) -> None:
241
  if self.weight_initialized:
@@ -263,10 +271,8 @@ class F8Linear(nn.Module):
263
 
264
  def forward(self, x: torch.Tensor) -> torch.Tensor:
265
  if self.input_scale_initialized or is_compiling():
266
- x = (
267
- x.mul(self.input_scale)
268
- .clamp(min=-self.input_max_value, max=self.input_max_value)
269
- .type(self.input_float8_dtype)
270
  )
271
  else:
272
  x = self.quantize_input(x)
 
1
  from loguru import logger
2
  import torch
3
  import torch.nn as nn
 
 
 
 
 
4
  from torch.nn import init
5
  import math
6
  from torch.compiler import is_compiling
 
195
  def quantize_weight(self):
196
  if self.weight_initialized:
197
  return
198
+ amax = torch.max(torch.abs(self.weight.data)).float()
199
+ self.scale = self.amax_to_scale(amax, self.max_value)
200
+ self.float8_data = self.to_fp8_saturated(
201
+ self.weight.data, self.scale, self.max_value
202
+ ).to(self.float8_dtype)
203
+ self.scale_reciprocal = self.scale.reciprocal()
204
  self.weight.data = torch.zeros(
205
  1, dtype=self.weight.dtype, device=self.weight.device, requires_grad=False
206
  )
207
+ self.weight_initialized = True
208
 
209
  def set_weight_tensor(self, tensor: torch.Tensor):
210
  self.weight.data = tensor
211
  self.weight_initialized = False
212
  self.quantize_weight()
213
 
214
+ def amax_to_scale(self, amax, max_val):
215
+ return (max_val / torch.clamp(amax, min=1e-12)).clamp(max=max_val)
216
+
217
+ def to_fp8_saturated(self, x, scale, max_val):
218
+ return (x * scale).clamp(-max_val, max_val)
219
+
220
  def quantize_input(self, x: torch.Tensor):
221
  if self.input_scale_initialized:
222
+ return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to(
223
+ self.input_float8_dtype
224
+ )
225
  elif self.trial_index < self.num_scale_trials:
226
+
227
+ amax = torch.max(torch.abs(x)).float()
228
+
229
  self.input_amax_trials[self.trial_index] = amax
230
  self.trial_index += 1
231
+ self.input_scale = self.amax_to_scale(
232
+ self.input_amax_trials[: self.trial_index].max(), self.input_max_value
 
 
233
  )
234
  self.input_scale_reciprocal = self.input_scale.reciprocal()
235
+ return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to(
236
+ self.input_float8_dtype
237
+ )
238
  else:
239
+ self.input_scale = self.amax_to_scale(
240
+ self.input_amax_trials.max(), self.input_max_value
241
  )
242
  self.input_scale_reciprocal = self.input_scale.reciprocal()
243
  self.input_scale_initialized = True
244
+ return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to(
245
+ self.input_float8_dtype
246
+ )
247
 
248
  def reset_parameters(self) -> None:
249
  if self.weight_initialized:
 
271
 
272
  def forward(self, x: torch.Tensor) -> torch.Tensor:
273
  if self.input_scale_initialized or is_compiling():
274
+ x = self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to(
275
+ self.input_float8_dtype
 
 
276
  )
277
  else:
278
  x = self.quantize_input(x)
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  git+https://github.com/aredden/torch-cublas-hgemm.git@master
2
- git+https://github.com/pytorch/ao.git@main
3
  einops
4
  PyTurboJPEG
5
  pydantic
 
1
  git+https://github.com/aredden/torch-cublas-hgemm.git@master
 
2
  einops
3
  PyTurboJPEG
4
  pydantic