import torch.nn as nn | |
import torch | |
def model_device(m: nn.Module): | |
return next(iter(m.parameters())).device | |
def model_numel(m: nn.Module, requires_grad=False): | |
if requires_grad: | |
return sum(p.numel() for p in m.parameters() if p.requires_grad) | |
else: | |
return sum(p.numel() for p in m.parameters()) | |