File size: 1,568 Bytes
96ee597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""Copyright PolyAI Limited."""
import logging
import pdb
import sys
import traceback
from functools import wraps
from time import time
from typing import List

import torch

from .symbol_table import SymbolTable


def load_checkpoint(ckpt_path: str) -> dict:
    """
    Loads checkpoint, while matching phone embedding size.
    """
    state_dict: dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
    new_state_dict = dict()
    for p_name in state_dict.keys():
        if p_name.startswith("vocoder"):
            continue

        new_state_dict[p_name] = state_dict[p_name]

    return new_state_dict


def breakpoint_on_error(fn):
    """Creates a breakpoint on error

    Use as a wrapper

    Args:
        fn: the function

    Returns:
        inner function
    """

    def inner(*args, **kwargs):
        try:
            return fn(*args, **kwargs)
        except Exception:
            """Standard python way of creating a breakpoint on error"""
            extype, value, tb = sys.exc_info()
            print(f"extype={extype},\nvalue={value}")
            traceback.print_exc()
            pdb.post_mortem(tb)

    return inner


def measure_duration(f):
    @wraps(f)
    def wrap(*args, **kw):
        ts = time()
        result = f(*args, **kw)
        te = time()
        logging.debug("func:%r took: %2.4f sec" % (f.__name__, te - ts))
        return result

    return wrap


def split_metapath(in_paths: List[str]):
    other_paths = []

    for itm_path in in_paths:
        other_paths.append(itm_path)

    return other_paths