FunSR / models /models.py
KyanChen's picture
add
02c5426
raw
history blame
498 Bytes
import copy
models = {}
def register(name):
def decorator(cls):
models[name] = cls
return cls
return decorator
def make(model_spec, args=None, load_sd=False):
if args is not None:
model_args = copy.deepcopy(model_spec['args'])
model_args.update(args)
else:
model_args = model_spec['args']
model = models[model_spec['name']](**model_args)
if load_sd:
model.load_state_dict(model_spec['sd'], strict=True)
return model