def print_trainable_parameters(model, vb=0): | |
trainable_params = 0 | |
all_param = 0 | |
for _, param in model.named_parameters(): | |
all_param += param.numel() | |
if param.requires_grad: | |
trainable_params += param.numel() | |
if vb > 0: | |
print(_, param.requires_grad, param.numel()) | |
print( | |
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}" | |
) | |