Safetensors
aredden commited on
Commit
6d82dcc
·
1 Parent(s): 00f5d2c

Small fix for issue where f16 CublasLinear layers weren't being used even when available.

Browse files
Files changed (1) hide show
  1. float8_quantize.py +1 -1
float8_quantize.py CHANGED
@@ -336,7 +336,7 @@ def recursive_swap_linears(
336
 
337
  @torch.inference_mode()
338
  def swap_to_cublaslinear(model: nn.Module):
339
- if not isinstance(CublasLinear, torch.nn.Module):
340
  return
341
  for name, child in model.named_children():
342
  if isinstance(child, nn.Linear) and not isinstance(
 
336
 
337
  @torch.inference_mode()
338
  def swap_to_cublaslinear(model: nn.Module):
339
+ if not isinstance(CublasLinear, type(torch.nn.Module)):
340
  return
341
  for name, child in model.named_children():
342
  if isinstance(child, nn.Linear) and not isinstance(