File size: 1,393 Bytes
d5d7329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from __future__ import annotations

from pathlib import Path
from typing import Any

import torch
from sklearn.cluster import KMeans


def get_cluster_model(ckpt_path: Path | str):
    with Path(ckpt_path).open("rb") as f:
        checkpoint = torch.load(
            f, map_location="cpu"
        )  # Danger of arbitrary code execution
    kmeans_dict = {}
    for spk, ckpt in checkpoint.items():
        km = KMeans(ckpt["n_features_in_"])
        km.__dict__["n_features_in_"] = ckpt["n_features_in_"]
        km.__dict__["_n_threads"] = ckpt["_n_threads"]
        km.__dict__["cluster_centers_"] = ckpt["cluster_centers_"]
        kmeans_dict[spk] = km
    return kmeans_dict


def check_speaker(model: Any, speaker: Any):
    if speaker not in model:
        raise ValueError(f"Speaker {speaker} not in {list(model.keys())}")


def get_cluster_result(model: Any, x: Any, speaker: Any):
    """
    x: np.array [t, 256]
    return cluster class result
    """
    check_speaker(model, speaker)
    return model[speaker].predict(x)


def get_cluster_center_result(model: Any, x: Any, speaker: Any):
    """x: np.array [t, 256]"""
    check_speaker(model, speaker)
    predict = model[speaker].predict(x)
    return model[speaker].cluster_centers_[predict]


def get_center(model: Any, x: Any, speaker: Any):
    check_speaker(model, speaker)
    return model[speaker].cluster_centers_[x]