Small fix for issue where f16 CublasLinear layers weren't being used even when available.
Browse files- float8_quantize.py +1 -1
float8_quantize.py
CHANGED
@@ -336,7 +336,7 @@ def recursive_swap_linears(
|
|
336 |
|
337 |
@torch.inference_mode()
|
338 |
def swap_to_cublaslinear(model: nn.Module):
|
339 |
-
if not isinstance(CublasLinear, torch.nn.Module):
|
340 |
return
|
341 |
for name, child in model.named_children():
|
342 |
if isinstance(child, nn.Linear) and not isinstance(
|
|
|
336 |
|
337 |
@torch.inference_mode()
|
338 |
def swap_to_cublaslinear(model: nn.Module):
|
339 |
+
if not isinstance(CublasLinear, type(torch.nn.Module)):
|
340 |
return
|
341 |
for name, child in model.named_children():
|
342 |
if isinstance(child, nn.Linear) and not isinstance(
|