import torch.nn as nn import torch def model_device(m: nn.Module): return next(iter(m.parameters())).device def model_numel(m: nn.Module, requires_grad=False): if requires_grad: return sum(p.numel() for p in m.parameters() if p.requires_grad) else: return sum(p.numel() for p in m.parameters())