File size: 1,191 Bytes
f82071f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def download_dict():
    return {
        "vec768l12": {
            "url": "https://ibm.ent.box.com/shared/static/z1wgl1stco8ffooyatzdwsqn2psd9lrr",
            "output": "./pretrain/checkpoint_best_legacy_500.pt"
        },
        "vec256l9": {
            "url": "https://ibm.ent.box.com/shared/static/z1wgl1stco8ffooyatzdwsqn2psd9lrr",
            "output": "./pretrain/checkpoint_best_legacy_500.pt"
        },
        "hubertsoft": {
            "url": "https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt",
            "output": "./pretrain/hubert-soft-0d54a1f4.pt"
        },
        "whisper-ppg": {
            "url": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
            "output": "./pretrain/medium.pt"
        }
    }


def get_speech_encoder(config_path="configs/config.json"):
    import json

    with open(config_path, "r") as f:
        data = f.read()
        config = json.loads(data)
        speech_encoder = config["model"]["speech_encoder"]
        dict = download_dict()

        return dict[speech_encoder]["url"], dict[speech_encoder]["output"]