File size: 3,139 Bytes
9a26655
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import numpy as np
import joblib
from s3prl.nn import S3PRLUpstream
import soundfile as sf
from argparse import ArgumentParser
import sentencepiece as spm
import os
from tqdm import tqdm


class ApplyKmeans(object):
    def __init__(self, km_path, use_gpu):
        self.km_model = joblib.load(km_path)
        self.C_np = self.km_model.cluster_centers_.transpose()
        self.Cnorm_np = (self.C_np**2).sum(0, keepdims=True)

        self.C = torch.from_numpy(self.C_np)
        self.Cnorm = torch.from_numpy(self.Cnorm_np)
        if use_gpu and torch.cuda.is_available():
            self.C = self.C.cuda()
            self.Cnorm = self.Cnorm.cuda()

    def __call__(self, x):
        if isinstance(x, torch.Tensor):
            x = x.to(self.C.device)
            dist = (
                x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm
            )
            return dist.argmin(dim=1).cpu().numpy()
        else:
            dist = (
                (x**2).sum(1, keepdims=True)
                - 2 * np.matmul(x, self.C_np)
                + self.Cnorm_np
            )
            return np.argmin(dist, axis=1)


def streaming_extract(wav, window_size=60):
    chunk_audios = []
    for i in tqdm(range(0, wav.shape[0], window_size * 16)):
        batched_audio = (torch.tensor(wav[i : i+window_size*16]).unsqueeze(0), torch.tensor([window_size*16]))
    
    # all_hs = []
    # for i in tqdm(range(0, wav.shape[0], window_size * 16)):
    #     hs, _ = model(torch.tensor(wav[i : i+window_size*16]).unsqueeze(0), torch.tensor([window_size*16]))
    #     all_hs.append(hs[20])
    # return torch.concat(all_hs, dim=1)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--km_model", default='./km_2000.mdl', help="Path to the kmeans model")
    parser.add_argument("--bpe_model", default='./bpe.model', help="Path to the bpe model")
    parser.add_argument("--audio", required=True, help="Path to the audio file")
    parser.add_argument("-s", action='store_true', help="Streaming mode")

    args = parser.parse_args()
    kmeans_path = args.km_model
    bpe_path = args.bpe_model
    audio_file = args.audio
    streaming = args.s

    apply_kmeans = ApplyKmeans(kmeans_path, use_gpu=True)
    ssl_model = S3PRLUpstream("hf_hubert_custom", path_or_url='TencentGameMate/chinese-hubert-large')
    ssl_model.eval()
    sp = spm.SentencePieceProcessor(model_file=bpe_path)

    unit_to_char = {}
    for l in open('distinct_cjk_token_lists').readlines():
        l = l.split()
        unit_to_char[int(l[0])] = l[1]

    wav, sr = sf.read(audio_file)
    with torch.no_grad():
        if streaming:
            assert False, "streaming mode is still developing"
            # all_hs = streaming_extract(wav)
        else:
            all_hs, all_hs_len = ssl_model(torch.tensor([wav]), torch.tensor([wav.shape[0]]))

    ssl_units = apply_kmeans(all_hs[20][0, :, :].numpy())
    print(ssl_units)
    ssl_char = "".join([unit_to_char[c] for c in ssl_units])
    ssl_char_bpe = sp.encode(ssl_char, out_type=str)
    print(ssl_char)
    print(ssl_char_bpe)