|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Utility functions to load from the checkpoints. |
|
Each checkpoint is a torch.saved dict with the following keys: |
|
- 'xp.cfg': the hydra config as dumped during training. This should be used |
|
to rebuild the object using the audiocraft.models.builders functions, |
|
- 'model_best_state': a readily loadable best state for the model, including |
|
the conditioner. The model obtained from `xp.cfg` should be compatible |
|
with this state dict. In the case of a LM, the encodec model would not be |
|
bundled along but instead provided separately. |
|
|
|
Those functions also support loading from a remote location with the Torch Hub API. |
|
They also support overriding some parameters, in particular the device and dtype |
|
of the returned model. |
|
""" |
|
|
|
from pathlib import Path |
|
import typing as tp |
|
|
|
from omegaconf import OmegaConf |
|
import torch |
|
|
|
from . import builders |
|
|
|
|
|
def _get_state_dict(file_or_url: tp.Union[Path, str], device='cpu'): |
|
|
|
file_or_url = str(file_or_url) |
|
assert isinstance(file_or_url, str) |
|
if file_or_url.startswith('https://'): |
|
return torch.hub.load_state_dict_from_url(file_or_url, map_location=device, check_hash=True) |
|
else: |
|
return torch.load(file_or_url, device) |
|
|
|
|
|
def load_compression_model(file_or_url: tp.Union[Path, str], device='cpu'): |
|
pkg = _get_state_dict(file_or_url) |
|
cfg = OmegaConf.create(pkg['xp.cfg']) |
|
cfg.device = str(device) |
|
model = builders.get_compression_model(cfg) |
|
model.load_state_dict(pkg['best_state']) |
|
model.eval() |
|
return model |
|
|
|
|
|
def load_lm_model(file_or_url: tp.Union[Path, str], device='cpu'): |
|
pkg = _get_state_dict(file_or_url) |
|
cfg = OmegaConf.create(pkg['xp.cfg']) |
|
cfg.device = str(device) |
|
if cfg.device == 'cpu': |
|
cfg.transformer_lm.memory_efficient = False |
|
cfg.transformer_lm.custom = True |
|
cfg.dtype = 'float32' |
|
else: |
|
cfg.dtype = 'float16' |
|
model = builders.get_lm_model(cfg) |
|
model.load_state_dict(pkg['best_state']) |
|
model.eval() |
|
model.cfg = cfg |
|
return model |
|
|