Spaces:
Running
Running
import os | |
import logging | |
from contants import config | |
class HParams: | |
def __init__(self, **kwargs): | |
for k, v in kwargs.items(): | |
if type(v) == dict: | |
v = HParams(**v) | |
self[k] = v | |
def keys(self): | |
return self.__dict__.keys() | |
def items(self): | |
return self.__dict__.items() | |
def values(self): | |
return self.__dict__.values() | |
def __len__(self): | |
return len(self.__dict__) | |
def __getitem__(self, key): | |
return getattr(self, key) | |
def __setitem__(self, key, value): | |
return setattr(self, key, value) | |
def __contains__(self, key): | |
return key in self.__dict__ | |
def __repr__(self): | |
return self.__dict__.__repr__() | |
def load_checkpoint(checkpoint_path, model): | |
from torch import load | |
checkpoint_dict = load(checkpoint_path, map_location=config.system.device) | |
iteration = checkpoint_dict.get('iteration', None) | |
saved_state_dict = checkpoint_dict['model'] | |
if hasattr(model, 'module'): | |
state_dict = model.module.state_dict() | |
else: | |
state_dict = model.state_dict() | |
new_state_dict = {} | |
for k, v in state_dict.items(): | |
try: | |
new_state_dict[k] = saved_state_dict[k] | |
except: | |
logging.info(f"{k} is not in the checkpoint") | |
new_state_dict[k] = v | |
if hasattr(model, 'module'): | |
model.module.load_state_dict(new_state_dict) | |
else: | |
model.load_state_dict(new_state_dict) | |
# if iteration: | |
# logging.info(f"Loaded checkpoint '{checkpoint_path}' (iteration {iteration})") | |
# else: | |
# logging.info(f"Loaded checkpoint '{checkpoint_path}'") | |
return iteration | |
def get_hparams_from_file(config_path): | |
from json import loads | |
with open(config_path, 'r', encoding='utf-8') as f: | |
data = f.read() | |
config = loads(data) | |
hparams = HParams(**config) | |
return hparams | |
def load_audio_to_torch(full_path, target_sampling_rate): | |
import librosa | |
from torch import FloatTensor | |
from numpy import float32 | |
audio, sampling_rate = librosa.load(full_path, sr=target_sampling_rate, mono=True) | |
return FloatTensor(audio.astype(float32)) | |
def check_is_none(*items) -> bool: | |
""" | |
Check if any item is None or an empty string. | |
Args: | |
*items: Variable number of items to check. | |
Returns: | |
bool: True if any item is None or an empty string, False otherwise. | |
""" | |
for item in items: | |
if item is None or (isinstance(item, str) and str(item).isspace()) or str(item) == "": | |
return True | |
return False | |
def clean_folder(folder_path): | |
for filename in os.listdir(folder_path): | |
file_path = os.path.join(folder_path, filename) | |
# å¦‚æžœæ˜¯æ–‡ä»¶ï¼Œåˆ™åˆ é™¤æ–‡ä»¶ã€‚å¦‚æžœæ˜¯æ–‡ä»¶å¤¹åˆ™è·³è¿‡ã€‚ | |
if os.path.isfile(file_path): | |
os.remove(file_path) | |