import torch |
import numpy as np |
import joblib |
from s3prl.nn import S3PRLUpstream |
import soundfile as sf |
from argparse import ArgumentParser |
import sentencepiece as spm |
import os |
from tqdm import tqdm |
class ApplyKmeans(object): |
def __init__(self, km_path, use_gpu): |
self.km_model = joblib.load(km_path) |
self.C_np = self.km_model.cluster_centers_.transpose() |
self.Cnorm_np = (self.C_np**2).sum(0, keepdims=True) |
self.C = torch.from_numpy(self.C_np) |
self.Cnorm = torch.from_numpy(self.Cnorm_np) |
if use_gpu and torch.cuda.is_available(): |
self.C = self.C.cuda() |
self.Cnorm = self.Cnorm.cuda() |
def __call__(self, x): |
if isinstance(x, torch.Tensor): |
x = x.to(self.C.device) |
dist = ( |
x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm |
) |
return dist.argmin(dim=1).cpu().numpy() |
else: |
dist = ( |
(x**2).sum(1, keepdims=True) |
- 2 * np.matmul(x, self.C_np) |
+ self.Cnorm_np |
) |
return np.argmin(dist, axis=1) |
def streaming_extract(wav, window_size=60): |
chunk_audios = [] |
for i in tqdm(range(0, wav.shape[0], window_size * 16)): |
batched_audio = (torch.tensor(wav[i : i+window_size*16]).unsqueeze(0), torch.tensor([window_size*16])) |
if __name__ == "__main__": |
parser = ArgumentParser() |
parser.add_argument("--km_model", default='./km_2000.mdl', help="Path to the kmeans model") |
parser.add_argument("--bpe_model", default='./bpe.model', help="Path to the bpe model") |
parser.add_argument("--audio", required=True, help="Path to the audio file") |
parser.add_argument("-s", action='store_true', help="Streaming mode") |
args = parser.parse_args() |
kmeans_path = args.km_model |
bpe_path = args.bpe_model |
audio_file = args.audio |
streaming = args.s |
apply_kmeans = ApplyKmeans(kmeans_path, use_gpu=True) |
ssl_model = S3PRLUpstream("hf_hubert_custom", path_or_url='TencentGameMate/chinese-hubert-large') |
ssl_model.eval() |
sp = spm.SentencePieceProcessor(model_file=bpe_path) |
unit_to_char = {} |
for l in open('distinct_cjk_token_lists').readlines(): |
l = l.split() |
unit_to_char[int(l[0])] = l[1] |
wav, sr = sf.read(audio_file) |
with torch.no_grad(): |
if streaming: |
assert False, "streaming mode is still developing" |
else: |
all_hs, all_hs_len = ssl_model(torch.tensor([wav]), torch.tensor([wav.shape[0]])) |
ssl_units = apply_kmeans(all_hs[20][0, :, :].numpy()) |
print(ssl_units) |
ssl_char = "".join([unit_to_char[c] for c in ssl_units]) |
ssl_char_bpe = sp.encode(ssl_char, out_type=str) |
print(ssl_char) |
print(ssl_char_bpe) |