MAZALA2024 commited on
Commit
c3240ca
·
verified ·
1 Parent(s): 0aa8c17

Update voice_processing.py

Browse files
Files changed (1) hide show
  1. voice_processing.py +57 -2
voice_processing.py CHANGED
@@ -47,10 +47,65 @@ def get_unique_filename(extension):
47
  return f"{uuid.uuid4()}.{extension}"
48
 
49
  def model_data(model_name):
50
- # ... (keep the existing implementation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  def load_hubert():
53
- # ... (keep the existing implementation)
 
 
 
 
 
 
 
 
 
 
54
 
55
  def get_model_names():
56
  model_root = "weights"
 
47
  return f"{uuid.uuid4()}.{extension}"
48
 
49
  def model_data(model_name):
50
+ pth_path = [
51
+ f"{model_root}/{model_name}/{f}"
52
+ for f in os.listdir(f"{model_root}/{model_name}")
53
+ if f.endswith(".pth")
54
+ ][0]
55
+ print(f"Loading {pth_path}")
56
+ cpt = torch.load(pth_path, map_location="cpu")
57
+ tgt_sr = cpt["config"][-1]
58
+ cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
59
+ if_f0 = cpt.get("f0", 1)
60
+ version = cpt.get("version", "v1")
61
+ if version == "v1":
62
+ if if_f0 == 1:
63
+ net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=config.is_half)
64
+ else:
65
+ net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
66
+ elif version == "v2":
67
+ if if_f0 == 1:
68
+ net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=config.is_half)
69
+ else:
70
+ net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
71
+ else:
72
+ raise ValueError("Unknown version")
73
+ del net_g.enc_q
74
+ net_g.load_state_dict(cpt["weight"], strict=False)
75
+ print("Model loaded")
76
+ net_g.eval().to(config.device)
77
+ if config.is_half:
78
+ net_g = net_g.half()
79
+ else:
80
+ net_g = net_g.float()
81
+ vc = VC(tgt_sr, config)
82
+
83
+ index_files = [
84
+ f"{model_root}/{model_name}/{f}"
85
+ for f in os.listdir(f"{model_root}/{model_name}")
86
+ if f.endswith(".index")
87
+ ]
88
+ if len(index_files) == 0:
89
+ print("No index file found")
90
+ index_file = ""
91
+ else:
92
+ index_file = index_files[0]
93
+ print(f"Index file found: {index_file}")
94
+
95
+ return tgt_sr, net_g, vc, version, index_file, if_f0
96
 
97
  def load_hubert():
98
+ models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
99
+ ["hubert_base.pt"],
100
+ suffix="",
101
+ )
102
+ hubert_model = models[0]
103
+ hubert_model = hubert_model.to(config.device)
104
+ if config.is_half:
105
+ hubert_model = hubert_model.half()
106
+ else:
107
+ hubert_model = hubert_model.float()
108
+ return hubert_model.eval()
109
 
110
  def get_model_names():
111
  model_root = "weights"