Kit-Lemonfoot commited on
Commit
34271eb
·
verified ·
1 Parent(s): 20611fc

bit of a cheat but let's see if this works on space

Browse files
Files changed (1) hide show
  1. common/tts_model.py +250 -250
common/tts_model.py CHANGED
@@ -1,250 +1,250 @@
1
- import numpy as np
2
- import gradio as gr
3
- import torch
4
- import os
5
- import warnings
6
- from gradio.processing_utils import convert_to_16_bit_wav
7
- from typing import Dict, List, Optional, Union
8
-
9
- import utils
10
- from infer import get_net_g, infer
11
- from models import SynthesizerTrn
12
- from models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra
13
-
14
- from .log import logger
15
- from .constants import (
16
- DEFAULT_ASSIST_TEXT_WEIGHT,
17
- DEFAULT_LENGTH,
18
- DEFAULT_LINE_SPLIT,
19
- DEFAULT_NOISE,
20
- DEFAULT_NOISEW,
21
- DEFAULT_SDP_RATIO,
22
- DEFAULT_SPLIT_INTERVAL,
23
- DEFAULT_STYLE,
24
- DEFAULT_STYLE_WEIGHT,
25
- )
26
-
27
-
28
- class Model:
29
- def __init__(
30
- self, model_path: str, config_path: str, style_vec_path: str, device: str
31
- ):
32
- self.model_path: str = model_path
33
- self.config_path: str = config_path
34
- self.device: str = device
35
- self.style_vec_path: str = style_vec_path
36
- self.hps: utils.HParams = utils.get_hparams_from_file(self.config_path)
37
- self.spk2id: Dict[str, int] = self.hps.data.spk2id
38
- self.id2spk: Dict[int, str] = {v: k for k, v in self.spk2id.items()}
39
-
40
- self.num_styles: int = self.hps.data.num_styles
41
- if hasattr(self.hps.data, "style2id"):
42
- self.style2id: Dict[str, int] = self.hps.data.style2id
43
- else:
44
- self.style2id: Dict[str, int] = {str(i): i for i in range(self.num_styles)}
45
- if len(self.style2id) != self.num_styles:
46
- raise ValueError(
47
- f"Number of styles ({self.num_styles}) does not match the number of style2id ({len(self.style2id)})"
48
- )
49
-
50
- self.style_vectors: np.ndarray = np.load(self.style_vec_path)
51
- if self.style_vectors.shape[0] != self.num_styles:
52
- raise ValueError(
53
- f"The number of styles ({self.num_styles}) does not match the number of style vectors ({self.style_vectors.shape[0]})"
54
- )
55
-
56
- self.net_g: Union[SynthesizerTrn, SynthesizerTrnJPExtra, None] = None
57
-
58
- def load_net_g(self):
59
- self.net_g = get_net_g(
60
- model_path=self.model_path,
61
- version=self.hps.version,
62
- device=self.device,
63
- hps=self.hps,
64
- )
65
-
66
- def get_style_vector(self, style_id: int, weight: float = 1.0) -> np.ndarray:
67
- mean = self.style_vectors[0]
68
- style_vec = self.style_vectors[style_id]
69
- style_vec = mean + (style_vec - mean) * weight
70
- return style_vec
71
-
72
- def get_style_vector_from_audio(
73
- self, audio_path: str, weight: float = 1.0
74
- ) -> np.ndarray:
75
- from style_gen import get_style_vector
76
-
77
- xvec = get_style_vector(audio_path)
78
- mean = self.style_vectors[0]
79
- xvec = mean + (xvec - mean) * weight
80
- return xvec
81
-
82
- def infer(
83
- self,
84
- text: str,
85
- language: str = "JP",
86
- sid: int = 0,
87
- reference_audio_path: Optional[str] = None,
88
- sdp_ratio: float = DEFAULT_SDP_RATIO,
89
- noise: float = DEFAULT_NOISE,
90
- noisew: float = DEFAULT_NOISEW,
91
- length: float = DEFAULT_LENGTH,
92
- line_split: bool = DEFAULT_LINE_SPLIT,
93
- split_interval: float = DEFAULT_SPLIT_INTERVAL,
94
- assist_text: Optional[str] = None,
95
- assist_text_weight: float = DEFAULT_ASSIST_TEXT_WEIGHT,
96
- use_assist_text: bool = False,
97
- style: str = DEFAULT_STYLE,
98
- style_weight: float = DEFAULT_STYLE_WEIGHT,
99
- given_tone: Optional[list[int]] = None,
100
- ) -> tuple[int, np.ndarray]:
101
- #logger.info(f"Start generating audio data from text:\n{text}")
102
- if language != "JP" and self.hps.version.endswith("JP-Extra"):
103
- raise ValueError(
104
- "The model is trained with JP-Extra, but the language is not JP"
105
- )
106
- if reference_audio_path == "":
107
- reference_audio_path = None
108
- if assist_text == "" or not use_assist_text:
109
- assist_text = None
110
-
111
- if self.net_g is None:
112
- self.load_net_g()
113
- if reference_audio_path is None:
114
- style_id = self.style2id[style]
115
- style_vector = self.get_style_vector(style_id, style_weight)
116
- else:
117
- style_vector = self.get_style_vector_from_audio(
118
- reference_audio_path, style_weight
119
- )
120
- if not line_split:
121
- with torch.no_grad():
122
- audio = infer(
123
- text=text,
124
- sdp_ratio=sdp_ratio,
125
- noise_scale=noise,
126
- noise_scale_w=noisew,
127
- length_scale=length,
128
- sid=sid,
129
- language=language,
130
- hps=self.hps,
131
- net_g=self.net_g,
132
- device=self.device,
133
- assist_text=assist_text,
134
- assist_text_weight=assist_text_weight,
135
- style_vec=style_vector,
136
- given_tone=given_tone,
137
- )
138
- else:
139
- texts = text.split("\n")
140
- texts = [t for t in texts if t != ""]
141
- audios = []
142
- with torch.no_grad():
143
- for i, t in enumerate(texts):
144
- audios.append(
145
- infer(
146
- text=t,
147
- sdp_ratio=sdp_ratio,
148
- noise_scale=noise,
149
- noise_scale_w=noisew,
150
- length_scale=length,
151
- sid=sid,
152
- language=language,
153
- hps=self.hps,
154
- net_g=self.net_g,
155
- device=self.device,
156
- assist_text=assist_text,
157
- assist_text_weight=assist_text_weight,
158
- style_vec=style_vector,
159
- )
160
- )
161
- if i != len(texts) - 1:
162
- audios.append(np.zeros(int(44100 * split_interval)))
163
- audio = np.concatenate(audios)
164
- with warnings.catch_warnings():
165
- warnings.simplefilter("ignore")
166
- audio = convert_to_16_bit_wav(audio)
167
- #logger.info("Audio data generated successfully")
168
- return (self.hps.data.sampling_rate, audio)
169
-
170
-
171
- class ModelHolder:
172
- def __init__(self, root_dir: str, device: str):
173
- self.root_dir: str = root_dir
174
- self.device: str = device
175
- self.model_files_dict: Dict[str, List[str]] = {}
176
- self.current_model: Optional[Model] = None
177
- self.model_names: List[str] = []
178
- self.models: List[Model] = []
179
- self.refresh()
180
-
181
- def refresh(self):
182
- self.model_files_dict = {}
183
- self.model_names = []
184
- self.current_model = None
185
- model_dirs = [
186
- d
187
- for d in os.listdir(self.root_dir)
188
- if os.path.isdir(os.path.join(self.root_dir, d))
189
- ]
190
- for model_name in model_dirs:
191
- model_dir = os.path.join(self.root_dir, model_name)
192
- model_files = [
193
- os.path.join(model_dir, f)
194
- for f in os.listdir(model_dir)
195
- if f.endswith(".pth") or f.endswith(".pt") or f.endswith(".safetensors")
196
- ]
197
- if len(model_files) == 0:
198
- logger.warning(
199
- f"No model files found in {self.root_dir}/{model_name}, so skip it"
200
- )
201
- continue
202
- self.model_files_dict[model_name] = model_files
203
- self.model_names.append(model_name)
204
-
205
- def load_model_gr(
206
- self, model_name: str, model_path: str
207
- ) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]:
208
- if model_name not in self.model_files_dict:
209
- raise ValueError(f"Model `{model_name}` is not found")
210
- if model_path not in self.model_files_dict[model_name]:
211
- raise ValueError(f"Model file `{model_path}` is not found")
212
- if (
213
- self.current_model is not None
214
- and self.current_model.model_path == model_path
215
- ):
216
- # Already loaded
217
- speakers = list(self.current_model.spk2id.keys())
218
- styles = list(self.current_model.style2id.keys())
219
- return (
220
- gr.Dropdown(choices=styles, value=styles[0]),
221
- gr.Button(interactive=True, value="音声合成"),
222
- gr.Dropdown(choices=speakers, value=speakers[0]),
223
- )
224
- self.current_model = Model(
225
- model_path=model_path,
226
- config_path=os.path.join(self.root_dir, model_name, "config.json"),
227
- style_vec_path=os.path.join(self.root_dir, model_name, "style_vectors.npy"),
228
- device=self.device,
229
- )
230
- speakers = list(self.current_model.spk2id.keys())
231
- styles = list(self.current_model.style2id.keys())
232
- return (
233
- gr.Dropdown(choices=styles, value=styles[0]),
234
- gr.Button(interactive=True, value="音声合成"),
235
- gr.Dropdown(choices=speakers, value=speakers[0]),
236
- )
237
-
238
- def update_model_files_gr(self, model_name: str) -> gr.Dropdown:
239
- model_files = self.model_files_dict[model_name]
240
- return gr.Dropdown(choices=model_files, value=model_files[0])
241
-
242
- def update_model_names_gr(self) -> tuple[gr.Dropdown, gr.Dropdown, gr.Button]:
243
- self.refresh()
244
- initial_model_name = self.model_names[0]
245
- initial_model_files = self.model_files_dict[initial_model_name]
246
- return (
247
- gr.Dropdown(choices=self.model_names, value=initial_model_name),
248
- gr.Dropdown(choices=initial_model_files, value=initial_model_files[0]),
249
- gr.Button(interactive=False), # For tts_button
250
- )
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import torch
4
+ import os
5
+ import warnings
6
+ from gradio.processing_utils import convert_to_16_bit_wav
7
+ from typing import Dict, List, Optional, Union
8
+
9
+ import utils
10
+ from infer import get_net_g, infer
11
+ from models import SynthesizerTrn
12
+ from models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra
13
+
14
+ from .log import logger
15
+ from .constants import (
16
+ DEFAULT_ASSIST_TEXT_WEIGHT,
17
+ DEFAULT_LENGTH,
18
+ DEFAULT_LINE_SPLIT,
19
+ DEFAULT_NOISE,
20
+ DEFAULT_NOISEW,
21
+ DEFAULT_SDP_RATIO,
22
+ DEFAULT_SPLIT_INTERVAL,
23
+ DEFAULT_STYLE,
24
+ DEFAULT_STYLE_WEIGHT,
25
+ )
26
+
27
+
28
+ class Model:
29
+ def __init__(
30
+ self, model_path: str, config_path: str, style_vec_path: str, device: str
31
+ ):
32
+ self.model_path: str = model_path
33
+ self.config_path: str = config_path
34
+ self.device: str = device
35
+ self.style_vec_path: str = style_vec_path
36
+ self.hps: utils.HParams = utils.get_hparams_from_file(self.config_path)
37
+ self.spk2id: Dict[str, int] = self.hps.data.spk2id
38
+ self.id2spk: Dict[int, str] = {v: k for k, v in self.spk2id.items()}
39
+
40
+ self.num_styles: int = self.hps.data.num_styles
41
+ if hasattr(self.hps.data, "style2id"):
42
+ self.style2id: Dict[str, int] = self.hps.data.style2id
43
+ else:
44
+ self.style2id: Dict[str, int] = {str(i): i for i in range(self.num_styles)}
45
+ if len(self.style2id) != self.num_styles:
46
+ raise ValueError(
47
+ f"Number of styles ({self.num_styles}) does not match the number of style2id ({len(self.style2id)})"
48
+ )
49
+
50
+ self.style_vectors: np.ndarray = np.load(self.style_vec_path)
51
+ if self.style_vectors.shape[0] != self.num_styles:
52
+ raise ValueError(
53
+ f"The number of styles ({self.num_styles}) does not match the number of style vectors ({self.style_vectors.shape[0]})"
54
+ )
55
+
56
+ self.net_g: Union[SynthesizerTrn, SynthesizerTrnJPExtra, None] = None
57
+
58
+ def load_net_g(self):
59
+ self.net_g = get_net_g(
60
+ model_path=self.model_path,
61
+ version=self.hps.version,
62
+ device=self.device,
63
+ hps=self.hps,
64
+ )
65
+
66
+ def get_style_vector(self, style_id: int, weight: float = 1.0) -> np.ndarray:
67
+ mean = self.style_vectors[0]
68
+ style_vec = self.style_vectors[style_id]
69
+ style_vec = mean + (style_vec - mean) * weight
70
+ return style_vec
71
+
72
+ def get_style_vector_from_audio(
73
+ self, audio_path: str, weight: float = 1.0
74
+ ) -> np.ndarray:
75
+ from style_gen import get_style_vector
76
+
77
+ xvec = get_style_vector(audio_path)
78
+ mean = self.style_vectors[0]
79
+ xvec = mean + (xvec - mean) * weight
80
+ return xvec
81
+
82
+ def infer(
83
+ self,
84
+ text: str,
85
+ language: str = "JP",
86
+ sid: int = 0,
87
+ reference_audio_path: Optional[str] = None,
88
+ sdp_ratio: float = DEFAULT_SDP_RATIO,
89
+ noise: float = DEFAULT_NOISE,
90
+ noisew: float = DEFAULT_NOISEW,
91
+ length: float = DEFAULT_LENGTH,
92
+ line_split: bool = DEFAULT_LINE_SPLIT,
93
+ split_interval: float = DEFAULT_SPLIT_INTERVAL,
94
+ assist_text: Optional[str] = None,
95
+ assist_text_weight: float = DEFAULT_ASSIST_TEXT_WEIGHT,
96
+ use_assist_text: bool = False,
97
+ style: str = DEFAULT_STYLE,
98
+ style_weight: float = DEFAULT_STYLE_WEIGHT,
99
+ given_tone: Optional[list[int]] = None,
100
+ ) -> tuple[int, np.ndarray]:
101
+ #logger.info(f"Start generating audio data from text:\n{text}")
102
+ if language != "JP" and self.hps.version.endswith("JP-Extra"):
103
+ raise ValueError(
104
+ "The model is trained with JP-Extra, but the language is not JP"
105
+ )
106
+ if reference_audio_path == "":
107
+ reference_audio_path = None
108
+ if assist_text == "" or not use_assist_text:
109
+ assist_text = None
110
+
111
+ if self.net_g is None:
112
+ self.load_net_g()
113
+ if reference_audio_path is None:
114
+ style_id = self.style2id[style]
115
+ style_vector = self.get_style_vector(style_id, style_weight)
116
+ else:
117
+ style_vector = self.get_style_vector_from_audio(
118
+ reference_audio_path, style_weight
119
+ )
120
+ if not line_split:
121
+ with torch.no_grad():
122
+ audio = infer(
123
+ text=text,
124
+ sdp_ratio=sdp_ratio,
125
+ noise_scale=noise,
126
+ noise_scale_w=noisew,
127
+ length_scale=length,
128
+ sid=sid,
129
+ language=language,
130
+ hps=self.hps,
131
+ net_g=self.net_g,
132
+ device=self.device,
133
+ assist_text=assist_text,
134
+ assist_text_weight=assist_text_weight,
135
+ style_vec=style_vector,
136
+ given_tone=given_tone,
137
+ )
138
+ else:
139
+ texts = text.split("\n")
140
+ texts = [t for t in texts if t != ""]
141
+ audios = []
142
+ with torch.no_grad():
143
+ for i, t in enumerate(texts):
144
+ audios.append(
145
+ infer(
146
+ text=t,
147
+ sdp_ratio=sdp_ratio,
148
+ noise_scale=noise,
149
+ noise_scale_w=noisew,
150
+ length_scale=length,
151
+ sid=sid,
152
+ language=language,
153
+ hps=self.hps,
154
+ net_g=self.net_g,
155
+ device=self.device,
156
+ assist_text=assist_text,
157
+ assist_text_weight=assist_text_weight,
158
+ style_vec=style_vector,
159
+ )
160
+ )
161
+ if i != len(texts) - 1:
162
+ audios.append(np.zeros(int(44100 * split_interval)))
163
+ audio = np.concatenate(audios)
164
+ with warnings.catch_warnings():
165
+ warnings.simplefilter("ignore")
166
+ audio = convert_to_16_bit_wav(audio)
167
+ #logger.info("Audio data generated successfully")
168
+ return (self.hps.data.sampling_rate, audio)
169
+
170
+
171
+ class ModelHolder:
172
+ def __init__(self, root_dir: str, device: str):
173
+ self.root_dir: str = root_dir
174
+ self.device: str = device
175
+ self.model_files_dict: Dict[str, List[str]] = {}
176
+ self.current_model: Optional[Model] = None
177
+ self.model_names: List[str] = []
178
+ self.models: List[Model] = []
179
+ self.refresh()
180
+
181
+ def refresh(self):
182
+ self.model_files_dict = {}
183
+ self.model_names = []
184
+ self.current_model = None
185
+ model_dirs = [
186
+ d
187
+ for d in os.listdir(self.root_dir)
188
+ if os.path.isdir(os.path.join(self.root_dir, d))
189
+ ]
190
+ for model_name in model_dirs:
191
+ model_dir = os.path.join(self.root_dir, model_name)
192
+ model_files = [
193
+ os.path.join(model_dir, f)
194
+ for f in os.listdir(model_dir)
195
+ if f.endswith(".pth") or f.endswith(".pt") or f.endswith(".safetensors")
196
+ ]
197
+ if len(model_files) == 0:
198
+ logger.warning(
199
+ f"No model files found in {self.root_dir}/{model_name}, so skip it"
200
+ )
201
+ continue
202
+ self.model_files_dict[model_name] = model_files
203
+ self.model_names.append(model_name)
204
+
205
+ def load_model_gr(
206
+ self, model_name: str, model_path: str
207
+ ) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]:
208
+ if model_name not in self.model_files_dict:
209
+ raise ValueError(f"Model `{model_name}` is not found")
210
+ #if model_path not in self.model_files_dict[model_name]:
211
+ # raise ValueError(f"Model file `{model_path}` is not found")
212
+ if (
213
+ self.current_model is not None
214
+ and self.current_model.model_path == model_path
215
+ ):
216
+ # Already loaded
217
+ speakers = list(self.current_model.spk2id.keys())
218
+ styles = list(self.current_model.style2id.keys())
219
+ return (
220
+ gr.Dropdown(choices=styles, value=styles[0]),
221
+ gr.Button(interactive=True, value="音声合成"),
222
+ gr.Dropdown(choices=speakers, value=speakers[0]),
223
+ )
224
+ self.current_model = Model(
225
+ model_path=model_path,
226
+ config_path=os.path.join(self.root_dir, model_name, "config.json"),
227
+ style_vec_path=os.path.join(self.root_dir, model_name, "style_vectors.npy"),
228
+ device=self.device,
229
+ )
230
+ speakers = list(self.current_model.spk2id.keys())
231
+ styles = list(self.current_model.style2id.keys())
232
+ return (
233
+ gr.Dropdown(choices=styles, value=styles[0]),
234
+ gr.Button(interactive=True, value="音声合成"),
235
+ gr.Dropdown(choices=speakers, value=speakers[0]),
236
+ )
237
+
238
+ def update_model_files_gr(self, model_name: str) -> gr.Dropdown:
239
+ model_files = self.model_files_dict[model_name]
240
+ return gr.Dropdown(choices=model_files, value=model_files[0])
241
+
242
+ def update_model_names_gr(self) -> tuple[gr.Dropdown, gr.Dropdown, gr.Button]:
243
+ self.refresh()
244
+ initial_model_name = self.model_names[0]
245
+ initial_model_files = self.model_files_dict[initial_model_name]
246
+ return (
247
+ gr.Dropdown(choices=self.model_names, value=initial_model_name),
248
+ gr.Dropdown(choices=initial_model_files, value=initial_model_files[0]),
249
+ gr.Button(interactive=False), # For tts_button
250
+ )