diff --git a/.env.webui b/.env.webui index a807cd0298c85378bbefe6e65990411d0b9ee51b..22d73786f170f6fa0741125fe3841a6eed8f0128 100644 --- a/.env.webui +++ b/.env.webui @@ -17,5 +17,5 @@ TTS_MAX_LEN=1000 SSML_MAX_LEN=3000 MAX_BATCH_SIZE=12 -V_GIT_TAG="🤗hf(0.5.6-rc)" +V_GIT_TAG="🤗hf(0.6.1-rc)" V_GIT_COMMIT=main diff --git a/language/zh-CN.json b/language/zh-CN.json index 31e5890a575dcac42c44b75b7675650a38f22553..f4f41cf038fe73c2194b311b9357b8f5a3b77d6d 100644 --- a/language/zh-CN.json +++ b/language/zh-CN.json @@ -80,6 +80,9 @@ "readme": "readme", "changelog": "changelog", "💼Speaker file": "💼音色文件", + "🎛️Spliter": "🎛️分割器配置", + "eos": "句尾词", + "Spliter Threshold": "分割器阈值", "TTS_STYLE_GUIDE": ["后缀为 _p 表示带prompt,效果更强但是影响质量"], "SSML_SPLITER_GUIDE": [ "- 如果尾字吞字不读,可以试试结尾加上 `[lbreak]`", diff --git a/modules/ChatTTS/ChatTTS/core.py b/modules/ChatTTS/ChatTTS/core.py index 225de72f07cb237cfa7210872d316936afe808f8..549973e0c5dcdf9869ae1237a65fc7762ceae244 100644 --- a/modules/ChatTTS/ChatTTS/core.py +++ b/modules/ChatTTS/ChatTTS/core.py @@ -17,7 +17,7 @@ from .infer.api import refine_text, infer_code from huggingface_hub import snapshot_download -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.ERROR) class Chat: diff --git a/modules/SynthesizeSegments.py b/modules/SynthesizeSegments.py index e7589ab85d6ebdbd3618d6a668a0bf60c6a51b56..de8a7778a27b2d89e6058a15716a0c53538c9a71 100644 --- a/modules/SynthesizeSegments.py +++ b/modules/SynthesizeSegments.py @@ -1,8 +1,10 @@ +import copy from box import Box from pydub import AudioSegment from typing import List, Union from scipy.io.wavfile import write import io +from modules.SentenceSplitter import SentenceSplitter from modules.api.utils import calc_spk_style from modules.ssml_parser.SSMLParser import SSMLSegment, SSMLBreak, SSMLContext from modules.utils import rng @@ -56,27 +58,27 @@ def to_number(value, t, default=0): class TTSAudioSegment(Box): - text: str - temperature: float - top_P: float - top_K: int - spk: int - infer_seed: int - prompt1: str - prompt2: str - prefix: str - - _type: str - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self._type = kwargs.get("_type", "voice") + self.text = kwargs.get("text", "") + self.temperature = kwargs.get("temperature", 0.3) + self.top_P = kwargs.get("top_P", 0.5) + self.top_K = kwargs.get("top_K", 20) + self.spk = kwargs.get("spk", -1) + self.infer_seed = kwargs.get("infer_seed", -1) + self.prompt1 = kwargs.get("prompt1", "") + self.prompt2 = kwargs.get("prompt2", "") + self.prefix = kwargs.get("prefix", "") class SynthesizeSegments: - def __init__(self, batch_size: int = 8): + def __init__(self, batch_size: int = 8, eos="", spliter_thr=100): self.batch_size = batch_size self.batch_default_spk_seed = rng.np_rng() self.batch_default_infer_seed = rng.np_rng() + self.eos = eos + self.spliter_thr = spliter_thr def segment_to_generate_params( self, segment: Union[SSMLSegment, SSMLBreak] @@ -85,9 +87,11 @@ class SynthesizeSegments: return TTSAudioSegment(_type="break") if segment.get("params", None) is not None: - return TTSAudioSegment(**segment.get("params")) + params = segment.get("params") + text = segment.get("text", None) or segment.text or "" + return TTSAudioSegment(**params, text=text) - text = segment.get("text", "") + text = segment.get("text", None) or segment.text or "" is_end = segment.get("is_end", False) text = str(text).strip() @@ -156,7 +160,7 @@ class SynthesizeSegments: for i in range(0, len(bucket), self.batch_size): batch = bucket[i : i + self.batch_size] param_arr = [self.segment_to_generate_params(segment) for segment in batch] - texts = [params.text for params in param_arr] + texts = [params.text + self.eos for params in param_arr] params = param_arr[0] audio_datas = generate_audio.generate_audio_batch( @@ -204,9 +208,38 @@ class SynthesizeSegments: return buckets + def split_segments(self, segments: List[Union[SSMLSegment, SSMLBreak]]): + """ + 将 segments 中的 text 经过 spliter 处理成多个 segments + """ + spliter = SentenceSplitter(threshold=self.spliter_thr) + ret_segments: List[Union[SSMLSegment, SSMLBreak]] = [] + + for segment in segments: + if isinstance(segment, SSMLBreak): + ret_segments.append(segment) + continue + + text = segment.text + if not text: + continue + + sentences = spliter.parse(text) + for sentence in sentences: + ret_segments.append( + SSMLSegment( + text=sentence, + attrs=segment.attrs.copy(), + params=copy.copy(segment.params), + ) + ) + + return ret_segments + def synthesize_segments( self, segments: List[Union[SSMLSegment, SSMLBreak]] ) -> List[AudioSegment]: + segments = self.split_segments(segments) audio_segments = [None] * len(segments) buckets = self.bucket_segments(segments) diff --git a/modules/api/api_setup.py b/modules/api/api_setup.py index e7de2a62e4131afe6fb5db0280feca8288f4d79d..bfe07f4de7b7f9ddc4a2c625579f1ecd07aa1e2a 100644 --- a/modules/api/api_setup.py +++ b/modules/api/api_setup.py @@ -18,6 +18,7 @@ from modules.api.impl import ( speaker_api, ping_api, models_api, + xtts_v2_api, ) logger = logging.getLogger(__name__) @@ -35,6 +36,7 @@ def create_api(app, exclude=[]): google_api.setup(app_mgr) openai_api.setup(app_mgr) refiner_api.setup(app_mgr) + xtts_v2_api.setup(app_mgr) return app_mgr @@ -42,9 +44,9 @@ def create_api(app, exclude=[]): def setup_model_args(parser: argparse.ArgumentParser): parser.add_argument("--compile", action="store_true", help="Enable model compile") parser.add_argument( - "--half", + "--no_half", action="store_true", - help="Enable half precision for model inference", + help="Disalbe half precision for model inference", ) parser.add_argument( "--off_tqdm", @@ -82,7 +84,7 @@ def process_model_args(args): compile = env.get_and_update_env(args, "compile", False, bool) device_id = env.get_and_update_env(args, "device_id", None, str) use_cpu = env.get_and_update_env(args, "use_cpu", [], list) - half = env.get_and_update_env(args, "half", False, bool) + no_half = env.get_and_update_env(args, "no_half", False, bool) off_tqdm = env.get_and_update_env(args, "off_tqdm", False, bool) debug_generate = env.get_and_update_env(args, "debug_generate", False, bool) diff --git a/modules/api/impl/google_api.py b/modules/api/impl/google_api.py index cd66e036578866f35b2cf8a9f1f559e1e2ca7e90..6244bc9fa8ab4c625b5ec8a04982be20798f8ed5 100644 --- a/modules/api/impl/google_api.py +++ b/modules/api/impl/google_api.py @@ -13,6 +13,7 @@ from modules.Enhancer.ResembleEnhance import ( ) from modules.api.Api import APIManager from modules.synthesize_audio import synthesize_audio +from modules.utils import audio from modules.utils.audio import apply_prosody_to_audio_data from modules.normalization import text_normalize @@ -44,6 +45,9 @@ class VoiceSelectionParams(BaseModel): topK: int = 20 seed: int = 42 + # end_of_sentence + eos: str = "[uv_break]" + class AudioConfig(BaseModel): audioEncoding: api_utils.AudioFormat = "mp3" @@ -87,6 +91,7 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest): language_code = voice.languageCode voice_name = voice.name infer_seed = voice.seed or 42 + eos = voice.eos or "[uv_break]" audio_format = audioConfig.audioEncoding or "mp3" speaking_rate = audioConfig.speakingRate or 1 pitch = audioConfig.pitch or 0 @@ -94,11 +99,9 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest): batch_size = audioConfig.batchSize or 1 - # TODO spliter_threshold spliter_threshold = audioConfig.spliterThreshold or 100 - # TODO sample_rate - sample_rate_hertz = audioConfig.sampleRateHertz or 24000 + sample_rate = audioConfig.sampleRateHertz or 24000 params = api_utils.calc_spk_style(spk=voice.name, style=voice.style) @@ -137,10 +140,10 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest): prefix=params.get("prefix", ""), batch_size=batch_size, spliter_threshold=spliter_threshold, + end_of_sentence=eos, ) elif input.ssml: - # 处理SSML合成逻辑 parser = create_ssml_parser() segments = parser.parse(input.ssml) for seg in segments: @@ -151,17 +154,13 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest): status_code=422, detail="The SSML text is empty or parsing failed." ) - synthesize = SynthesizeSegments(batch_size=batch_size) + synthesize = SynthesizeSegments( + batch_size=batch_size, eos=eos, spliter_thr=spliter_threshold + ) audio_segments = synthesize.synthesize_segments(segments) combined_audio = combine_audio_segments(audio_segments) - buffer = io.BytesIO() - combined_audio.export(buffer, format="wav") - - buffer.seek(0) - - audio_data = buffer.read() - + sample_rate, audio_data = audio.pydub_to_np(combined_audio) else: raise HTTPException( status_code=422, detail="Either text or SSML input must be provided." diff --git a/modules/api/impl/openai_api.py b/modules/api/impl/openai_api.py index f1e21e1241c958047d3b8488105981528fde82eb..7c0e012c093c13bb6ce19415a3e504cc3fcafe46 100644 --- a/modules/api/impl/openai_api.py +++ b/modules/api/impl/openai_api.py @@ -41,6 +41,8 @@ class AudioSpeechRequest(BaseModel): spliter_threshold: float = Field( 100, ge=10, le=1024, description="Threshold for sentence spliter" ) + # end of sentence + eos: str = "[uv_break]" async def openai_speech_api( @@ -52,6 +54,7 @@ async def openai_speech_api( input_text = request.input voice = request.voice style = request.style + eos = request.eos response_format = request.response_format batch_size = request.batch_size spliter_threshold = request.spliter_threshold @@ -95,6 +98,7 @@ async def openai_speech_api( prompt1=prompt1, prompt2=prompt2, prefix=prefix, + end_of_sentence=eos, ) if speed != 1: diff --git a/modules/api/impl/ssml_api.py b/modules/api/impl/ssml_api.py index 2696470d6afaa6f5ba6bac9a75b36e2bd6164ce8..c6277b6214fe18a5f9e271c766f1f16d9d3f981f 100644 --- a/modules/api/impl/ssml_api.py +++ b/modules/api/impl/ssml_api.py @@ -26,8 +26,13 @@ class SSMLRequest(BaseModel): # NOTE: 🤔 也许这个值应该配置成系统变量? 传进来有点奇怪 batch_size: int = 4 + # end of sentence + eos: str = "[uv_break]" -async def synthesize_ssml( + spliter_thr: int = 100 + + +async def synthesize_ssml_api( request: SSMLRequest = Body( ..., description="JSON body with SSML string and format" ) @@ -36,12 +41,19 @@ async def synthesize_ssml( ssml = request.ssml format = request.format.lower() batch_size = request.batch_size + eos = request.eos + spliter_thr = request.spliter_thr if batch_size < 1: raise HTTPException( status_code=400, detail="Batch size must be greater than 0." ) + if spliter_thr < 50: + raise HTTPException( + status_code=400, detail="Spliter threshold must be greater than 50." + ) + if not ssml or ssml == "": raise HTTPException(status_code=400, detail="SSML content is required.") @@ -55,7 +67,9 @@ async def synthesize_ssml( for seg in segments: seg["text"] = text_normalize(seg["text"], is_end=True) - synthesize = SynthesizeSegments(batch_size) + synthesize = SynthesizeSegments( + batch_size=batch_size, eos=eos, spliter_thr=spliter_thr + ) audio_segments = synthesize.synthesize_segments(segments) combined_audio = combine_audio_segments(audio_segments) buffer = io.BytesIO() @@ -77,4 +91,4 @@ async def synthesize_ssml( def setup(api_manager: APIManager): - api_manager.post("/v1/ssml", response_class=FileResponse)(synthesize_ssml) + api_manager.post("/v1/ssml", response_class=FileResponse)(synthesize_ssml_api) diff --git a/modules/api/impl/tts_api.py b/modules/api/impl/tts_api.py index 7330b4820612341e8ca57fea0a04b3dbe3eadfa5..b91f5493ca898c5d7522110004a65443849536b8 100644 --- a/modules/api/impl/tts_api.py +++ b/modules/api/impl/tts_api.py @@ -38,6 +38,7 @@ class TTSParams(BaseModel): prefix: str = Query("", description="Text prefix for inference") bs: str = Query("8", description="Batch size for inference") thr: str = Query("100", description="Threshold for sentence spliter") + eos: str = Query("", description="End of sentence str") async def synthesize_tts(params: TTSParams = Depends()): @@ -87,6 +88,7 @@ async def synthesize_tts(params: TTSParams = Depends()): prefix = params.prefix or calc_params.get("prefix", params.prefix) prompt1 = params.prompt1 or calc_params.get("prompt1", params.prompt1) prompt2 = params.prompt2 or calc_params.get("prompt2", params.prompt2) + eos = params.eos or "" batch_size = int(params.bs) threshold = int(params.thr) @@ -103,6 +105,7 @@ async def synthesize_tts(params: TTSParams = Depends()): prefix=prefix, batch_size=batch_size, spliter_threshold=threshold, + end_of_sentence=eos, ) buffer = io.BytesIO() diff --git a/modules/api/impl/xtts_v2_api.py b/modules/api/impl/xtts_v2_api.py new file mode 100644 index 0000000000000000000000000000000000000000..0b660562b2ae1041f46c62fe17355040723a6590 --- /dev/null +++ b/modules/api/impl/xtts_v2_api.py @@ -0,0 +1,160 @@ +import io +from fastapi import HTTPException +from fastapi.responses import StreamingResponse +from pydantic import BaseModel +from modules.api import utils as api_utils +from modules.api.Api import APIManager + +import soundfile as sf + +from modules import config +from modules.normalization import text_normalize +from modules.speaker import speaker_mgr +from modules.synthesize_audio import synthesize_audio + +import logging + +from modules.utils.audio import apply_prosody_to_audio_data + +logger = logging.getLogger(__name__) + + +class XTTS_V2_Settings: + def __init__(self): + self.stream_chunk_size = 100 + self.temperature = 0.3 + self.speed = 1 + self.length_penalty = 0.5 + self.repetition_penalty = 1.0 + self.top_p = 0.7 + self.top_k = 20 + self.enable_text_splitting = True + + +class TTSSettingsRequest(BaseModel): + stream_chunk_size: int + temperature: float + speed: float + length_penalty: float + repetition_penalty: float + top_p: float + top_k: int + enable_text_splitting: bool + + +class SynthesisRequest(BaseModel): + text: str + speaker_wav: str + language: str + + +def setup(app: APIManager): + XTTSV2 = XTTS_V2_Settings() + + @app.get("/v1/xtts_v2/speakers") + async def speakers(): + spks = speaker_mgr.list_speakers() + return [ + { + "name": spk.name, + "voice_id": spk.id, + # TODO: 也许可以放一个 "/v1/tts" 接口地址在这里 + "preview_url": "", + } + for spk in spks + ] + + @app.post("/v1/xtts_v2/tts_to_audio", response_class=StreamingResponse) + async def tts_to_audio(request: SynthesisRequest): + text = request.text + # speaker_wav 就是 speaker id 。。。 + voice_id = request.speaker_wav + language = request.language + + spk = speaker_mgr.get_speaker_by_id(voice_id) or speaker_mgr.get_speaker( + voice_id + ) + if spk is None: + raise HTTPException(status_code=400, detail="Invalid speaker id") + + text = text_normalize(text, is_end=True) + sample_rate, audio_data = synthesize_audio( + text=text, + temperature=XTTSV2.temperature, + # length_penalty=XTTSV2.length_penalty, + # repetition_penalty=XTTSV2.repetition_penalty, + top_P=XTTSV2.top_p, + top_K=XTTSV2.top_k, + spk=spk, + spliter_threshold=XTTSV2.stream_chunk_size, + # TODO 支持设置 batch_size + batch_size=4, + end_of_sentence="[uv_break]", + ) + + if XTTSV2.speed: + audio_data = apply_prosody_to_audio_data( + audio_data, + rate=XTTSV2.speed, + sr=sample_rate, + ) + + # to mp3 + buffer = io.BytesIO() + sf.write(buffer, audio_data, sample_rate, format="wav") + buffer.seek(0) + + buffer = api_utils.wav_to_mp3(buffer) + + return StreamingResponse(buffer, media_type="audio/mpeg") + + @app.get("/v1/xtts_v2/tts_stream") + async def tts_stream(): + raise HTTPException(status_code=501, detail="Not implemented") + + @app.post("/v1/xtts_v2/set_tts_settings") + async def set_tts_settings(request: TTSSettingsRequest): + try: + if request.stream_chunk_size < 50: + raise HTTPException( + status_code=400, detail="stream_chunk_size must be greater than 0" + ) + if request.temperature < 0: + raise HTTPException( + status_code=400, detail="temperature must be greater than 0" + ) + if request.speed < 0: + raise HTTPException( + status_code=400, detail="speed must be greater than 0" + ) + if request.length_penalty < 0: + raise HTTPException( + status_code=400, detail="length_penalty must be greater than 0" + ) + if request.repetition_penalty < 0: + raise HTTPException( + status_code=400, detail="repetition_penalty must be greater than 0" + ) + if request.top_p < 0: + raise HTTPException( + status_code=400, detail="top_p must be greater than 0" + ) + if request.top_k < 0: + raise HTTPException( + status_code=400, detail="top_k must be greater than 0" + ) + + XTTSV2.stream_chunk_size = request.stream_chunk_size + XTTSV2.temperature = request.temperature + XTTSV2.speed = request.speed + XTTSV2.length_penalty = request.length_penalty + XTTSV2.repetition_penalty = request.repetition_penalty + XTTSV2.top_p = request.top_p + XTTSV2.top_k = request.top_k + XTTSV2.enable_text_splitting = request.enable_text_splitting + return {"message": "Settings successfully applied"} + except Exception as e: + if isinstance(e, HTTPException): + raise e + logger.error(e) + raise HTTPException(status_code=500, detail=str(e)) diff --git a/modules/devices/devices.py b/modules/devices/devices.py index c9e1862cdd427f19edf669b2b6c4daee9d6f3340..e11f6eba5da29af778ee1d254778acaa7492b6bb 100644 --- a/modules/devices/devices.py +++ b/modules/devices/devices.py @@ -127,7 +127,7 @@ def reset_device(): global dtype_gpt global dtype_decoder - if config.runtime_env_vars.half: + if not config.runtime_env_vars.no_half: dtype = torch.float16 dtype_dvae = torch.float16 dtype_vocos = torch.float16 diff --git a/modules/finetune/__init__.py b/modules/finetune/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/finetune/model/__init__.py b/modules/finetune/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/finetune/model/encoder.py b/modules/finetune/model/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d445ad83d07ebdd9bfaa25e9e6d9c64001471d08 --- /dev/null +++ b/modules/finetune/model/encoder.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn + +from modules.ChatTTS.ChatTTS.model.dvae import ConvNeXtBlock, DVAEDecoder + +from .wavenet import WaveNet + + +def get_encoder_config(decoder: DVAEDecoder) -> dict[str, int | bool]: + return { + "idim": decoder.conv_out.out_channels, + "odim": decoder.conv_in[0].in_channels, + "n_layer": len(decoder.decoder_block), + "bn_dim": decoder.conv_in[0].out_channels, + "hidden": decoder.conv_in[2].out_channels, + "kernel": decoder.decoder_block[0].dwconv.kernel_size[0], + "dilation": decoder.decoder_block[0].dwconv.dilation[0], + "down": decoder.up, + } + + +class DVAEEncoder(nn.Module): + def __init__( + self, + idim: int, + odim: int, + n_layer: int = 12, + bn_dim: int = 64, + hidden: int = 256, + kernel: int = 7, + dilation: int = 2, + down: bool = False, + ) -> None: + super().__init__() + self.wavenet = WaveNet( + input_channels=100, + residual_channels=idim, + residual_layers=20, + dilation_cycle=4, + ) + self.conv_in_transpose = nn.ConvTranspose1d( + idim, hidden, kernel_size=1, bias=False + ) + # nn.Sequential( + # nn.ConvTranspose1d(100, idim, 3, 1, 1, bias=False), + # nn.ConvTranspose1d(idim, hidden, kernel_size=1, bias=False) + # ) + self.encoder_block = nn.ModuleList( + [ + ConvNeXtBlock( + hidden, + hidden * 4, + kernel, + dilation, + ) + for _ in range(n_layer) + ] + ) + self.conv_out_transpose = nn.Sequential( + nn.Conv1d(hidden, bn_dim, 3, 1, 1), + nn.GELU(), + nn.Conv1d(bn_dim, odim, 3, 1, 1), + ) + + def forward( + self, + audio_mel_specs: torch.Tensor, # (batch_size, audio_len*2, 100) + audio_attention_mask: torch.Tensor, # (batch_size, audio_len) + conditioning=None, + ) -> torch.Tensor: + mel_attention_mask = ( + audio_attention_mask.unsqueeze(-1).repeat(1, 1, 2).flatten(1) + ) + x: torch.Tensor = self.wavenet( + audio_mel_specs.transpose(1, 2) + ) # (batch_size, idim, audio_len*2) + x = x * mel_attention_mask.unsqueeze(1) + x = self.conv_in_transpose(x) # (batch_size, hidden, audio_len*2) + for f in self.encoder_block: + x = f(x, conditioning) + x = self.conv_out_transpose(x) # (batch_size, odim, audio_len*2) + x = ( + x.view(x.size(0), x.size(1), 2, x.size(2) // 2) + .permute(0, 3, 1, 2) + .flatten(2) + ) + return x # (batch_size, audio_len, audio_dim=odim*2) diff --git a/modules/finetune/model/wavenet.py b/modules/finetune/model/wavenet.py new file mode 100644 index 0000000000000000000000000000000000000000..828aa41c8afdbbb59f038d5053357a317721b5c0 --- /dev/null +++ b/modules/finetune/model/wavenet.py @@ -0,0 +1,227 @@ +"""https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/vqgan/modules/wavenet.py""" + +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + + +class Mish(nn.Module): + def forward(self, x): + return x * torch.tanh(F.softplus(x)) + + +class DiffusionEmbedding(nn.Module): + """Diffusion Step Embedding""" + + def __init__(self, d_denoiser): + super(DiffusionEmbedding, self).__init__() + self.dim = d_denoiser + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class LinearNorm(nn.Module): + """LinearNorm Projection""" + + def __init__(self, in_features, out_features, bias=False): + super(LinearNorm, self).__init__() + self.linear = nn.Linear(in_features, out_features, bias) + + nn.init.xavier_uniform_(self.linear.weight) + if bias: + nn.init.constant_(self.linear.bias, 0.0) + + def forward(self, x): + x = self.linear(x) + return x + + +class ConvNorm(nn.Module): + """1D Convolution""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=None, + dilation=1, + bias=True, + w_init_gain="linear", + ): + super(ConvNorm, self).__init__() + + if padding is None: + assert kernel_size % 2 == 1 + padding = int(dilation * (kernel_size - 1) / 2) + + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + nn.init.kaiming_normal_(self.conv.weight) + + def forward(self, signal): + conv_signal = self.conv(signal) + + return conv_signal + + +class ResidualBlock(nn.Module): + """Residual Block""" + + def __init__( + self, + residual_channels, + use_linear_bias=False, + dilation=1, + condition_channels=None, + ): + super(ResidualBlock, self).__init__() + self.conv_layer = ConvNorm( + residual_channels, + 2 * residual_channels, + kernel_size=3, + stride=1, + padding=dilation, + dilation=dilation, + ) + + if condition_channels is not None: + self.diffusion_projection = LinearNorm( + residual_channels, residual_channels, use_linear_bias + ) + self.condition_projection = ConvNorm( + condition_channels, 2 * residual_channels, kernel_size=1 + ) + + self.output_projection = ConvNorm( + residual_channels, 2 * residual_channels, kernel_size=1 + ) + + def forward(self, x, condition=None, diffusion_step=None): + y = x + + if diffusion_step is not None: + diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) + y = y + diffusion_step + + y = self.conv_layer(y) + + if condition is not None: + condition = self.condition_projection(condition) + y = y + condition + + gate, filter = torch.chunk(y, 2, dim=1) + y = torch.sigmoid(gate) * torch.tanh(filter) + + y = self.output_projection(y) + residual, skip = torch.chunk(y, 2, dim=1) + + return (x + residual) / math.sqrt(2.0), skip + + +class WaveNet(nn.Module): + def __init__( + self, + input_channels: Optional[int] = None, + output_channels: Optional[int] = None, + residual_channels: int = 512, + residual_layers: int = 20, + dilation_cycle: Optional[int] = 4, + is_diffusion: bool = False, + condition_channels: Optional[int] = None, + ): + super().__init__() + + # Input projection + self.input_projection = None + if input_channels is not None and input_channels != residual_channels: + self.input_projection = ConvNorm( + input_channels, residual_channels, kernel_size=1 + ) + + if input_channels is None: + input_channels = residual_channels + + self.input_channels = input_channels + + # Residual layers + self.residual_layers = nn.ModuleList( + [ + ResidualBlock( + residual_channels=residual_channels, + use_linear_bias=False, + dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1, + condition_channels=condition_channels, + ) + for i in range(residual_layers) + ] + ) + + # Skip projection + self.skip_projection = ConvNorm( + residual_channels, residual_channels, kernel_size=1 + ) + + # Output projection + self.output_projection = None + if output_channels is not None and output_channels != residual_channels: + self.output_projection = ConvNorm( + residual_channels, output_channels, kernel_size=1 + ) + + if is_diffusion: + self.diffusion_embedding = DiffusionEmbedding(residual_channels) + self.mlp = nn.Sequential( + LinearNorm(residual_channels, residual_channels * 4, False), + Mish(), + LinearNorm(residual_channels * 4, residual_channels, False), + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + if getattr(m, "bias", None) is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, t=None, condition=None): + if self.input_projection is not None: + x = self.input_projection(x) + x = F.silu(x) + + if t is not None: + t = self.diffusion_embedding(t) + t = self.mlp(t) + + skip = [] + for layer in self.residual_layers: + x, skip_connection = layer(x, condition, t) + skip.append(skip_connection) + + x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers)) + x = self.skip_projection(x) + + if self.output_projection is not None: + x = F.silu(x) + x = self.output_projection(x) + + return x diff --git a/modules/finetune/train_gpt.py b/modules/finetune/train_gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..9642d37016a0b99348d11a79dcdf7b2bdd3c0aef --- /dev/null +++ b/modules/finetune/train_gpt.py @@ -0,0 +1,246 @@ +import functools +import torch +import transformers +import peft +from transformers.trainer_pt_utils import LabelSmoother +from utils.dataset import AudioCollator +from utils.logger import MetricLogger +from utils.output import ansi, get_ansi_len, output_iter + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index + + +def train_gpt_lora( + chat, + dataset, + decoder_encoder, + dvae_encoder, + batch_size=16, + epochs=10, + train_text=True, + speaker_embeds=None, + lora_r=8, + lora_alpha=16, +): + if speaker_embeds is None: + speaker_embeds = {} + + tokenizer = chat.pretrain_models["tokenizer"] + decoder_decoder = chat.pretrain_models["decoder"] + decoder_decoder.eval().requires_grad_(False) + decoder_encoder.to(device=dataset.device).eval().requires_grad_(False) + dvae_decoder = chat.pretrain_models["dvae"] + dvae_decoder.eval().requires_grad_(False) + dvae_encoder.to(device=dataset.device).eval().requires_grad_(False) + + gpt = chat.pretrain_models["gpt"] + gpt.train().requires_grad_() + + # Add LoRA to GPT model + lora_config = peft.LoraConfig(r=lora_r, lora_alpha=lora_alpha) + gpt.gpt = peft.get_peft_model(gpt.gpt, lora_config) + + speaker_embeds = { + speaker: torch.randn(768, device=dataset.device, requires_grad=True) + for speaker in dataset.speakers + } | speaker_embeds + + for speaker_embed in speaker_embeds.values(): + std, mean = chat.pretrain_models["spk_stat"].chunk(2) + speaker_embed.data = speaker_embed.data * std + mean + + SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]") + AUDIO_EOS_TOKEN_ID = 0 + AUDIO_PAD_TOKEN_ID = AUDIO_EOS_TOKEN_ID + + train_params = list(gpt.parameters()) + list(speaker_embeds.values()) + optimizer = torch.optim.Adam( + gpt.parameters(), lr=1e-3, weight_decay=0, betas=[0.9, 0.95], eps=1e-5 + ) + optimizer.add_param_group({"params": speaker_embeds.values(), "lr": 1e-1}) + + loss_fn = torch.nn.CrossEntropyLoss() + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, 1e-7) + + loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=AudioCollator(text_pad=tokenizer.pad_token_id), + ) + logger = MetricLogger() + logger.create_meters(loss=None, mse_loss=None, audio_loss=None, text_loss=None) + + for _epoch in range(epochs): + _epoch += 1 + logger.reset() + header = "{blue_light}{0}: {1}{reset}".format( + "Epoch", output_iter(_epoch, epochs), **ansi + ) + header = header.ljust(max(len("Epoch"), 30) + get_ansi_len(header)) + iterator = logger.log_every(loader, header=header, tqdm_header="Batch") + + for batch in iterator: + speakers = batch["speaker"] + text_input_ids = batch["text_input_ids"] + text_attention_mask = batch["text_attention_mask"] + audio_mel_specs = batch["audio_mel_specs"] + audio_attention_mask = batch["audio_attention_mask"] + + batch_size, text_len = text_attention_mask.size() + + dvae_audio_latents = dvae_encoder(audio_mel_specs, audio_attention_mask) + _, dvae_audio_input_ids = quantize( + dvae_decoder.vq_layer.quantizer, dvae_audio_latents + ) + dvae_audio_input_ids[~audio_attention_mask.bool()] = AUDIO_PAD_TOKEN_ID + + extended_audio_attention_mask = torch.cat( + [ + audio_attention_mask, + torch.zeros( + (batch_size, 1), + dtype=audio_attention_mask.dtype, + device=audio_attention_mask.device, + ), + ], + dim=1, + ) + extended_audio_input_ids = torch.cat( + [ + dvae_audio_input_ids, + AUDIO_PAD_TOKEN_ID + * torch.ones( + (batch_size, 1, gpt.num_vq), + dtype=dvae_audio_input_ids.dtype, + device=dvae_audio_input_ids.device, + ), + ], + dim=1, + ) + + indices = audio_attention_mask.int().sum(dim=1) + for i in range(batch_size): + extended_audio_attention_mask[i, indices[i]] = 1 + extended_audio_input_ids[i, indices[i]] = AUDIO_EOS_TOKEN_ID + + input_ids = torch.cat( + [ + text_input_ids.unsqueeze(-1).repeat(1, 1, gpt.num_vq), + extended_audio_input_ids, + ], + dim=1, + ) + attention_mask = torch.cat( + [text_attention_mask, extended_audio_attention_mask], dim=1 + ) + text_mask = torch.cat( + [ + torch.ones_like(text_attention_mask, dtype=bool), + torch.zeros_like(extended_audio_attention_mask, dtype=bool), + ], + dim=1, + ) + labels = input_ids.clone() + labels[~attention_mask.bool()] = IGNORE_TOKEN_ID + + inputs_embeds = gpt.get_emb(input_ids=input_ids, text_mask=text_mask) + + indices = torch.all(input_ids == SPEAKER_TOKEN_ID, dim=-1) + for i, speaker in enumerate(speakers): + inputs_embeds[i, indices[i]] = torch.nn.functional.normalize( + speaker_embeds[speaker].to(dtype=inputs_embeds.dtype), + p=2.0, + dim=-1, + eps=1e-12, + ).unsqueeze(0) + + outputs = gpt.gpt.forward( + inputs_embeds=inputs_embeds, attention_mask=attention_mask + ) + hidden_states = outputs.last_hidden_state + text_hidden_states = hidden_states[:, : text_len - 1] + audio_hidden_states = hidden_states[:, text_len - 1 : -1] + + audio_logits = torch.stack( + [gpt.head_code[i](audio_hidden_states) for i in range(gpt.num_vq)], + dim=2, + ) + audio_loss = loss_fn( + audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2) + ) + loss = audio_loss + + if train_text: + text_logits = gpt.head_text(text_hidden_states) + text_loss = loss_fn( + text_logits.flatten(0, 1), labels[:, 1:text_len, 0].flatten(0, 1) + ) + loss += text_loss + logger.meters["text_loss"].update(text_loss.item(), n=batch_size) + + gpt_gen_mel_specs = decoder_decoder( + audio_hidden_states[:, :-1].transpose(1, 2) + ).transpose(1, 2) + mse_loss = torch.nn.functional.mse_loss(gpt_gen_mel_specs, audio_mel_specs) + loss += 0.01 * mse_loss + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(train_params, 1.0) + optimizer.step() + + logger.meters["loss"].update(loss.item(), n=batch_size) + logger.meters["mse_loss"].update(mse_loss.item(), n=batch_size) + logger.meters["audio_loss"].update(audio_loss.item(), n=batch_size) + + lr_scheduler.step() + optimizer.zero_grad() + return speaker_embeds + + +# Example usage +def main(): + # Load necessary models and data paths + chat = ChatTTS.Chat() + chat.load_models() + dataset = XzListTar( + root="data/all.list", + tokenizer=chat.pretrain_models["tokenizer"], + vocos_model=chat.pretrain_models["vocos"], + tar_path="data/Xz.tar", + tar_in_memory=True, + process_ahead=True, + ) + + decoder_encoder = DVAEEncoder( + **get_encoder_config(chat.pretrain_models["decoder"].decoder) + ) + dvae_encoder = DVAEEncoder( + **get_encoder_config(chat.pretrain_models["dvae"].decoder) + ) + + # Train GPT with LoRA + speaker_embeds = train_gpt_lora( + chat=chat, + dataset=dataset, + decoder_encoder=decoder_encoder, + dvae_encoder=dvae_encoder, + batch_size=32, + epochs=10, + train_text=True, + lora_r=8, + lora_alpha=16, + ) + + # Save LoRA parameters and embeddings + lora_save_path = "./saved_models/gpt_lora.pth" + peft.save_pretrained(gpt.gpt, lora_save_path) + np.savez( + "./saved_models/speaker_embeds.npz", + **{k: v.cpu().numpy() for k, v in speaker_embeds.items()} + ) + + +if __name__ == "__main__": + main() diff --git a/modules/finetune/train_speaker.py b/modules/finetune/train_speaker.py new file mode 100644 index 0000000000000000000000000000000000000000..343d743d6c2acff51c170fde6fc3ae32bb47c482 --- /dev/null +++ b/modules/finetune/train_speaker.py @@ -0,0 +1,296 @@ +import torch +import torch.nn.functional as F +import transformers + +from modules.finetune.model.encoder import DVAEEncoder, get_encoder_config +from modules.finetune.utils.output import get_ansi_len, output_iter, ansi +from .utils.logger import MetricLogger +from .utils.dataset import AudioCollator, XzListTar +from .utils.model import quantize + +IGNORE_TOKEN_ID = transformers.trainer_pt_utils.LabelSmoother.ignore_index + + +def train_speaker_embeddings( + chat, + dataset, + gpt, + batch_size=16, + epochs=10, + train_text=True, + speaker_embeds=None, +): + tokenizer = chat.pretrain_models["tokenizer"] + + decoder_decoder = chat.pretrain_models["decoder"] + decoder_decoder.eval().requires_grad_(False) + decoder_encoder = DVAEEncoder(**get_encoder_config(decoder_decoder.decoder)).to( + device=dataset.device + ) + decoder_encoder.eval().requires_grad_(False) + + dvae_decoder = chat.pretrain_models["dvae"] + dvae_decoder.eval().requires_grad_(False) + dvae_encoder = DVAEEncoder(**get_encoder_config(dvae_decoder.decoder)).to( + device=dataset.device + ) + dvae_encoder.eval().requires_grad_(False) + + if speaker_embeds is None: + speaker_embeds = { + speaker: torch.randn( + 768, + device=dataset.device, + requires_grad=True, + ) + for speaker in dataset.speakers + } + for speaker_embed in speaker_embeds.values(): + std, mean = chat.pretrain_models["spk_stat"].chunk(2) + speaker_embed.data = speaker_embed.data * std + mean + + SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]") + AUDIO_EOS_TOKEN_ID = 0 + AUDIO_PAD_TOKEN_ID = AUDIO_EOS_TOKEN_ID + + optimizer = torch.optim.Adam( + speaker_embeds.values(), lr=1e-2, weight_decay=0, betas=[0.9, 0.95], eps=1e-5 + ) + loss_fn = torch.nn.CrossEntropyLoss() + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, 1e-7) + + loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=AudioCollator(text_pad=tokenizer.pad_token_id), + ) + logger = MetricLogger() + logger.create_meters(loss=None, mse_loss=None, audio_loss=None, text_loss=None) + + for _epoch in range(epochs): + _epoch += 1 + logger.reset() + header = "{blue_light}{0}: {1}{reset}".format( + "Epoch", output_iter(_epoch, epochs), **ansi + ) + header = header.ljust(max(len("Epoch"), 30) + get_ansi_len(header)) + iterator = logger.log_every(loader, header=header, tqdm_header="Batch") + + for batch in iterator: + speakers = batch["speaker"] + text_input_ids = batch["text_input_ids"] + text_attention_mask = batch["text_attention_mask"] + audio_mel_specs = batch["audio_mel_specs"] + audio_attention_mask = batch["audio_attention_mask"] + + batch_size, text_len = text_attention_mask.size() + + dvae_audio_latents = dvae_encoder(audio_mel_specs, audio_attention_mask) + _, dvae_audio_input_ids = quantize( + dvae_decoder.vq_layer.quantizer, dvae_audio_latents + ) + dvae_audio_input_ids[~audio_attention_mask.bool()] = AUDIO_PAD_TOKEN_ID + + extended_audio_attention_mask = torch.cat( + [ + audio_attention_mask, + torch.zeros( + (batch_size, 1), + dtype=audio_attention_mask.dtype, + device=audio_attention_mask.device, + ), + ], + dim=1, + ) + extended_audio_input_ids = torch.cat( + [ + dvae_audio_input_ids, + AUDIO_PAD_TOKEN_ID + * torch.ones( + (batch_size, 1, gpt.num_vq), + dtype=dvae_audio_input_ids.dtype, + device=dvae_audio_input_ids.device, + ), + ], + dim=1, + ) + indices = audio_attention_mask.int().sum(dim=1) + for i in range(batch_size): + extended_audio_attention_mask[i, indices[i]] = 1 + extended_audio_input_ids[i, indices[i]] = AUDIO_EOS_TOKEN_ID + + input_ids = torch.cat( + [ + text_input_ids.unsqueeze(-1).repeat(1, 1, gpt.num_vq), + extended_audio_input_ids, + ], + dim=1, + ) + attention_mask = torch.cat( + [text_attention_mask, extended_audio_attention_mask], dim=1 + ) + text_mask = torch.cat( + [ + torch.ones_like(text_attention_mask, dtype=bool), + torch.zeros_like(extended_audio_attention_mask, dtype=bool), + ], + dim=1, + ) + + labels = input_ids.clone() + labels[~attention_mask.bool()] = IGNORE_TOKEN_ID + + inputs_embeds = gpt.get_emb(input_ids=input_ids, text_mask=text_mask) + + indices = torch.all(input_ids == SPEAKER_TOKEN_ID, dim=-1) + for i, speaker in enumerate(speakers): + inputs_embeds[i, indices[i]] = F.normalize( + speaker_embeds[speaker].to(dtype=inputs_embeds.dtype), + p=2.0, + dim=-1, + eps=1e-12, + ).unsqueeze(0) + outputs = gpt.gpt.forward( + inputs_embeds=inputs_embeds, attention_mask=attention_mask + ) + hidden_states = outputs.last_hidden_state + text_hidden_states = hidden_states[:, : text_len - 1] + audio_hidden_states = hidden_states[:, text_len - 1 : -1] + + audio_logits = torch.stack( + [gpt.head_code[i](audio_hidden_states) for i in range(gpt.num_vq)], + dim=2, + ) + audio_loss = loss_fn( + audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2) + ) + loss = audio_loss + if train_text: + text_logits = gpt.head_text(text_hidden_states) + text_loss = loss_fn( + text_logits.flatten(0, 1), labels[:, 1:text_len, 0].flatten(0, 1) + ) + loss += text_loss + logger.meters["text_loss"].update(text_loss.item(), n=batch_size) + + gpt_gen_mel_specs = decoder_decoder( + audio_hidden_states[:, :-1].transpose(1, 2) + ).transpose(1, 2) + mse_loss = torch.nn.functional.mse_loss(gpt_gen_mel_specs, audio_mel_specs) + loss += 0.01 * mse_loss + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(speaker_embeds.values(), 1.0) + optimizer.step() + logger.meters["loss"].update(loss.item(), n=batch_size) + logger.meters["mse_loss"].update(mse_loss.item(), n=batch_size) + logger.meters["audio_loss"].update(audio_loss.item(), n=batch_size) + lr_scheduler.step() + optimizer.zero_grad() + return speaker_embeds + + +if __name__ == "__main__": + import argparse + import os + import numpy as np + import pathlib + from modules.models import load_chat_tts + from modules.devices import devices + from modules import config + from modules.speaker import Speaker + + config.runtime_env_vars.no_half = True + devices.reset_device() + + parser = argparse.ArgumentParser() + parser.add_argument("--save_folder", type=str, default="./") + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--train_text", action="store_true", help="train text loss") + # 初始化 speaker + parser.add_argument("--init_speaker", type=str) + parser.add_argument( + "--data_path", + type=str, + default="datasets/data_speaker_a/speaker_a.list", + help="the data_path to json/list file", + ) + parser.add_argument("--tar_path", type=str, help="the tarball path with wavs") + parser.add_argument( + "--tar_in_memory", action="store_true", help="load tarball in memory" + ) + + args = parser.parse_args() + + data_path: str = args.data_path + tar_path: str | None = args.tar_path + tar_in_memory: bool = args.tar_in_memory + train_text: bool = args.train_text + # gpt_lora: bool = args.gpt_lora + # gpt_kbit: int = args.gpt_kbit + save_folder: str = args.save_folder + batch_size: int = args.batch_size + epochs: int = args.epochs + init_speaker: str = args.init_speaker + + speaker_embeds_save_path = os.path.join(save_folder, "speaker_embeds.npz") + + chat = load_chat_tts() + dataset = XzListTar( + root=data_path, + tokenizer=chat.pretrain_models["tokenizer"], + vocos_model=chat.pretrain_models["vocos"], + tar_path=tar_path, + tar_in_memory=tar_in_memory, + device=devices.device, + # speakers=None, # set(['speaker_A', 'speaker_B']) + ) + + print("len(dataset)", len(dataset)) + + speaker_embeds = None + if init_speaker: + spk: Speaker = Speaker.from_file(init_speaker) + speaker_embeds = { + speaker: torch.tensor( + spk.emb, + device=devices.device, + requires_grad=True, + ) + for speaker in dataset.speakers + } + + speaker_embeds = train_speaker_embeddings( + chat, + dataset, + chat.pretrain_models["gpt"], + batch_size=batch_size, + epochs=epochs, + train_text=train_text, + speaker_embeds=speaker_embeds, + ) + speaker_outs = { + speaker: Speaker(speaker_embed.detach().cpu(), f"ep{epochs}_{speaker}") + for speaker, speaker_embed in speaker_embeds.items() + } + time_str = np.datetime_as_string(np.datetime64("now", "s")) + time_str = time_str.replace(":", "_").replace(" ", "_").replace("-", "_") + for speaker, speaker_out in speaker_outs.items(): + torch.save( + speaker_out, + pathlib.Path(save_folder) / f"spk_{speaker}_{time_str}_ep{epochs}.pt", + ) + +# example +""" +python -m modules.finetune.train_speaker \ + --data_path datasets/data_speaker_a/speaker_a.list \ + --save_folder ./data \ + --init_speaker ./data/speakers/Bob.pt \ + --epochs 100 \ + --batch_size 6 \ + --train_text +""" diff --git a/modules/finetune/utils/__init__.py b/modules/finetune/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/finetune/utils/dataset.py b/modules/finetune/utils/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e5d86dae96de753e9c648b62329844958c3a33eb --- /dev/null +++ b/modules/finetune/utils/dataset.py @@ -0,0 +1,487 @@ +import os +import functools +import json +import tarfile +import io +import logging +import abc +import typing + +import torch.utils.data +import torchaudio +from torchvision.datasets.utils import download_url +import transformers +import vocos + +from modules.ChatTTS.ChatTTS.utils.infer_utils import ( + count_invalid_characters, + apply_character_map, +) + + +class LazyDataType(typing.TypedDict): + filepath: str + speaker: str + lang: str + text: str + + +class DataType(LazyDataType): + text_input_ids: torch.Tensor # (batch_size, text_len) + text_attention_mask: torch.Tensor # (batch_size, text_len) + audio_mel_specs: torch.Tensor # (batch_size, audio_len*2, 100) + audio_attention_mask: torch.Tensor # (batch_size, audio_len) + + +class XzListTarKwargsType(typing.TypedDict): + tokenizer: typing.Union[transformers.PreTrainedTokenizer, None] + vocos_model: typing.Union[vocos.Vocos, None] + device: typing.Union[str, torch.device, None] + speakers: typing.Union[typing.Iterable[str], None] + sample_rate: typing.Union[int] + default_speaker: typing.Union[str, None] + default_lang: typing.Union[str, None] + tar_in_memory: typing.Union[bool, None] + process_ahead: typing.Union[bool, None] + + +class AudioFolder(torch.utils.data.Dataset, abc.ABC): + def __init__( + self, + root: str | io.BytesIO, + tokenizer: transformers.PreTrainedTokenizer | None = None, + vocos_model: vocos.Vocos | None = None, + device: str | torch.device | None = None, + speakers: typing.Iterable[str] | None = None, + sample_rate: int = 24_000, + default_speaker: str | None = None, + default_lang: str | None = None, + tar_path: str | None = None, + tar_in_memory: bool = False, + process_ahead: bool = False, + ) -> None: + self.root = root + self.sample_rate = sample_rate + self.default_speaker = default_speaker + self.default_lang = default_lang + + self.logger = logging.getLogger(__name__) + self.normalizer = {} + + self.tokenizer = tokenizer + self.vocos = vocos_model + self.vocos_device = ( + None if self.vocos is None else next(self.vocos.parameters()).device + ) + self.device = device or self.vocos_device + + # tar -cvf ../Xz.tar * + # tar -xf Xz.tar -C ./Xz + self.tar_path = tar_path + self.tar_file = None + self.tar_io = None + if tar_path is not None: + if tar_in_memory: + with open(tar_path, "rb") as f: + self.tar_io = io.BytesIO(f.read()) + self.tar_file = tarfile.open(fileobj=self.tar_io) + else: + self.tar_file = tarfile.open(tar_path) + + self.lazy_data, self.speakers = self.get_lazy_data(root, speakers) + + self.text_input_ids: dict[int, torch.Tensor] = {} + self.audio_mel_specs: dict[int, torch.Tensor] = {} + if process_ahead: + for n, item in enumerate(self.lazy_data): + self.audio_mel_specs[n] = self.preprocess_audio(item["filepath"]) + self.text_input_ids[n] = self.preprocess_text( + item["text"], item["lang"] + ) + if self.tar_file is not None: + self.tar_file.close() + if self.tar_io is not None: + self.tar_io.close() + + @abc.abstractmethod + def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]: ... + + @staticmethod + @abc.abstractmethod + def save_config( + save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./" + ) -> None: ... + + def __len__(self): + return len(self.lazy_data) + + def __getitem__(self, n: int) -> DataType: + lazy_data = self.lazy_data[n] + if n in self.audio_mel_specs: + audio_mel_specs = self.audio_mel_specs[n] + text_input_ids = self.text_input_ids[n] + else: + audio_mel_specs = self.preprocess_audio(lazy_data["filepath"]) + text_input_ids = self.preprocess_text(lazy_data["text"], lazy_data["lang"]) + self.audio_mel_specs[n] = audio_mel_specs + self.text_input_ids[n] = text_input_ids + if len(self.audio_mel_specs) == len(self.lazy_data): + if self.tar_file is not None: + self.tar_file.close() + if self.tar_io is not None: + self.tar_io.close() + text_attention_mask = torch.ones( + len(text_input_ids), device=text_input_ids.device + ) + audio_attention_mask = torch.ones( + (len(audio_mel_specs) + 1) // 2, + device=audio_mel_specs.device, + ) + return { + "filepath": lazy_data["filepath"], + "speaker": lazy_data["speaker"], + "lang": lazy_data["lang"], + "text": lazy_data["text"], + "text_input_ids": text_input_ids, + "text_attention_mask": text_attention_mask, + "audio_mel_specs": audio_mel_specs, + "audio_attention_mask": audio_attention_mask, + } + + def get_lazy_data( + self, + root: str | io.BytesIO, + speakers: typing.Iterable[str] | None = None, + ) -> tuple[list[LazyDataType], set[str]]: + if speakers is not None: + new_speakers = set(speakers) + else: + new_speakers = set() + lazy_data = [] + + raw_data = self.get_raw_data(root) + folder_path = os.path.dirname(root) if isinstance(root, str) else "" + for item in raw_data: + if "speaker" not in item: + item["speaker"] = self.default_speaker + if "lang" not in item: + item["lang"] = self.default_lang + + if speakers is not None and item["speaker"] not in speakers: + continue + if speakers is None and item["speaker"] not in new_speakers: + new_speakers.add(item["speaker"]) + if self.tar_file is None and isinstance(root, str): + filepath = os.path.join(folder_path, item["filepath"]) + else: + filepath = item["filepath"] + lazy_data.append( + { + "filepath": filepath, + "speaker": item["speaker"], + "lang": item["lang"].lower(), + "text": item["text"], + } + ) + return lazy_data, new_speakers + + def preprocess_text( + self, + text: str, + lang: str, + ) -> torch.Tensor: + invalid_characters = count_invalid_characters(text) + if len(invalid_characters): + # self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}') + text = apply_character_map(text) + + # if not skip_refine_text: + # text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids'] + # text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens] + # text = self.pretrain_models['tokenizer'].batch_decode(text_tokens) + # if refine_text_only: + # return text + + text = f"[Stts][spk_emb]{text}[Ptts]" + # text = f'[Stts][empty_spk]{text}[Ptts]' + + text_token = self.tokenizer( + text, return_tensors="pt", add_special_tokens=False + ).to(device=self.device) + return text_token["input_ids"].squeeze(0) + + def preprocess_audio(self, filepath: str) -> torch.Tensor: + if self.tar_file is not None: + file = self.tar_file.extractfile(filepath) + waveform, sample_rate = torchaudio.load(file) + else: + waveform, sample_rate = torchaudio.load(filepath) + waveform = waveform.to(device=self.vocos_device) + if sample_rate != self.sample_rate: + waveform = torchaudio.functional.resample( + waveform, + orig_freq=sample_rate, + new_freq=self.sample_rate, + ) + mel_spec: torch.Tensor = self.vocos.feature_extractor(waveform) + return ( + mel_spec.to(device=self.device).squeeze(0).transpose(0, 1) + ) # (audio_len*2, 100) + + +class JsonFolder(AudioFolder): + """ + In json file, each item is formatted as following example: + `{"filepath": "path/to/file.wav", "speaker": "John", "lang": "ZH", "text": "Hello"}`. + + filepath is relative to the dirname of root json file. + """ + + def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]: + with open(root, "r", encoding="utf-8") as f: + raw_data = json.load(f) + return raw_data + + @staticmethod + def save_config( + save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./" + ) -> None: + save_data = [item.copy() for item in lazy_data] + for item in save_data: + item["filepath"] = os.path.relpath(item["filepath"], rel_path) + with open(save_path, "w", encoding="utf-8") as f: + json.dump(save_data, f, ensure_ascii=False, indent=4) + + +class ListFolder(AudioFolder): + """ + In list file, each row is formatted as `filepath|speaker|lang|text` with `|` as separator. + `path/to/file.wav|John|ZH|Hello`. + + filepath is relative to the dirname of root list file. + """ + + def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]: + raw_data = [] + with open(root, "r", encoding="utf-8") as f: + for line in f.readlines(): + line = line.strip().removesuffix("\n") + if len(line) == 0: + continue + filepath, speaker, lang, text = line.split(sep="|", maxsplit=3) + raw_data.append( + { + "text": text, + "filepath": filepath, + "speaker": speaker, + "lang": lang, + } + ) + return raw_data + + @staticmethod + def save_config( + save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./" + ) -> None: + save_data = [item.copy() for item in lazy_data] + for item in save_data: + item["filepath"] = os.path.relpath(item["filepath"], rel_path) + with open(save_path, "w", encoding="utf-8") as f: + for item in save_data: + f.write( + f"{item['filepath']}|{item['speaker']}|{item['lang']}|{item['text']}\n" + ) + + +class XzListTar(ListFolder): + def __init__( + self, + *args, + root: str | io.BytesIO, + tar_path: str | None = None, + **kwargs, + ): + if isinstance(root, io.BytesIO): + assert tar_path is not None + else: + # make sure root is a list file + if not root.endswith(".list"): # folder case + if os.path.isfile(root): + raise FileExistsError(f"{root} is a file!") + elif not os.path.exists(root): + os.makedirs(root) + root = os.path.join(root, "all.list") + if isinstance(root, str) and not os.path.isfile(root): + # prepare all.list + self.concat_dataset( + save_folder=os.path.dirname(root), + langs=kwargs.get("langs", ["zh", "en"]), + ) + + super().__init__(root, *args, tar_path=tar_path, **kwargs) + + def concat_dataset( + self, save_folder: str | None = None, langs: list[str] = ["zh", "en"] + ) -> None: + if save_folder is None: + save_folder = os.path.dirname(self.root) + if os.path.isfile(save_folder): + raise FileExistsError(f"{save_folder} already exists as a file!") + elif not os.path.exists(save_folder): + os.makedirs(save_folder) + lazy_data = [] + + for member in self.tar_file.getmembers(): + if not member.isfile(): + continue + if member.name.endswith(".list"): + print(member.name) + root_io = self.tar_file.extractfile(member) + lazy_data += ListFolder(root_io).lazy_data + if member.name.endswith(".json"): + print(member.name) + root_io = self.tar_file.extractfile(member) + lazy_data += JsonFolder(root_io).lazy_data + if langs is not None: + lazy_data = [item for item in lazy_data if item["lang"] in langs] + ListFolder.save_config(os.path.join(save_folder, "all.list"), lazy_data) + JsonFolder.save_config(os.path.join(save_folder, "all.json"), lazy_data) + print(f"all.list and all.json are saved to {save_folder}") + + +class XzListFolder(ListFolder): + """ + [Xz乔希](https://space.bilibili.com/5859321) + + Only look at the basename of filepath in list file. Previous folder paths are ignored. + Files are organized as `[list basename]/[file basename]` + + Example tree structure: + + [folder] + ├── speaker_A + │ ├── 1.wav + │ └── 2.wav + ├── speaker_A.list + ├── speaker_B + │ ├── 1.wav + │ └── 2.wav + └── speaker_B.list + """ + + def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]: + raw_data = super().get_raw_data(root) + for item in raw_data: + item["filepath"] = os.path.join( + os.path.basename(root).removesuffix(".list"), + os.path.basename(item["filepath"]), + ) + return raw_data + + +class AudioCollator: + def __init__(self, text_pad: int = 0, audio_pad: int = 0): + self.text_pad = text_pad + self.audio_pad = audio_pad + + def __call__(self, batch: list[DataType]): + batch = [x for x in batch if x is not None] + + audio_maxlen = max(len(item["audio_attention_mask"]) for item in batch) + text_maxlen = max(len(item["text_attention_mask"]) for item in batch) + + filepath = [] + speaker = [] + lang = [] + text = [] + text_input_ids = [] + text_attention_mask = [] + audio_mel_specs = [] + audio_attention_mask = [] + + for x in batch: + filepath.append(x["filepath"]) + speaker.append(x["speaker"]) + lang.append(x["lang"]) + text.append(x["text"]) + text_input_ids.append( + torch.nn.functional.pad( + x["text_input_ids"], + (text_maxlen - len(x["text_input_ids"]), 0), + value=self.text_pad, + ) + ) + text_attention_mask.append( + torch.nn.functional.pad( + x["text_attention_mask"], + (text_maxlen - len(x["text_attention_mask"]), 0), + value=0, + ) + ) + audio_mel_specs.append( + torch.nn.functional.pad( + x["audio_mel_specs"], + (0, 0, 0, audio_maxlen * 2 - len(x["audio_mel_specs"])), + value=self.audio_pad, + ) + ) + audio_attention_mask.append( + torch.nn.functional.pad( + x["audio_attention_mask"], + (0, audio_maxlen - len(x["audio_attention_mask"])), + value=0, + ) + ) + return { + "filepath": filepath, + "speaker": speaker, + "lang": lang, + "text": text, + "text_input_ids": torch.stack(text_input_ids), + "text_attention_mask": torch.stack(text_attention_mask), + "audio_mel_specs": torch.stack(audio_mel_specs), + "audio_attention_mask": torch.stack(audio_attention_mask), + } + + +def formalize_xz_list(src_folder: str): + for root, _, files in os.walk(src_folder): + for file in files: + if file.endswith(".list"): + filepath = os.path.join(root, file) + print(filepath) + lazy_data = XzListFolder(filepath).lazy_data + XzListFolder.save_config(filepath, lazy_data, rel_path=src_folder) + + +def concat_dataset( + src_folder: str, save_folder: str | None = None, langs: list[str] = ["zh", "en"] +) -> None: + if save_folder is None: + save_folder = src_folder + if os.path.isfile(save_folder): + raise FileExistsError(f"{save_folder} already exists as a file!") + elif not os.path.exists(save_folder): + os.makedirs(save_folder) + lazy_data = [] + same_folder = os.path.samefile(src_folder, save_folder) + for root, _, files in os.walk(src_folder): + for file in files: + filepath = os.path.join(root, file) + if same_folder and file in ("all.list", "all.json"): + continue + if file.endswith(".list"): + print(filepath) + lazy_data += ListFolder(filepath).lazy_data + if file.endswith(".json"): + print(filepath) + lazy_data += JsonFolder(filepath).lazy_data + if langs is not None: + lazy_data = [item for item in lazy_data if item["lang"] in langs] + ListFolder.save_config( + os.path.join(save_folder, "all.list"), lazy_data, rel_path=save_folder + ) + JsonFolder.save_config( + os.path.join(save_folder, "all.json"), lazy_data, rel_path=save_folder + ) + print(f"all.list and all.json are saved to {save_folder}") diff --git a/modules/finetune/utils/logger.py b/modules/finetune/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..7e59b5cfd272cfb4294822924aba108e143d7310 --- /dev/null +++ b/modules/finetune/utils/logger.py @@ -0,0 +1,409 @@ +#!/usr/bin/env python3 + +import statistics +import time +from collections import defaultdict, deque +from tqdm import tqdm as tqdm_class + +from typing import Generator, Iterable, TypeVar +from typing_extensions import Self + +import torch +import torch.distributed as dist + +from .output import ansi, prints, get_ansi_len + +__all__ = ["SmoothedValue", "MetricLogger"] + +MB = 1 << 20 +T = TypeVar("T") + + +class SmoothedValue: + r"""Track a series of values and provide access to smoothed values over a + window or the global series average. + + See Also: + https://github.com/pytorch/vision/blob/main/references/classification/utils.py + + Args: + name (str): Name string. + window_size (int): The :attr:`maxlen` of :class:`~collections.deque`. + fmt (str): The format pattern of ``str(self)``. + + Attributes: + name (str): Name string. + fmt (str): The string pattern. + deque (~collections.deque): The unique data series. + count (int): The amount of data. + total (float): The sum of all data. + + median (float): The median of :attr:`deque`. + avg (float): The avg of :attr:`deque`. + global_avg (float): :math:`\frac{\text{total}}{\text{count}}` + max (float): The max of :attr:`deque`. + min (float): The min of :attr:`deque`. + last_value (float): The last value of :attr:`deque`. + """ + + def __init__( + self, name: str = "", window_size: int = None, fmt: str = "{global_avg:.3f}" + ): + self.name = name + self.deque: deque[float] = deque(maxlen=window_size) + self.count: int = 0 + self.total: float = 0.0 + self.fmt = fmt + + def update(self, value: float, n: int = 1) -> Self: + r"""Update :attr:`n` pieces of data with same :attr:`value`. + + .. code-block:: python + + self.deque.append(value) + self.total += value * n + self.count += n + + Args: + value (float): the value to update. + n (int): the number of data with same :attr:`value`. + + Returns: + SmoothedValue: return ``self`` for stream usage. + """ + self.deque.append(value) + self.total += value * n + self.count += n + return self + + def update_list(self, value_list: list[float]) -> Self: + r"""Update :attr:`value_list`. + + .. code-block:: python + + for value in value_list: + self.deque.append(value) + self.total += value + self.count += len(value_list) + + Args: + value_list (list[float]): the value list to update. + + Returns: + SmoothedValue: return ``self`` for stream usage. + """ + for value in value_list: + self.deque.append(value) + self.total += value + self.count += len(value_list) + return self + + def reset(self) -> Self: + r"""Reset ``deque``, ``count`` and ``total`` to be empty. + + Returns: + SmoothedValue: return ``self`` for stream usage. + """ + self.deque = deque(maxlen=self.deque.maxlen) + self.count = 0 + self.total = 0.0 + return self + + def synchronize_between_processes(self): + r""" + Warning: + Does NOT synchronize the deque! + """ + if not (dist.is_available() and dist.is_initialized()): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = float(t[1]) + + @property + def median(self) -> float: + try: + return statistics.median(self.deque) + except Exception: + return 0.0 + + @property + def avg(self) -> float: + try: + return statistics.mean(self.deque) + except Exception: + return 0.0 + + @property + def global_avg(self) -> float: + try: + return self.total / self.count + except Exception: + return 0.0 + + @property + def max(self) -> float: + try: + return max(self.deque) + except Exception: + return 0.0 + + @property + def min(self) -> float: + try: + return min(self.deque) + except Exception: + return 0.0 + + @property + def last_value(self) -> float: + try: + return self.deque[-1] + except Exception: + return 0.0 + + def __str__(self): + return self.fmt.format( + name=self.name, + count=self.count, + total=self.total, + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + min=self.min, + max=self.max, + last_value=self.last_value, + ) + + def __format__(self, format_spec: str) -> str: + return self.__str__() + + +class MetricLogger: + r""" + See Also: + https://github.com/pytorch/vision/blob/main/references/classification/utils.py + + Args: + delimiter (str): The delimiter to join different meter strings. + Defaults to ``''``. + meter_length (int): The minimum length for each meter. + Defaults to ``20``. + tqdm (bool): Whether to use tqdm to show iteration information. + Defaults to ``env['tqdm']``. + indent (int): The space indent for the entire string. + Defaults to ``0``. + + Attributes: + meters (dict[str, SmoothedValue]): The meter dict. + iter_time (SmoothedValue): Iteration time meter. + data_time (SmoothedValue): Data loading time meter. + memory (SmoothedValue): Memory usage meter. + """ + + def __init__( + self, + delimiter: str = "", + meter_length: int = 20, + tqdm: bool = True, + indent: int = 0, + **kwargs, + ): + self.meters: defaultdict[str, SmoothedValue] = defaultdict(SmoothedValue) + self.create_meters(**kwargs) + self.delimiter = delimiter + self.meter_length = meter_length + self.tqdm = tqdm + self.indent = indent + + self.iter_time = SmoothedValue() + self.data_time = SmoothedValue() + self.memory = SmoothedValue(fmt="{max:.0f}") + + def create_meters(self, **kwargs: str) -> Self: + r"""Create meters with specific ``fmt`` in :attr:`self.meters`. + + ``self.meters[meter_name] = SmoothedValue(fmt=fmt)`` + + Args: + **kwargs: ``(meter_name: fmt)`` + + Returns: + MetricLogger: return ``self`` for stream usage. + """ + for k, v in kwargs.items(): + self.meters[k] = SmoothedValue(fmt="{global_avg:.3f}" if v is None else v) + return self + + def update(self, n: int = 1, **kwargs: float) -> Self: + r"""Update values to :attr:`self.meters` by calling :meth:`SmoothedValue.update()`. + + ``self.meters[meter_name].update(float(value), n=n)`` + + Args: + n (int): the number of data with same value. + **kwargs: ``{meter_name: value}``. + + Returns: + MetricLogger: return ``self`` for stream usage. + """ + for k, v in kwargs.items(): + if k not in self.meters: + self.meters[k] = SmoothedValue() + self.meters[k].update(float(v), n=n) + return self + + def update_list(self, **kwargs: list) -> Self: + r"""Update values to :attr:`self.meters` by calling :meth:`SmoothedValue.update_list()`. + + ``self.meters[meter_name].update_list(value_list)`` + + Args: + **kwargs: ``{meter_name: value_list}``. + + Returns: + MetricLogger: return ``self`` for stream usage. + """ + for k, v in kwargs.items(): + self.meters[k].update_list(v) + return self + + def reset(self) -> Self: + r"""Reset meter in :attr:`self.meters` by calling :meth:`SmoothedValue.reset()`. + + Returns: + MetricLogger: return ``self`` for stream usage. + """ + for meter in self.meters.values(): + meter.reset() + return self + + def get_str(self, cut_too_long: bool = True, strip: bool = True, **kwargs) -> str: + r"""Generate formatted string based on keyword arguments. + + ``key: value`` with max length to be :attr:`self.meter_length`. + + Args: + cut_too_long (bool): Whether to cut too long values to first 5 characters. + Defaults to ``True``. + strip (bool): Whether to strip trailing whitespaces. + Defaults to ``True``. + **kwargs: Keyword arguments to generate string. + """ + str_list: list[str] = [] + for k, v in kwargs.items(): + v_str = str(v) + _str: str = "{green}{k}{reset}: {v}".format(k=k, v=v_str, **ansi) + max_length = self.meter_length + get_ansi_len(_str) + if cut_too_long: + _str = _str[:max_length] + str_list.append(_str.ljust(max_length)) + _str = self.delimiter.join(str_list) + if strip: + _str = _str.rstrip() + return _str + + def __getattr__(self, attr: str) -> float: + if attr in self.meters: + return self.meters[attr] + if attr in vars(self): # TODO: use hasattr + return vars(self)[attr] + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, attr) + ) + + def __str__(self) -> str: + return self.get_str(**self.meters) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def log_every( + self, + iterable: Iterable[T], + header: str = "", + tqdm: bool = None, + tqdm_header: str = "Iter", + indent: int = None, + verbose: int = 1, + ) -> Generator[T, None, None]: + r"""Wrap an :class:`collections.abc.Iterable` with formatted outputs. + + * Middle Output: + ``{tqdm_header}: [ current / total ] str(self) {memory} {iter_time} {data_time} {time}<{remaining}`` + * Final Output + ``{header} str(self) {memory} {iter_time} {data_time} {total_time}`` + + Args: + iterable (~collections.abc.Iterable): The raw iterator. + header (str): The header string for final output. + Defaults to ``''``. + tqdm (bool): Whether to use tqdm to show iteration information. + Defaults to ``self.tqdm``. + tqdm_header (str): The header string for middle output. + Defaults to ``'Iter'``. + indent (int): The space indent for the entire string. + if ``None``, use ``self.indent``. + Defaults to ``None``. + verbose (int): The verbose level of output information. + """ + tqdm = tqdm if tqdm is not None else self.tqdm + indent = indent if indent is not None else self.indent + iterator = iterable + if len(header) != 0: + header = header.ljust(30 + get_ansi_len(header)) + if tqdm: + length = len(str(len(iterable))) + pattern: str = ( + "{tqdm_header}: {blue_light}" + "[ {red}{{n_fmt:>{length}}}{blue_light} " + "/ {red}{{total_fmt}}{blue_light} ]{reset}" + ).format(tqdm_header=tqdm_header, length=length, **ansi) + offset = len(f"{{n_fmt:>{length}}}{{total_fmt}}") - 2 * length + pattern = pattern.ljust(30 + offset + get_ansi_len(pattern)) + time_str = self.get_str(time="{elapsed}<{remaining}", cut_too_long=False) + bar_format = f"{pattern}{{desc}}{time_str}" + iterator = tqdm_class(iterable, leave=False, bar_format=bar_format) + + self.iter_time.reset() + self.data_time.reset() + self.memory.reset() + + end = time.time() + start_time = time.time() + for obj in iterator: + cur_data_time = time.time() - end + self.data_time.update(cur_data_time) + yield obj + cur_iter_time = time.time() - end + self.iter_time.update(cur_iter_time) + if torch.cuda.is_available(): + cur_memory = torch.cuda.max_memory_allocated() / MB + self.memory.update(cur_memory) + if tqdm: + _dict = {k: v for k, v in self.meters.items()} + if verbose > 2 and torch.cuda.is_available(): + _dict.update(memory=f"{cur_memory:.0f} MB") + if verbose > 1: + _dict.update( + iter=f"{cur_iter_time:.3f} s", data=f"{cur_data_time:.3f} s" + ) + iterator.set_description_str(self.get_str(**_dict, strip=False)) + end = time.time() + self.synchronize_between_processes() + total_time = time.time() - start_time + total_time_str = tqdm_class.format_interval(total_time) + + _dict = {k: v for k, v in self.meters.items()} + if verbose > 2 and torch.cuda.is_available(): + _dict.update(memory=f"{str(self.memory)} MB") + if verbose > 1: + _dict.update( + iter=f"{str(self.iter_time)} s", data=f"{str(self.data_time)} s" + ) + _dict.update(time=total_time_str) + prints(self.delimiter.join([header, self.get_str(**_dict)]), indent=indent) diff --git a/modules/finetune/utils/model.py b/modules/finetune/utils/model.py new file mode 100644 index 0000000000000000000000000000000000000000..416cb1bcc7084c8d0e065e4de36a75e43ab47fa6 --- /dev/null +++ b/modules/finetune/utils/model.py @@ -0,0 +1,19 @@ +import torch +from einops import rearrange +from vector_quantize_pytorch.residual_fsq import GroupedResidualFSQ + + +def quantize( + quantizer: GroupedResidualFSQ, + audio_latents: torch.Tensor, # (batch_size, audio_len, audio_dim=1024) +) -> tuple[torch.Tensor, torch.Tensor]: + # feat shape (batch_size, audio_len, audio_dim) + # ind shape (GFSQ.G, batch_size, audio_len, GFSQ.R) + # num_vq=GFSQ.G*GFSQ.R + feat, ind = quantizer(audio_latents) + audio_quantized_latents = feat # (batch_size, audio_len, audio_dim) + audio_input_ids = rearrange( # (batch_size, audio_len, num_vq) + ind, + "g b t r ->b t (g r)", + ) + return audio_quantized_latents, audio_input_ids diff --git a/modules/finetune/utils/output.py b/modules/finetune/utils/output.py new file mode 100644 index 0000000000000000000000000000000000000000..541092ddfa1f33848e1d8ff914ffbeab312db44f --- /dev/null +++ b/modules/finetune/utils/output.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 + +import re +import sys +from contextlib import contextmanager + + +class ANSI: + ansi_color = { + "black": "\033[30m", + "red": "\033[31m", + "green": "\033[32m", + "yellow": "\033[33m", + "blue": "\033[34m", + "purple": "\033[35m", + "blue_light": "\033[36m", + "white": "\033[37m", + "reset": "\033[0m", + "upline": "\033[1A", + "clear_line": "\033[2K", + "clear": "\033[2J", + } + ansi_nocolor = { + "black": "", + "red": "", + "green": "", + "yellow": "", + "blue": "", + "purple": "", + "blue_light": "", + "white": "", + "reset": "", + "upline": "\033[1A\033[", + "clear_line": "\033[K", + "clear": "\033[2J", + } + + def __init__(self): + self._dict = ANSI.ansi_color if ("--color" in sys.argv) else ANSI.ansi_nocolor + + def switch(self, color: bool): + self._dict = ANSI.ansi_color if color else ANSI.ansi_nocolor + + def keys(self): + return self._dict.keys() + + def items(self): + return self._dict.items() + + def __getitem__(self, key): + return self._dict[key] + + def __str__(self): + return str(self._dict) + + def __repr__(self): + return repr(self._dict) + + +ansi = ANSI() + + +def remove_ansi(s: str) -> str: + ansi_escape = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]") + return ansi_escape.sub("", s) + + +def get_ansi_len(s: str) -> int: + return len(s) - len(remove_ansi(s)) + + +def prints(*args: str, indent: int = 0, prefix: str = "", **kwargs): + assert indent >= 0 + new_args = [] + for arg in args: + new_args.append(indent_str(str(arg), indent=indent)) + if len(new_args): + new_args[0] = prefix + str(new_args[0]) + print(*new_args, **kwargs) + + +def output_iter(_iter: int, iteration: int = None, iter_len: int = 4) -> str: + if iteration is None: + pattern = "{blue_light}[ {red}{0}{blue_light} ]{reset}" + return pattern.format(str(_iter).rjust(iter_len), **ansi) + else: + iter_str = str(iteration) + length = len(iter_str) + pattern = ( + "{blue_light}[ {red}{0}{blue_light} " "/ {red}{1}{blue_light} ]{reset}" + ) + return pattern.format(str(_iter).rjust(length), iter_str, **ansi) + + +def indent_str(s_: str, indent: int = 0) -> str: + # modified from torch.nn.modules._addindent + if indent > 0 and s_: + s_ = indent * " " + str(s_[:-1]).replace("\n", "\n" + indent * " ") + s_[-1] + return s_ + + +class IndentRedirect: # TODO: inherit TextIOWrapper? + def __init__(self, buffer: bool = True, indent: int = 0): + self.__console__ = sys.stdout + self.indent = indent + self.__buffer: str = None + if buffer: + self.__buffer = "" + + def write(self, text: str, indent: int = None): + indent = indent if indent is not None else self.indent + text = indent_str(text, indent=indent) + if self.__buffer is None: + self.__console__.write(text) + else: + self.__buffer += text + + def flush(self): + if self.__buffer is not None: + self.__console__.write(self.__buffer) + self.__buffer = "" + self.__console__.flush() + + @contextmanager + def __call__(self) -> None: + try: + sys.stdout = self + yield + finally: + sys.stdout = self.__console__ + self.__buffer = "" + + def enable(self): + sys.stdout = self + + def disable(self): + if self.__buffer is not None: + self.__buffer = "" + sys.stdout = self.__console__ + + @property + def buffer(self) -> str: + return self.__buffer + + +redirect = IndentRedirect() diff --git a/modules/generate_audio.py b/modules/generate_audio.py index 9fcabe3954c3912b1f008ed8865ca04c12e18a14..a2e4552b9103d2bb13dd030724f93f740ab7f1b8 100644 --- a/modules/generate_audio.py +++ b/modules/generate_audio.py @@ -76,6 +76,8 @@ def generate_audio_batch( params_infer_code["spk_emb"] = chat_tts.sample_random_speaker() logger.debug(("spk", spk)) elif isinstance(spk, Speaker): + if not isinstance(spk.emb, torch.Tensor): + raise ValueError("spk.pt is broken, please retrain the model.") params_infer_code["spk_emb"] = spk.emb logger.debug(("spk", spk.name)) else: diff --git a/modules/normalization.py b/modules/normalization.py index 1d740e1ca6b914a37deceb515409e088cb5c29d2..cc6e941f78143b2b46b5eb8f886f55f68417c77f 100644 --- a/modules/normalization.py +++ b/modules/normalization.py @@ -120,6 +120,7 @@ character_map = { "~": " ", "~": " ", "/": " ", + "·": " ", } character_to_word = { @@ -282,6 +283,9 @@ def text_normalize(text, is_end=False): if __name__ == "__main__": + from modules.devices import devices + + devices.reset_device() test_cases = [ "ChatTTS是专门为对话场景设计的文本转语音模型,例如LLM助手对话任务。它支持英文和中文两种语言。最大的模型使用了10万小时以上的中英文数据进行训练。在HuggingFace中开源的版本为4万小时训练且未SFT的版本.", " [oral_9] [laugh_0] [break_0] 电 [speed_0] 影 [speed_0] 中 梁朝伟 [speed_9] 扮演的陈永仁的编号27149", @@ -319,6 +323,7 @@ State-of-the-art Machine Learning for PyTorch, TensorFlow, and JAX. """ 120米 有12%的概率会下雨 +埃隆·马斯克 """, ] diff --git a/modules/repos_static/resemble_enhance/data/distorter/base.py b/modules/repos_static/resemble_enhance/data/distorter/base.py index d43d84fa840dd25804d9c5e5dc9673f65fdc5b94..f07ef407fa92190234ead9f7de43d7c5ea3c6b4d 100644 --- a/modules/repos_static/resemble_enhance/data/distorter/base.py +++ b/modules/repos_static/resemble_enhance/data/distorter/base.py @@ -2,6 +2,7 @@ import itertools import os import random import time +from typing import Union import warnings import numpy as np @@ -87,7 +88,7 @@ class Choice(Effect): class Permutation(Effect): - def __init__(self, *effects, n: int | None = None): + def __init__(self, *effects, n: Union[int, None] = None): super().__init__() self.effects = effects self.n = n diff --git a/modules/repos_static/resemble_enhance/data/distorter/custom.py b/modules/repos_static/resemble_enhance/data/distorter/custom.py index 28428f7789cebb2d174c581111711f4d73f6565b..fdabed6aac1647de9a7ee887f84308effa71c8da 100644 --- a/modules/repos_static/resemble_enhance/data/distorter/custom.py +++ b/modules/repos_static/resemble_enhance/data/distorter/custom.py @@ -3,6 +3,7 @@ import random from dataclasses import dataclass from functools import cached_property from pathlib import Path +from typing import Union import librosa import numpy as np @@ -16,7 +17,7 @@ _logger = logging.getLogger(__name__) @dataclass class RandomRIR(Effect): - rir_dir: Path | None + rir_dir: Union[Path, None] rir_rate: int = 44_000 rir_suffix: str = ".npy" deterministic: bool = False @@ -49,7 +50,9 @@ class RandomRIR(Effect): length = len(wav) - wav = librosa.resample(wav, orig_sr=sr, target_sr=self.rir_rate, res_type="kaiser_fast") + wav = librosa.resample( + wav, orig_sr=sr, target_sr=self.rir_rate, res_type="kaiser_fast" + ) rir = self._sample_rir() wav = signal.convolve(wav, rir, mode="same") @@ -58,7 +61,9 @@ class RandomRIR(Effect): if actlev > 0.99: wav = (wav / actlev) * 0.98 - wav = librosa.resample(wav, orig_sr=self.rir_rate, target_sr=sr, res_type="kaiser_fast") + wav = librosa.resample( + wav, orig_sr=self.rir_rate, target_sr=sr, res_type="kaiser_fast" + ) if abs(length - len(wav)) > 10: _logger.warning(f"length mismatch: {length} vs {len(wav)}") diff --git a/modules/repos_static/resemble_enhance/data/distorter/sox.py b/modules/repos_static/resemble_enhance/data/distorter/sox.py index 92a2d74033d33b975141c1ece7ac5619d1dfcc39..3e08376087683222dd5db98f4c4b25ad0e38b847 100644 --- a/modules/repos_static/resemble_enhance/data/distorter/sox.py +++ b/modules/repos_static/resemble_enhance/data/distorter/sox.py @@ -1,6 +1,7 @@ import logging import os import random +from typing import Union import warnings from functools import partial @@ -29,7 +30,9 @@ class AttachableEffect(Effect): chain = augment.EffectChain() chain = self.attach(chain) tensor = torch.from_numpy(wav)[None].float() # (1, T) - tensor = chain.apply(tensor, src_info={"rate": sr}, target_info={"channels": 1, "rate": sr}) + tensor = chain.apply( + tensor, src_info={"rate": sr}, target_info={"channels": 1, "rate": sr} + ) wav = tensor.numpy()[0] # (T,) return wav @@ -41,7 +44,9 @@ class SoxEffect(AttachableEffect): self.kwargs = kwargs def attach(self, chain: augment.EffectChain) -> augment.EffectChain: - _logger.debug(f"Attaching {self.effect_name} with {self.args} and {self.kwargs}") + _logger.debug( + f"Attaching {self.effect_name} with {self.args} and {self.kwargs}" + ) if not hasattr(chain, self.effect_name): raise ValueError(f"EffectChain has no attribute {self.effect_name}") return getattr(chain, self.effect_name)(*self.args, **self.kwargs) @@ -115,21 +120,30 @@ class Randint(Generator): class Concat(Generator): - def __init__(self, *parts: Generator | str): + def __init__(self, *parts: Union[Generator, str]): self.parts = parts def __call__(self): - return "".join([part if isinstance(part, str) else part() for part in self.parts]) + return "".join( + [part if isinstance(part, str) else part() for part in self.parts] + ) class RandomLowpassDistorter(SoxEffect): def __init__(self, low=2000, high=16000): - super().__init__("sinc", "-n", Randint(50, 200), Concat("-", Uniform(low, high))) + super().__init__( + "sinc", "-n", Randint(50, 200), Concat("-", Uniform(low, high)) + ) class RandomBandpassDistorter(SoxEffect): def __init__(self, low=100, high=1000, min_width=2000, max_width=4000): - super().__init__("sinc", "-n", Randint(50, 200), partial(self._fn, low, high, min_width, max_width)) + super().__init__( + "sinc", + "-n", + Randint(50, 200), + partial(self._fn, low, high, min_width, max_width), + ) @staticmethod def _fn(low, high, min_width, max_width): @@ -139,7 +153,15 @@ class RandomBandpassDistorter(SoxEffect): class RandomEqualizer(SoxEffect): - def __init__(self, low=100, high=4000, q_low=1, q_high=5, db_low: int = -30, db_high: int = 30): + def __init__( + self, + low=100, + high=4000, + q_low=1, + q_high=5, + db_low: int = -30, + db_high: int = 30, + ): super().__init__( "equalizer", Uniform(low, high), @@ -150,7 +172,9 @@ class RandomEqualizer(SoxEffect): class RandomOverdrive(SoxEffect): def __init__(self, gain_low=5, gain_high=40, colour_low=20, colour_high=80): - super().__init__("overdrive", Uniform(gain_low, gain_high), Uniform(colour_low, colour_high)) + super().__init__( + "overdrive", Uniform(gain_low, gain_high), Uniform(colour_low, colour_high) + ) class RandomReverb(Chain): diff --git a/modules/repos_static/resemble_enhance/data/utils.py b/modules/repos_static/resemble_enhance/data/utils.py index 77f59d345b75cac76c6c423c734ae9c70a626590..38ca25fe36074d615962dc229599cf1b3a548aaa 100644 --- a/modules/repos_static/resemble_enhance/data/utils.py +++ b/modules/repos_static/resemble_enhance/data/utils.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Callable +from typing import Callable, Union from torch import Tensor @@ -16,7 +16,9 @@ def rglob_audio_files(path: Path): return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac")) -def mix_fg_bg(fg: Tensor, bg: Tensor, alpha: float | Callable[..., float] = 0.5, eps=1e-7): +def mix_fg_bg( + fg: Tensor, bg: Tensor, alpha: Union[float, Callable[..., float]] = 0.5, eps=1e-7 +): """ Args: fg: (b, t) diff --git a/modules/repos_static/resemble_enhance/denoiser/denoiser.py b/modules/repos_static/resemble_enhance/denoiser/denoiser.py index c0d9c2b6ffbc471029cee620216a2d080b9dd100..4d672df3431d1877d3c8cb882aa8606d6d8b5d1f 100644 --- a/modules/repos_static/resemble_enhance/denoiser/denoiser.py +++ b/modules/repos_static/resemble_enhance/denoiser/denoiser.py @@ -1,4 +1,5 @@ import logging +from typing import Union import torch import torch.nn.functional as F @@ -154,7 +155,7 @@ class Denoiser(nn.Module): sep_sin = sin * cos_res + cos * sin_res return sep_mag, sep_cos, sep_sin - def forward(self, x: Tensor, y: Tensor | None = None): + def forward(self, x: Tensor, y: Union[Tensor, None] = None): """ Args: x: (b t), a mixed audio diff --git a/modules/repos_static/resemble_enhance/enhancer/download.py b/modules/repos_static/resemble_enhance/enhancer/download.py index 614b9a4b4f9a1a10b79f12ca1a25821247ea2a16..089181893229ba67c9202e204f994d512975f9fc 100644 --- a/modules/repos_static/resemble_enhance/enhancer/download.py +++ b/modules/repos_static/resemble_enhance/enhancer/download.py @@ -1,5 +1,6 @@ import logging from pathlib import Path +from typing import Union import torch @@ -12,14 +13,18 @@ def get_source_url(relpath): return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true" -def get_target_path(relpath: str | Path, run_dir: str | Path | None = None): +def get_target_path(relpath: Union[str, Path], run_dir: Union[str, Path, None] = None): if run_dir is None: run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME return Path(run_dir) / relpath -def download(run_dir: str | Path | None = None): - relpaths = ["hparams.yaml", "ds/G/latest", "ds/G/default/mp_rank_00_model_states.pt"] +def download(run_dir: Union[str, Path, None] = None): + relpaths = [ + "hparams.yaml", + "ds/G/latest", + "ds/G/default/mp_rank_00_model_states.pt", + ] for relpath in relpaths: path = get_target_path(relpath, run_dir=run_dir) if path.exists(): diff --git a/modules/repos_static/resemble_enhance/enhancer/enhancer.py b/modules/repos_static/resemble_enhance/enhancer/enhancer.py index c7ab9417deb429855b7fce43962426f6b6c4a9c0..1ea3f351752d8e8e13040fef842372367926c3e4 100644 --- a/modules/repos_static/resemble_enhance/enhancer/enhancer.py +++ b/modules/repos_static/resemble_enhance/enhancer/enhancer.py @@ -1,4 +1,5 @@ import logging +from typing import Union import matplotlib.pyplot as plt import pandas as pd @@ -109,7 +110,7 @@ class Enhancer(nn.Module): return self.mel_fn(x)[..., :-1] # (b d t) return self.mel_fn(x) - def _may_denoise(self, x: Tensor, y: Tensor | None = None): + def _may_denoise(self, x: Tensor, y: Union[Tensor, None] = None): if self.hp.lcfm_training_mode == "cfm": return self.denoiser(x, y) return x @@ -126,7 +127,9 @@ class Enhancer(nn.Module): self.lcfm.eval_tau_(tau) self._eval_lambd = lambd - def forward(self, x: Tensor, y: Tensor | None = None, z: Tensor | None = None): + def forward( + self, x: Tensor, y: Union[Tensor, None] = None, z: Union[Tensor, None] = None + ): """ Args: x: (b t), mix wavs (fg + bg) diff --git a/modules/repos_static/resemble_enhance/enhancer/hparams.py b/modules/repos_static/resemble_enhance/enhancer/hparams.py index ca89bea6f5d7d4ec4f543f8bde88b29dcae69f6a..7878e4172b5772c59aea0de54d2537f0523d9437 100644 --- a/modules/repos_static/resemble_enhance/enhancer/hparams.py +++ b/modules/repos_static/resemble_enhance/enhancer/hparams.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from pathlib import Path +from typing import Union from ..hparams import HParams as HParamsBase @@ -17,7 +18,7 @@ class HParams(HParamsBase): vocoder_extra_dim: int = 32 - gan_training_start_step: int | None = 5_000 - enhancer_stage1_run_dir: Path | None = None + gan_training_start_step: Union[int, None] = 5_000 + enhancer_stage1_run_dir: Union[Path, None] = None - denoiser_run_dir: Path | None = None + denoiser_run_dir: Union[Path, None] = None diff --git a/modules/repos_static/resemble_enhance/enhancer/inference.py b/modules/repos_static/resemble_enhance/enhancer/inference.py index af57a2c7d3e5cc7b08b00f85f0135e881e50fcbe..dc7712cb6a4d2126bb4d740d24ed9355312741ef 100644 --- a/modules/repos_static/resemble_enhance/enhancer/inference.py +++ b/modules/repos_static/resemble_enhance/enhancer/inference.py @@ -1,6 +1,7 @@ import logging from functools import cache from pathlib import Path +from typing import Union import torch @@ -13,7 +14,7 @@ logger = logging.getLogger(__name__) @cache -def load_enhancer(run_dir: str | Path | None, device): +def load_enhancer(run_dir: Union[str, Path, None], device): run_dir = download(run_dir) hp = HParams.load(run_dir) enhancer = Enhancer(hp) diff --git a/modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py b/modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py index a5125267b7f32e11c58e4b96bffa3ba1e96fdc4f..09b4a3e45ce3b50cca7ce7debe77ddb230ee9783 100644 --- a/modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py +++ b/modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py @@ -1,7 +1,7 @@ import logging from dataclasses import dataclass from functools import partial -from typing import Protocol +from typing import Protocol, Union import matplotlib.pyplot as plt import numpy as np @@ -17,8 +17,7 @@ logger = logging.getLogger(__name__) class VelocityField(Protocol): - def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor: - ... + def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor: ... class Solver: @@ -40,7 +39,9 @@ class Solver: self._camera = None self._mel_fn = mel_fn - self._time_mapping = partial(self.exponential_decay_mapping, n=time_mapping_divisor) + self._time_mapping = partial( + self.exponential_decay_mapping, n=time_mapping_divisor + ) def configurate_(self, nfe=None, method=None): if nfe is None: @@ -50,7 +51,9 @@ class Solver: method = self.method if nfe == 1 and method in ("midpoint", "rk4"): - logger.warning(f"1 NFE is not supported for {method}, using euler method instead.") + logger.warning( + f"1 NFE is not supported for {method}, using euler method instead." + ) method = "euler" self.nfe = nfe @@ -105,7 +108,9 @@ class Solver: ) else: # Spectrogram, b c t - plt.imshow(ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none") + plt.imshow( + ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none" + ) ax = plt.gca() ax.text(0.5, 1.01, f"t={t:.2f}", transform=ax.transAxes, ha="center") camera.snap() @@ -271,7 +276,7 @@ class CFM(nn.Module): global_dim=self.time_emb_dim, ) - def _perturb(self, ψ1: Tensor, t: Tensor | None = None): + def _perturb(self, ψ1: Tensor, t: Union[Tensor, None] = None): """ Perturb ψ1 to ψt. """ @@ -311,7 +316,7 @@ class CFM(nn.Module): """ return ψ1 - ψ0 - def _to_v(self, *, ψt, x, t: float | Tensor): + def _to_v(self, *, ψt, x, t: Union[float, Tensor]): """ Args: ψt: (b c t) @@ -364,7 +369,13 @@ class CFM(nn.Module): ψ1 = self.solver(f=f, ψ0=ψ0, t0=t0) return ψ1 - def forward(self, x: Tensor, y: Tensor | None = None, ψ0: Tensor | None = None, t0=0.0): + def forward( + self, + x: Tensor, + y: Union[Tensor, None] = None, + ψ0: Union[Tensor, None] = None, + t0=0.0, + ): if y is None: y = self.sample(x, ψ0=ψ0, t0=t0) else: diff --git a/modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py b/modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py index 91f5dbb506187271c67c7bbbf55475021854ab27..aa82827c8809b001d31827d76bbee731e11ae2e2 100644 --- a/modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py +++ b/modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py @@ -1,5 +1,6 @@ import logging from dataclasses import dataclass +from typing import Union import torch.nn as nn import torch.nn.functional as F @@ -14,7 +15,7 @@ logger = logging.getLogger(__name__) @dataclass class IRMAEOutput: latent: Tensor # latent vector - decoded: Tensor | None # decoder output, include extra dim + decoded: Union[Tensor, None] # decoder output, include extra dim class ResBlock(nn.Sequential): diff --git a/modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py b/modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py index 4c2f5f88718e2f42f82e2f4714ea510b4677b450..8d1c241312f96525fcf7630e805560cbe9b84406 100644 --- a/modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py +++ b/modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py @@ -1,5 +1,6 @@ import logging from enum import Enum +from typing import Union import matplotlib.pyplot as plt import torch @@ -70,19 +71,34 @@ class LCFM(nn.Module): return plt.subplot(221) - plt.imshow(y[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none") + plt.imshow( + y[0].detach().cpu().numpy(), + aspect="auto", + origin="lower", + interpolation="none", + ) plt.title("GT") plt.subplot(222) y_ = y_[:, : y.shape[1]] - plt.imshow(y_[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none") + plt.imshow( + y_[0].detach().cpu().numpy(), + aspect="auto", + origin="lower", + interpolation="none", + ) plt.title("Posterior") plt.subplot(223) z_ = self.cfm(x) y__ = self.ae.decode(z_) y__ = y__[:, : y.shape[1]] - plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none") + plt.imshow( + y__[0].detach().cpu().numpy(), + aspect="auto", + origin="lower", + interpolation="none", + ) plt.title("C-Prior") del y__ @@ -90,7 +106,12 @@ class LCFM(nn.Module): z_ = torch.randn_like(z_) y__ = self.ae.decode(z_) y__ = y__[:, : y.shape[1]] - plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none") + plt.imshow( + y__[0].detach().cpu().numpy(), + aspect="auto", + origin="lower", + interpolation="none", + ) plt.title("Prior") del z_, y__ @@ -109,7 +130,7 @@ class LCFM(nn.Module): def eval_tau_(self, tau): self._eval_tau = tau - def forward(self, x, y: Tensor | None = None, ψ0: Tensor | None = None): + def forward(self, x, y: Union[Tensor, None] = None, ψ0: Union[Tensor, None] = None): """ Args: x: (b d t), condition mel @@ -139,14 +160,20 @@ class LCFM(nn.Module): h = self.ae.decode(z) else: - ae_output: IRMAEOutput = self.ae(y, skip_decoding=self.mode == self.Mode.CFM) + ae_output: IRMAEOutput = self.ae( + y, skip_decoding=self.mode == self.Mode.CFM + ) if self.mode == self.Mode.CFM: _ = self.cfm(x, self._scale(ae_output.latent.detach()), ψ0=ψ0) h = ae_output.decoded - if h is not None and self.global_step is not None and self.global_step % 100 == 0: + if ( + h is not None + and self.global_step is not None + and self.global_step % 100 == 0 + ): self._visualize(x[:1], y[:1], h[:1]) return h diff --git a/modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py b/modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py index bb20217f048f398236698f6a38927310d0c1ba9b..602f08851095b7a25a1bddc8a2daa7e48fc10cb1 100644 --- a/modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py +++ b/modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py @@ -1,3 +1,4 @@ +from typing import Union import numpy as np import torch import torch.nn.functional as F @@ -50,7 +51,9 @@ class UnivNet(nn.Module): ] ) - self.conv_pre = weight_norm(nn.Conv1d(self.d_noise, self.nc, 7, padding=3, padding_mode="reflect")) + self.conv_pre = weight_norm( + nn.Conv1d(self.d_noise, self.nc, 7, padding=3, padding_mode="reflect") + ) self.conv_post = nn.Sequential( nn.LeakyReLU(0.2), @@ -64,7 +67,7 @@ class UnivNet(nn.Module): def eps(self): return 1e-5 - def forward(self, x: Tensor, y: Tensor | None = None, npad=10): + def forward(self, x: Tensor, y: Union[Tensor, None] = None, npad=10): """ Args: x: (b c t), acoustic features @@ -74,7 +77,9 @@ class UnivNet(nn.Module): """ assert x.ndim == 3, "x must be 3D tensor" assert y is None or y.ndim == 2, "y must be 2D tensor" - assert x.shape[1] == self.d_input, f"x.shape[1] must be {self.d_input}, but got {x.shape}" + assert ( + x.shape[1] == self.d_input + ), f"x.shape[1] must be {self.d_input}, but got {x.shape}" assert npad >= 0, "npad must be positive or zero" x = F.pad(x, (0, npad), "constant", 0) diff --git a/modules/repos_static/resemble_enhance/hparams.py b/modules/repos_static/resemble_enhance/hparams.py index a8e716175fa962ada1d98cd755430e2ea770278c..9f796e97c3ab1c3d540d9aed14c8bf0796a7d39b 100644 --- a/modules/repos_static/resemble_enhance/hparams.py +++ b/modules/repos_static/resemble_enhance/hparams.py @@ -1,6 +1,7 @@ import logging from dataclasses import asdict, dataclass from pathlib import Path +from typing import Union from omegaconf import OmegaConf from rich.console import Console @@ -102,7 +103,7 @@ class HParams: OmegaConf.save(asdict(self), str(path)) @classmethod - def load(cls, run_dir, yaml: Path | None = None): + def load(cls, run_dir, yaml: Union[Path, None] = None): hps = [] if (run_dir / "hparams.yaml").exists(): @@ -120,7 +121,9 @@ class HParams: for k, v in asdict(hp).items(): if getattr(hps[0], k) != v: errors[k] = f"{getattr(hps[0], k)} != {v}" - raise ValueError(f"Found inconsistent hparams: {errors}, consider deleting {run_dir}") + raise ValueError( + f"Found inconsistent hparams: {errors}, consider deleting {run_dir}" + ) return hps[0] diff --git a/modules/speaker.py b/modules/speaker.py index 46ceaf947a5be1c2915c51d29c5a48707388af82..764cf48839fe8127af63704f9ea947710ea1b3a4 100644 --- a/modules/speaker.py +++ b/modules/speaker.py @@ -29,13 +29,15 @@ class Speaker: speaker.emb = tensor return speaker - def __init__(self, seed, name="", gender="", describe=""): + def __init__( + self, seed_or_tensor: Union[int, torch.Tensor], name="", gender="", describe="" + ): self.id = uuid.uuid4() - self.seed = seed + self.seed = -2 if isinstance(seed_or_tensor, torch.Tensor) else seed_or_tensor self.name = name self.gender = gender self.describe = describe - self.emb = None + self.emb = None if isinstance(seed_or_tensor, int) else seed_or_tensor # TODO replace emb => tokens self.tokens = [] diff --git a/modules/ssml_parser/SSMLParser.py b/modules/ssml_parser/SSMLParser.py index 5db290006a51a809178c7c0198a51a9d6324a888..bda224cd2e9846cff34eb21e7430f1a5b7f9e42a 100644 --- a/modules/ssml_parser/SSMLParser.py +++ b/modules/ssml_parser/SSMLParser.py @@ -11,8 +11,8 @@ import copy class SSMLContext(Box): - def __init__(self, parent=None): - self.parent: Union[SSMLContext, None] = parent + def __init__(self, *args, **kwargs): + self.parent: Union[SSMLContext, None] = None self.style = None self.spk = None @@ -29,18 +29,14 @@ class SSMLContext(Box): self.prompt2 = None self.prefix = None - def clone(self): - ctx = SSMLContext() - for k, v in self.items(): - ctx[k] = v - return ctx + super().__init__(*args, **kwargs) class SSMLSegment(Box): - def __init__(self, text: str, attrs=SSMLContext()): - self.attrs = attrs + def __init__(self, text: str, attrs=SSMLContext(), params=None): + self.attrs = SSMLContext(**attrs) self.text = text - self.params = None + self.params = params class SSMLBreak: @@ -68,7 +64,7 @@ class SSMLParser: root = etree.fromstring(ssml) root_ctx = SSMLContext() - segments = [] + segments: List[Union[SSMLSegment, SSMLBreak]] = [] self.resolve(root, root_ctx, segments) return segments @@ -89,8 +85,13 @@ def create_ssml_parser(): parser = SSMLParser() @parser.resolver("speak") - def tag_speak(element, context, segments, parser): - ctx = context.clone() if context is not None else SSMLContext() + def tag_speak( + element: etree.Element, + context: Box, + segments: List[Union[SSMLSegment, SSMLBreak]], + parser: SSMLParser, + ): + ctx = context.copy() if context is not None else SSMLContext() version = element.get("version") if version != "0.1": @@ -100,8 +101,13 @@ def create_ssml_parser(): parser.resolve(child, ctx, segments) @parser.resolver("voice") - def tag_voice(element, context, segments, parser): - ctx = context.clone() if context is not None else SSMLContext() + def tag_voice( + element: etree.Element, + context: Box, + segments: List[Union[SSMLSegment, SSMLBreak]], + parser: SSMLParser, + ): + ctx = context.copy() if context is not None else SSMLContext() ctx.spk = element.get("spk", ctx.spk) ctx.style = element.get("style", ctx.style) @@ -131,13 +137,23 @@ def create_ssml_parser(): segments.append(SSMLSegment(child.tail.strip(), ctx)) @parser.resolver("break") - def tag_break(element, context, segments, parser): + def tag_break( + element: etree.Element, + context: Box, + segments: List[Union[SSMLSegment, SSMLBreak]], + parser: SSMLParser, + ): time_ms = int(element.get("time", "0").replace("ms", "")) segments.append(SSMLBreak(time_ms)) @parser.resolver("prosody") - def tag_prosody(element, context, segments, parser): - ctx = context.clone() if context is not None else SSMLContext() + def tag_prosody( + element: etree.Element, + context: Box, + segments: List[Union[SSMLSegment, SSMLBreak]], + parser: SSMLParser, + ): + ctx = context.copy() if context is not None else SSMLContext() ctx.spk = element.get("spk", ctx.spk) ctx.style = element.get("style", ctx.style) diff --git a/modules/synthesize_audio.py b/modules/synthesize_audio.py index a07c7bc69de6352cfec6a9c01d6c0203ac4f8d94..c032abc2ee99237b729c23b9e4e2cd1b0cf683d9 100644 --- a/modules/synthesize_audio.py +++ b/modules/synthesize_audio.py @@ -7,6 +7,7 @@ from modules import generate_audio as generate from modules.speaker import Speaker +from modules.ssml_parser.SSMLParser import SSMLSegment from modules.utils import audio @@ -23,45 +24,33 @@ def synthesize_audio( prefix: str = "", batch_size: int = 1, spliter_threshold: int = 100, + end_of_sentence="", ): - if batch_size == 1: - return generate.generate_audio( - text, - temperature=temperature, - top_P=top_P, - top_K=top_K, - spk=spk, - infer_seed=infer_seed, - use_decoder=use_decoder, - prompt1=prompt1, - prompt2=prompt2, - prefix=prefix, + spliter = SentenceSplitter(spliter_threshold) + sentences = spliter.parse(text) + + text_segments = [ + SSMLSegment( + text=s, + params={ + "temperature": temperature, + "top_P": top_P, + "top_K": top_K, + "spk": spk, + "infer_seed": infer_seed, + "use_decoder": use_decoder, + "prompt1": prompt1, + "prompt2": prompt2, + "prefix": prefix, + }, ) - else: - spliter = SentenceSplitter(spliter_threshold) - sentences = spliter.parse(text) + for s in sentences + ] + synthesizer = SynthesizeSegments( + batch_size=batch_size, eos=end_of_sentence, spliter_thr=spliter_threshold + ) + audio_segments = synthesizer.synthesize_segments(text_segments) - text_segments = [ - { - "text": s, - "params": { - "text": s, - "temperature": temperature, - "top_P": top_P, - "top_K": top_K, - "spk": spk, - "infer_seed": infer_seed, - "use_decoder": use_decoder, - "prompt1": prompt1, - "prompt2": prompt2, - "prefix": prefix, - }, - } - for s in sentences - ] - synthesizer = SynthesizeSegments(batch_size) - audio_segments = synthesizer.synthesize_segments(text_segments) + combined_audio = combine_audio_segments(audio_segments) - combined_audio = combine_audio_segments(audio_segments) - - return audio.pydub_to_np(combined_audio) + return audio.pydub_to_np(combined_audio) diff --git a/modules/utils/audio.py b/modules/utils/audio.py index 48f38c598db590bad30687e519db78f1b0b491af..b1a97eeee49ea8aa9d877fc3e9cdeb6e8f1ea1cf 100644 --- a/modules/utils/audio.py +++ b/modules/utils/audio.py @@ -95,7 +95,11 @@ def pitch_shift( def apply_prosody_to_audio_data( - audio_data: np.ndarray, rate: float, volume: float, pitch: float, sr: int + audio_data: np.ndarray, + rate: float = 1, + volume: float = 0, + pitch: float = 0, + sr: int = 24000, ) -> np.ndarray: if rate != 1: audio_data = pyrb.time_stretch(audio_data, sr=sr, rate=rate) diff --git a/modules/webui/app.py b/modules/webui/app.py index 1f64fa8d33ac67801307735ac3cfe26c63e384fa..b1acbf9e5f1fa6be03466d739b8f8445bbd853f7 100644 --- a/modules/webui/app.py +++ b/modules/webui/app.py @@ -7,6 +7,7 @@ from modules import config from modules.webui import gradio_extensions, webui_config from modules.webui.changelog_tab import create_changelog_tab +from modules.webui.finetune.ft_tab import create_ft_tabs from modules.webui.localization_runtime import ENLocalizationVars, ZHLocalizationVars from modules.webui.ssml.podcast_tab import create_ssml_podcast_tab from modules.webui.system_tab import create_system_tab @@ -118,6 +119,8 @@ def create_interface(): gr.Markdown("🚧 Under construction") with gr.TabItem("ASR", visible=webui_config.experimental): gr.Markdown("🚧 Under construction") + with gr.TabItem("Finetune", visible=webui_config.experimental): + create_ft_tabs(demo) with gr.TabItem("System"): create_system_tab() diff --git a/modules/webui/finetune/ProcessMonitor.py b/modules/webui/finetune/ProcessMonitor.py new file mode 100644 index 0000000000000000000000000000000000000000..a92c187ae0de80ec0ad93f56d65e623d4a916c55 --- /dev/null +++ b/modules/webui/finetune/ProcessMonitor.py @@ -0,0 +1,92 @@ +import os +import sys +import subprocess +import threading + + +class ProcessMonitor: + def __init__(self): + self.process = None + self.stdout = "" + self.stderr = "" + self.lock = threading.Lock() + + def start_process(self, command): + self.process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + bufsize=1, + universal_newlines=True, + ) + + # Set pipes to non-blocking mode + fd_out = self.process.stdout.fileno() + fd_err = self.process.stderr.fileno() + + if sys.platform != "win32": + import fcntl + + fl_out = fcntl.fcntl(fd_out, fcntl.F_GETFL) + fl_err = fcntl.fcntl(fd_err, fcntl.F_GETFL) + fcntl.fcntl(fd_out, fcntl.F_SETFL, fl_out | os.O_NONBLOCK) + fcntl.fcntl(fd_err, fcntl.F_SETFL, fl_err | os.O_NONBLOCK) + + # Start threads to read stdout and stderr + threading.Thread(target=self._read_stdout).start() + threading.Thread(target=self._read_stderr).start() + + def _read_stdout(self): + while self.process is not None and self.process.poll() is None: + try: + output = self.process.stdout.read() + if output: + with self.lock: + self.stdout += output + except: + pass + + def _read_stderr(self): + while self.process is not None and self.process.poll() is None: + try: + error = self.process.stderr.read() + if error: + with self.lock: + self.stderr += error + except: + pass + + def get_output(self): + with self.lock: + return self.stdout, self.stderr + + def stop_process(self): + if self.process: + self.process.terminate() + self.process = None + + +if __name__ == "__main__": + import time + + pm = ProcessMonitor() + pm.start_process( + [ + "python", + "-u", + "-c", + "import time; [print(i) or time.sleep(1) for i in range(5)]", + ] + ) + + while pm.process and pm.process.poll() is None: + stdout, stderr = pm.get_output() + if stdout: + print("STDOUT:", stdout) + if stderr: + print("STDERR:", stderr) + time.sleep(1) + + stdout, stderr = pm.get_output() + print("Final STDOUT:", stdout) + print("Final STDERR:", stderr) diff --git a/modules/webui/finetune/ft_tab.py b/modules/webui/finetune/ft_tab.py new file mode 100644 index 0000000000000000000000000000000000000000..ac7147f40e598abf846ef5ee2d6e3e8a6bf005b9 --- /dev/null +++ b/modules/webui/finetune/ft_tab.py @@ -0,0 +1,13 @@ +import gradio as gr + +from modules.webui.finetune.speaker_ft_tab import create_speaker_ft_tab + + +def create_ft_tabs(demo): + with gr.Tabs(): + with gr.TabItem("Speaker"): + create_speaker_ft_tab(demo) + with gr.TabItem("GPT"): + gr.Markdown("🚧 Under construction") + with gr.TabItem("AE"): + gr.Markdown("🚧 Under construction") diff --git a/modules/webui/finetune/ft_ui_utils.py b/modules/webui/finetune/ft_ui_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c2a8e8ca09ebe5a8c9c6df6f6e336a77de10fb17 --- /dev/null +++ b/modules/webui/finetune/ft_ui_utils.py @@ -0,0 +1,49 @@ +import os +from typing import IO, Union +from modules.speaker import Speaker, speaker_mgr +import subprocess + + +def get_datasets_dir(): + """ + 列出 ./datasets/data_* 文件夹 + """ + dataset_path = "./datasets" + dataset_list = os.listdir(dataset_path) + dataset_list = [ + d for d in dataset_list if os.path.isdir(os.path.join(dataset_path, d)) + ] + dataset_list = [d for d in dataset_list if d.startswith("data_")] + return dataset_list + + +def get_datasets_listfile(): + datasets = get_datasets_dir() + listfiles = [] + for d in datasets: + dir_path = os.path.join("./datasets", d) + files = os.listdir(dir_path) + for f in files: + if f.endswith(".list"): + listfiles.append(os.path.join(dir_path, f)) + return listfiles + + +def run_speaker_ft( + batch_size: int, epochs: int, train_text: bool, data_path: str, init_speaker: str +): + command = ["python3", "-m", "modules.finetune.train_speaker"] + command += [ + f"--batch_size={batch_size}", + f"--epochs={epochs}", + f"--data_path={data_path}", + ] + if train_text: + command.append("--train_text") + if init_speaker: + command.append(f"--init_speaker={init_speaker}") + process = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1 + ) + + return process diff --git a/modules/webui/finetune/speaker_ft_tab.py b/modules/webui/finetune/speaker_ft_tab.py new file mode 100644 index 0000000000000000000000000000000000000000..f652ec422a5b71acfaec98aa74f93c72b61ae32b --- /dev/null +++ b/modules/webui/finetune/speaker_ft_tab.py @@ -0,0 +1,130 @@ +import gradio as gr + +from modules.Enhancer.ResembleEnhance import unload_enhancer +from modules.webui import webui_config +from modules.webui.webui_utils import get_speaker_names +from .ft_ui_utils import get_datasets_listfile, run_speaker_ft +from .ProcessMonitor import ProcessMonitor +from modules.speaker import speaker_mgr +from modules.models import unload_chat_tts + + +class SpeakerFt: + def __init__(self): + self.process_monitor = ProcessMonitor() + self.status_str = "idle" + + def unload_main_thread_models(self): + unload_chat_tts() + unload_enhancer() + + def run( + self, + batch_size: int, + epochs: int, + lr: str, + train_text: bool, + data_path: str, + select_speaker: str = "", + ): + if self.process_monitor.process: + return + self.unload_main_thread_models() + spk_path = None + if select_speaker != "" and select_speaker != "none": + select_speaker = select_speaker.split(" : ")[1].strip() + spk = speaker_mgr.get_speaker(select_speaker) + if spk is None: + return ["Speaker not found"] + spk_filename = speaker_mgr.get_speaker_filename(spk.id) + spk_path = f"./data/speakers/{spk_filename}" + + command = ["python3", "-m", "modules.finetune.train_speaker"] + command += [ + f"--batch_size={batch_size}", + f"--epochs={epochs}", + f"--data_path={data_path}", + ] + if train_text: + command.append("--train_text") + if spk_path: + command.append(f"--init_speaker={spk_path}") + + self.status("Training process starting") + + self.process_monitor.start_process(command) + + self.status("Training started") + + def status(self, text: str): + self.status_str = text + + def flush(self): + stdout, stderr = self.process_monitor.get_output() + return f"{self.status_str}\n{stdout}\n{stderr}" + + def clear(self): + self.process_monitor.stdout = "" + self.process_monitor.stderr = "" + self.status("Logs cleared") + + def stop(self): + self.process_monitor.stop_process() + self.status("Training stopped") + + +def create_speaker_ft_tab(demo: gr.Blocks): + spk_ft = SpeakerFt() + speakers, speaker_names = get_speaker_names() + speaker_names = ["none"] + speaker_names + + with gr.Row(): + with gr.Column(scale=2): + with gr.Group(): + gr.Markdown("🎛️hparams") + dataset_input = gr.Dropdown( + label="Dataset", choices=get_datasets_listfile() + ) + lr_input = gr.Textbox(label="Learning Rate", value="1e-2") + epochs_input = gr.Slider( + label="Epochs", value=10, minimum=1, maximum=100, step=1 + ) + batch_size_input = gr.Slider( + label="Batch Size", value=4, minimum=1, maximum=64, step=1 + ) + train_text_checkbox = gr.Checkbox(label="Train text_loss", value=True) + init_spk_dropdown = gr.Dropdown( + label="Initial Speaker", + choices=speaker_names, + value="none", + ) + + with gr.Group(): + start_train_btn = gr.Button("Start Training") + stop_train_btn = gr.Button("Stop Training") + clear_train_btn = gr.Button("Clear logs") + with gr.Column(scale=5): + with gr.Group(): + # log + gr.Markdown("📜logs") + log_output = gr.Textbox( + show_label=False, label="Log", value="", lines=20, interactive=True + ) + + start_train_btn.click( + spk_ft.run, + inputs=[ + batch_size_input, + epochs_input, + lr_input, + train_text_checkbox, + dataset_input, + init_spk_dropdown, + ], + outputs=[], + ) + stop_train_btn.click(spk_ft.stop) + clear_train_btn.click(spk_ft.clear) + + if webui_config.experimental: + demo.load(spk_ft.flush, every=1, outputs=[log_output]) diff --git a/modules/webui/localization_runtime.py b/modules/webui/localization_runtime.py index 9689c960e900ad1ec93c3e85e5c09d8bb5a54626..273eb05c66676f525fb484b6a6a64a1129091462 100644 --- a/modules/webui/localization_runtime.py +++ b/modules/webui/localization_runtime.py @@ -7,6 +7,7 @@ class LocalizationVars: self.ssml_examples = [] self.tts_examples = [] + self.podcast_default = [] class ZHLocalizationVars(LocalizationVars): @@ -167,6 +168,69 @@ class ZHLocalizationVars(LocalizationVars): }, ] + self.podcast_default = [ + [ + 1, + "female2", + "你好,欢迎收听今天的播客内容。今天我们要聊的是中华料理。", + "podcast", + ], + [ + 2, + "Alice", + "嗨,我特别期待这个话题!中华料理真的是博大精深。", + "podcast", + ], + [ + 3, + "Bob", + "没错,中华料理有着几千年的历史,而且每个地区都有自己的特色菜。", + "podcast", + ], + [ + 4, + "female2", + "那我们先从最有名的川菜开始吧。川菜以其麻辣著称,是很多人的最爱。", + "podcast", + ], + [ + 5, + "Alice", + "对,我特别喜欢吃麻婆豆腐和辣子鸡。那种麻辣的感觉真是让人难以忘怀。", + "podcast", + ], + [ + 6, + "Bob", + "除了川菜,粤菜也是很受欢迎的。粤菜讲究鲜美,像是白切鸡和蒸鱼都是经典。", + "podcast", + ], + [ + 7, + "female2", + "对啊,粤菜的烹饪方式比较清淡,更注重食材本身的味道。", + "podcast", + ], + [ + 8, + "Alice", + "还有北京的京菜,像北京烤鸭,那可是来北京必吃的美食。", + "podcast", + ], + [ + 9, + "Bob", + "不仅如此,还有淮扬菜、湘菜、鲁菜等等,每个菜系都有其独特的风味。", + "podcast", + ], + [ + 10, + "female2", + "对对对,像淮扬菜的狮子头,湘菜的剁椒鱼头,都是让人垂涎三尺的美味。", + "podcast", + ], + ] + class ENLocalizationVars(LocalizationVars): def __init__(self): @@ -224,3 +288,65 @@ class ENLocalizationVars(LocalizationVars): "text": "Don't ever let somebody tell you you can't do something. Not even me. Alright? You got a dream, you gotta protect it. When people can't do something themselves, they're gonna tell you that you can't do it. You want something, go get it. Period.", }, ] + self.podcast_default = [ + [ + 1, + "female2", + "Hello, welcome to today's podcast. Today, we're going to talk about global cuisine.", + "podcast", + ], + [ + 2, + "Alice", + "Hi, I'm really excited about this topic! Global cuisine is incredibly diverse and fascinating.", + "podcast", + ], + [ + 3, + "Bob", + "Absolutely, every country has its own unique culinary traditions and specialties.", + "podcast", + ], + [ + 4, + "female2", + "Let's start with Italian cuisine. Italian food is loved worldwide, especially for its pasta and pizza.", + "podcast", + ], + [ + 5, + "Alice", + "Yes, I especially love a good Margherita pizza and a hearty plate of spaghetti carbonara. The flavors are simply amazing.", + "podcast", + ], + [ + 6, + "Bob", + "Besides Italian cuisine, Japanese cuisine is also very popular. Dishes like sushi and ramen have become global favorites.", + "podcast", + ], + [ + 7, + "female2", + "Exactly, Japanese cuisine is known for its emphasis on fresh ingredients and delicate presentation.", + "podcast", + ], + [ + 8, + "Alice", + "And then there's Mexican cuisine, with its bold flavors and colorful dishes like tacos and guacamole.", + "podcast", + ], + [ + 9, + "Bob", + "Not to mention, there's also Indian cuisine, Thai cuisine, French cuisine, and so many more, each with its own distinctive flavors and techniques.", + "podcast", + ], + [ + 10, + "female2", + "Yes, like Indian curry, Thai tom yum soup, and French croissants, these are all mouth-watering dishes that are loved by people all over the world.", + "podcast", + ], + ] diff --git a/modules/webui/ssml/podcast_tab.py b/modules/webui/ssml/podcast_tab.py index 440e0b60eb5f8b394a86cc6dac7266ec018fb6ec..32e732d34fb9c4b4e86cf88b5e549312c88174b4 100644 --- a/modules/webui/ssml/podcast_tab.py +++ b/modules/webui/ssml/podcast_tab.py @@ -3,72 +3,9 @@ import pandas as pd import torch from modules.normalization import text_normalize -from modules.webui import webui_utils +from modules.webui import webui_config, webui_utils from modules.utils.hf import spaces -podcast_default_case = [ - [ - 1, - "female2", - "你好,欢迎收听今天的播客内容。今天我们要聊的是中华料理。 [lbreak]", - "podcast", - ], - [ - 2, - "Alice", - "嗨,我特别期待这个话题!中华料理真的是博大精深。 [lbreak]", - "podcast", - ], - [ - 3, - "Bob", - "没错,中华料理有着几千年的历史,而且每个地区都有自己的特色菜。 [lbreak]", - "podcast", - ], - [ - 4, - "female2", - "那我们先从最有名的川菜开始吧。川菜以其麻辣著称,是很多人的最爱。 [lbreak]", - "podcast", - ], - [ - 5, - "Alice", - "对,我特别喜欢吃麻婆豆腐和辣子鸡。那种麻辣的感觉真是让人难以忘怀。 [lbreak]", - "podcast", - ], - [ - 6, - "Bob", - "除了川菜,粤菜也是很受欢迎的。粤菜讲究鲜美,像是白切鸡和蒸鱼都是经典。 [lbreak]", - "podcast", - ], - [ - 7, - "female2", - "对啊,粤菜的烹饪方式比较清淡,更注重食材本身的味道。 [lbreak]", - "podcast", - ], - [ - 8, - "Alice", - "还有北京的京菜,像北京烤鸭,那可是来北京必吃的美食。 [lbreak]", - "podcast", - ], - [ - 9, - "Bob", - "不仅如此,还有淮扬菜、湘菜、鲁菜等等,每个菜系都有其独特的风味。 [lbreak]", - "podcast", - ], - [ - 10, - "female2", - "对对对,像淮扬菜的狮子头,湘菜的剁椒鱼头,都是让人垂涎三尺的美味。 [lbreak]", - "podcast", - ], -] - # NOTE: 因为 text_normalize 需要使用 tokenizer @torch.inference_mode() @@ -133,7 +70,7 @@ def create_ssml_podcast_tab(ssml_input: gr.Textbox, tabs1: gr.Tabs, tabs2: gr.Ta datatype=["number", "str", "str", "str"], interactive=True, wrap=True, - value=podcast_default_case, + value=webui_config.localization.podcast_default, row_count=(0, "dynamic"), col_count=(4, "fixed"), ) diff --git a/modules/webui/ssml/ssml_tab.py b/modules/webui/ssml/ssml_tab.py index f2de84c2e28bcd35e7c14611cc9ff3fdb57bdfc7..6fa6dd861daa3fd246d56aec4ded84ef068537d3 100644 --- a/modules/webui/ssml/ssml_tab.py +++ b/modules/webui/ssml/ssml_tab.py @@ -22,7 +22,6 @@ def create_ssml_interface(): ssml_button = gr.Button("🔊Synthesize SSML", variant="primary") with gr.Column(scale=1): with gr.Group(): - # 参数 gr.Markdown("🎛️Parameters") # batch size batch_size_input = gr.Slider( @@ -32,6 +31,19 @@ def create_ssml_interface(): maximum=webui_config.max_batch_size, step=1, ) + with gr.Group(): + gr.Markdown("🎛️Spliter") + eos_input = gr.Textbox( + label="eos", + value="[uv_break]", + ) + spliter_thr_input = gr.Slider( + label="Spliter Threshold", + value=100, + minimum=50, + maximum=1000, + step=1, + ) with gr.Group(): gr.Markdown("💪🏼Enhance") @@ -49,7 +61,14 @@ def create_ssml_interface(): ssml_button.click( synthesize_ssml, - inputs=[ssml_input, batch_size_input, enable_enhance, enable_de_noise], + inputs=[ + ssml_input, + batch_size_input, + enable_enhance, + enable_de_noise, + eos_input, + spliter_thr_input, + ], outputs=ssml_output, ) diff --git a/modules/webui/tts_tab.py b/modules/webui/tts_tab.py index c39e7ed284cca32eae897cfcfaf80559a1c3d49b..ab81e12243bbd83eccf0a0f026cb9650f86df690 100644 --- a/modules/webui/tts_tab.py +++ b/modules/webui/tts_tab.py @@ -29,32 +29,6 @@ def create_tts_interface(): with gr.Row(): with gr.Column(scale=1): - with gr.Group(): - gr.Markdown("🎛️Sampling") - temperature_input = gr.Slider( - 0.01, 2.0, value=0.3, step=0.01, label="Temperature" - ) - top_p_input = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Top P") - top_k_input = gr.Slider(1, 50, value=20, step=1, label="Top K") - batch_size_input = gr.Slider( - 1, - webui_config.max_batch_size, - value=4, - step=1, - label="Batch Size", - ) - - with gr.Row(): - with gr.Group(): - gr.Markdown("🎭Style") - gr.Markdown("TTS_STYLE_GUIDE") - style_input_dropdown = gr.Dropdown( - choices=styles, - # label="Choose Style", - interactive=True, - show_label=False, - value="*auto", - ) with gr.Row(): with gr.Group(): gr.Markdown("🗣️Speaker") @@ -102,7 +76,47 @@ def create_tts_interface(): fn=load_spk_info, inputs=[spk_file_upload], outputs=[infos], - ), + ) + + with gr.Row(): + with gr.Group(): + gr.Markdown("🎭Style") + gr.Markdown("TTS_STYLE_GUIDE") + style_input_dropdown = gr.Dropdown( + choices=styles, + # label="Choose Style", + interactive=True, + show_label=False, + value="*auto", + ) + + with gr.Group(): + gr.Markdown("🎛️Sampling") + temperature_input = gr.Slider( + 0.01, 2.0, value=0.3, step=0.01, label="Temperature" + ) + top_p_input = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Top P") + top_k_input = gr.Slider(1, 50, value=20, step=1, label="Top K") + batch_size_input = gr.Slider( + 1, + webui_config.max_batch_size, + value=4, + step=1, + label="Batch Size", + ) + with gr.Group(): + gr.Markdown("🎛️Spliter") + eos_input = gr.Textbox( + label="eos", + value="[uv_break]", + ) + spliter_thr_input = gr.Slider( + label="Spliter Threshold", + value=100, + minimum=50, + maximum=1000, + step=1, + ) with gr.Group(): gr.Markdown("💃Inference Seed") @@ -202,7 +216,8 @@ def create_tts_interface(): ) refine_button = gr.Button("✍️Refine Text") - with gr.Group(): + # 由于使用不是很方便,所以列为实验性功能 + with gr.Group(visible=webui_config.experimental): gr.Markdown("🔧Prompt engineering") prompt1_input = gr.Textbox(label="Prompt 1") prompt2_input = gr.Textbox(label="Prompt 2") @@ -253,6 +268,8 @@ def create_tts_interface(): enable_enhance, enable_de_noise, spk_file_upload, + spliter_thr_input, + eos_input, ], outputs=tts_output, ) diff --git a/modules/webui/webui_utils.py b/modules/webui/webui_utils.py index 4a6d6dedf17c6a37c16a3546afab089a69599d38..cf57c80c597d20508a7e70a3d544e90f9bafe92a 100644 --- a/modules/webui/webui_utils.py +++ b/modules/webui/webui_utils.py @@ -95,6 +95,8 @@ def synthesize_ssml( batch_size=4, enable_enhance=False, enable_denoise=False, + eos: str = "[uv_break]", + spliter_thr: int = 100, ): try: batch_size = int(batch_size) @@ -114,7 +116,9 @@ def synthesize_ssml( if len(segments) == 0: return None - synthesize = SynthesizeSegments(batch_size=batch_size) + synthesize = SynthesizeSegments( + batch_size=batch_size, eos=eos, spliter_thr=spliter_thr + ) audio_segments = synthesize.synthesize_segments(segments) combined_audio = combine_audio_segments(audio_segments) @@ -151,6 +155,8 @@ def tts_generate( enable_enhance=False, enable_denoise=False, spk_file=None, + spliter_thr: int = 100, + eos: str = "[uv_break]", ): try: batch_size = int(batch_size) @@ -199,6 +205,8 @@ def tts_generate( prompt2=prompt2, prefix=prefix, batch_size=batch_size, + end_of_sentence=eos, + spliter_threshold=spliter_thr, ) audio_data, sample_rate = apply_audio_enhance( diff --git a/webui.py b/webui.py index 055874043e85b90485fee64bc4982cd4d3373486..d2b7fdb075b9eb61ca9f0e37c63cb9e4c7469ea2 100644 --- a/webui.py +++ b/webui.py @@ -1,4 +1,5 @@ import os +import sys import logging from modules.api.api_setup import ( @@ -106,6 +107,7 @@ def process_webui_args(args): auth=auth, show_api=False, prevent_thread_lock=True, + inbrowser=sys.platform == "win32", app_kwargs={ "title": app_title, "description": app_description,