Spaces:
Runtime error
Runtime error
File size: 1,568 Bytes
96ee597 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
"""Copyright PolyAI Limited."""
import logging
import pdb
import sys
import traceback
from functools import wraps
from time import time
from typing import List
import torch
from .symbol_table import SymbolTable
def load_checkpoint(ckpt_path: str) -> dict:
"""
Loads checkpoint, while matching phone embedding size.
"""
state_dict: dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
new_state_dict = dict()
for p_name in state_dict.keys():
if p_name.startswith("vocoder"):
continue
new_state_dict[p_name] = state_dict[p_name]
return new_state_dict
def breakpoint_on_error(fn):
"""Creates a breakpoint on error
Use as a wrapper
Args:
fn: the function
Returns:
inner function
"""
def inner(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Exception:
"""Standard python way of creating a breakpoint on error"""
extype, value, tb = sys.exc_info()
print(f"extype={extype},\nvalue={value}")
traceback.print_exc()
pdb.post_mortem(tb)
return inner
def measure_duration(f):
@wraps(f)
def wrap(*args, **kw):
ts = time()
result = f(*args, **kw)
te = time()
logging.debug("func:%r took: %2.4f sec" % (f.__name__, te - ts))
return result
return wrap
def split_metapath(in_paths: List[str]):
other_paths = []
for itm_path in in_paths:
other_paths.append(itm_path)
return other_paths
|