File size: 2,693 Bytes
3f9bd99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from sklearn.cluster import *
from sklearn import metrics
from sklearn.mixture import GaussianMixture  # 高斯混合模型
import os
import numpy as np
import config#
import yaml
#import argparse

if __name__ == "__main__":
    #parser = argparse.ArgumentParser()
    config.parser.add_argument("-a","--algorithm", default="k",help="choose algorithm",type=str)
    config.parser.add_argument("-n","--num_clusters", default=3,help="number of clusters",type=int)
    config.parser.add_argument("-r","--range", default=4,help="number of files in a class",type=int)
    args = config.parser.parse_args()
    filelist_dict={}
    yml_result={}
    from config import config
    with open(config.preprocess_text_config.cleaned_path, mode="r", encoding="utf-8") as f:
      for line in f:
         speaker=line.split("|")[1]
         if speaker not in filelist_dict:
            filelist_dict[speaker]=[]
            yml_result[speaker]={}
         filelist_dict[speaker].append(line.split("|")[0]) 
    #print(filelist_dict)
      
    for speaker in filelist_dict:
       embs = []
       wavnames = []
       print("\nspeaker: "+speaker)
       for file in filelist_dict[speaker]:
          try:                   
            embs.append(np.expand_dims(np.load(f"{os.path.splitext(file)[0]}.emo.npy"), axis=0))
            wavnames.append(os.path.basename(file))
          except Exception as e:
            print(e)
       x = np.concatenate(embs,axis=0)
       x = np.squeeze(x)
       # 聚类算法类的数量
       n_clusters = args.num_clusters
       if args.algorithm=="b":
          model = Birch(n_clusters= n_clusters, threshold= 0.2)
       elif args.algorithm=="s":
          model = SpectralClustering(n_clusters=n_clusters)
       elif  args.algorithm=="a":
          model = AgglomerativeClustering(n_clusters= n_clusters)
       else: 
          model = KMeans(n_clusters=n_clusters, random_state=10)
       # 可以自行尝试各种不同的聚类算法
       y_predict = model.fit_predict(x)    
       classes=[[] for i in range(y_predict.max()+1)]

       for idx, wavname in enumerate(wavnames):
           classes[y_predict[idx]].append(wavname)

       for i in range(y_predict.max()+1):
         class_length=len(classes[i])
         print("类别:", i, "本类中样本数量:", class_length)
         yml_result[speaker][f"class{i}"]=[]
         for j in range(args.range):
            if j >=class_length:
               break
            print(classes[i][j])  
            yml_result[speaker][f"class{i}"].append(classes[i][j])

    with open(os.path.join(config.dataset_path,'emo_clustering.yml'), 'w', encoding='utf-8') as f:
          yaml.dump(yml_result, f)