Plonk / utils /model_utils.py
nicolas-dufour's picture
squash: merge all unpushed commits
c4c7cee
raw
history blame
531 Bytes
def print_trainable_parameters(model):
"""
Prints the number and percentage of trainable parameters in the model.
Useful for tracking % parameters trained for LoRA.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)