MambaFaceKISS-hf / torch_utils.py
Maykeye
Initial commit
1a030c8
raw
history blame
327 Bytes
import torch.nn as nn
import torch
def model_device(m: nn.Module):
return next(iter(m.parameters())).device
def model_numel(m: nn.Module, requires_grad=False):
if requires_grad:
return sum(p.numel() for p in m.parameters() if p.requires_grad)
else:
return sum(p.numel() for p in m.parameters())