|
import sys |
|
import warnings |
|
from bisect import bisect_right |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.optim import lr_scheduler |
|
|
|
import craftsman |
|
|
|
|
|
def get_scheduler(name): |
|
if hasattr(lr_scheduler, name): |
|
return getattr(lr_scheduler, name) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
def getattr_recursive(m, attr): |
|
for name in attr.split("."): |
|
m = getattr(m, name) |
|
return m |
|
|
|
|
|
def get_parameters(model, name): |
|
module = getattr_recursive(model, name) |
|
if isinstance(module, nn.Module): |
|
return module.parameters() |
|
elif isinstance(module, nn.Parameter): |
|
return module |
|
return [] |
|
|
|
|
|
def parse_optimizer(config, model): |
|
if hasattr(config, "params"): |
|
params = [ |
|
{"params": get_parameters(model, name), "name": name, **args} |
|
for name, args in config.params.items() |
|
] |
|
craftsman.debug(f"Specify optimizer params: {config.params}") |
|
else: |
|
params = model.parameters() |
|
if config.name in ["FusedAdam"]: |
|
import apex |
|
|
|
optim = getattr(apex.optimizers, config.name)(params, **config.args) |
|
elif config.name in ["Adan"]: |
|
from craftsman.systems import optimizers |
|
|
|
optim = getattr(optimizers, config.name)(params, **config.args) |
|
else: |
|
optim = getattr(torch.optim, config.name)(params, **config.args) |
|
return optim |
|
|
|
|
|
def parse_scheduler_to_instance(config, optimizer): |
|
if config.name == "ChainedScheduler": |
|
schedulers = [ |
|
parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers |
|
] |
|
scheduler = lr_scheduler.ChainedScheduler(schedulers) |
|
elif config.name == "Sequential": |
|
schedulers = [ |
|
parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers |
|
] |
|
scheduler = lr_scheduler.SequentialLR( |
|
optimizer, schedulers, milestones=config.milestones |
|
) |
|
else: |
|
scheduler = getattr(lr_scheduler, config.name)(optimizer, **config.args) |
|
return scheduler |
|
|
|
|
|
def parse_scheduler(config, optimizer): |
|
interval = config.get("interval", "epoch") |
|
assert interval in ["epoch", "step"] |
|
if config.name == "SequentialLR": |
|
scheduler = { |
|
"scheduler": lr_scheduler.SequentialLR( |
|
optimizer, |
|
[ |
|
parse_scheduler(conf, optimizer)["scheduler"] |
|
for conf in config.schedulers |
|
], |
|
milestones=config.milestones, |
|
), |
|
"interval": interval, |
|
} |
|
elif config.name == "ChainedScheduler": |
|
scheduler = { |
|
"scheduler": lr_scheduler.ChainedScheduler( |
|
[ |
|
parse_scheduler(conf, optimizer)["scheduler"] |
|
for conf in config.schedulers |
|
] |
|
), |
|
"interval": interval, |
|
} |
|
else: |
|
scheduler = { |
|
"scheduler": get_scheduler(config.name)(optimizer, **config.args), |
|
"interval": interval, |
|
} |
|
return scheduler |
|
|