ChatTTS-Forge / modules /SynthesizeSegments.py
zhzluke96
update
da8d589
from box import Box
from pydub import AudioSegment
from typing import List, Union
from scipy.io.wavfile import write
import io
from modules.api.utils import calc_spk_style
from modules.ssml_parser.SSMLParser import SSMLSegment, SSMLBreak, SSMLContext
from modules.utils import rng
from modules.utils.audio import time_stretch, pitch_shift
from modules import generate_audio
from modules.normalization import text_normalize
import logging
import json
from modules.speaker import Speaker, speaker_mgr
logger = logging.getLogger(__name__)
def audio_data_to_segment(audio_data, sr):
byte_io = io.BytesIO()
write(byte_io, rate=sr, data=audio_data)
byte_io.seek(0)
return AudioSegment.from_file(byte_io, format="wav")
def combine_audio_segments(audio_segments: list[AudioSegment]) -> AudioSegment:
combined_audio = AudioSegment.empty()
for segment in audio_segments:
combined_audio += segment
return combined_audio
def apply_prosody(
audio_segment: AudioSegment, rate: float, volume: float, pitch: float
) -> AudioSegment:
if rate != 1:
audio_segment = time_stretch(audio_segment, rate)
if volume != 0:
audio_segment += volume
if pitch != 0:
audio_segment = pitch_shift(audio_segment, pitch)
return audio_segment
def to_number(value, t, default=0):
try:
number = t(value)
return number
except (ValueError, TypeError) as e:
return default
class TTSAudioSegment(Box):
text: str
temperature: float
top_P: float
top_K: int
spk: int
infer_seed: int
prompt1: str
prompt2: str
prefix: str
_type: str
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class SynthesizeSegments:
def __init__(self, batch_size: int = 8):
self.batch_size = batch_size
self.batch_default_spk_seed = rng.np_rng()
self.batch_default_infer_seed = rng.np_rng()
def segment_to_generate_params(
self, segment: Union[SSMLSegment, SSMLBreak]
) -> TTSAudioSegment:
if isinstance(segment, SSMLBreak):
return TTSAudioSegment(_type="break")
if segment.get("params", None) is not None:
return TTSAudioSegment(**segment.get("params"))
text = segment.get("text", "")
is_end = segment.get("is_end", False)
text = str(text).strip()
attrs = segment.attrs
spk = attrs.spk
style = attrs.style
ss_params = calc_spk_style(spk, style)
if "spk" in ss_params:
spk = ss_params["spk"]
seed = to_number(attrs.seed, int, ss_params.get("seed") or -1)
top_k = to_number(attrs.top_k, int, None)
top_p = to_number(attrs.top_p, float, None)
temp = to_number(attrs.temp, float, None)
prompt1 = attrs.prompt1 or ss_params.get("prompt1")
prompt2 = attrs.prompt2 or ss_params.get("prompt2")
prefix = attrs.prefix or ss_params.get("prefix")
disable_normalize = attrs.get("normalize", "") == "False"
seg = TTSAudioSegment(
_type="voice",
text=text,
temperature=temp if temp is not None else 0.3,
top_P=top_p if top_p is not None else 0.5,
top_K=top_k if top_k is not None else 20,
spk=spk if spk else -1,
infer_seed=seed if seed else -1,
prompt1=prompt1 if prompt1 else "",
prompt2=prompt2 if prompt2 else "",
prefix=prefix if prefix else "",
)
if not disable_normalize:
seg.text = text_normalize(text, is_end=is_end)
# NOTE 每个batch的默认seed保证前后一致即使是没设置spk的情况
if seg.spk == -1:
seg.spk = self.batch_default_spk_seed
if seg.infer_seed == -1:
seg.infer_seed = self.batch_default_infer_seed
return seg
def process_break_segments(
self,
src_segments: List[SSMLBreak],
bucket_segments: List[SSMLBreak],
audio_segments: List[AudioSegment],
):
for segment in bucket_segments:
index = src_segments.index(segment)
audio_segments[index] = AudioSegment.silent(
duration=int(segment.attrs.duration)
)
def process_voice_segments(
self,
src_segments: List[SSMLSegment],
bucket: List[SSMLSegment],
audio_segments: List[AudioSegment],
):
for i in range(0, len(bucket), self.batch_size):
batch = bucket[i : i + self.batch_size]
param_arr = [self.segment_to_generate_params(segment) for segment in batch]
texts = [params.text for params in param_arr]
params = param_arr[0]
audio_datas = generate_audio.generate_audio_batch(
texts=texts,
temperature=params.temperature,
top_P=params.top_P,
top_K=params.top_K,
spk=params.spk,
infer_seed=params.infer_seed,
prompt1=params.prompt1,
prompt2=params.prompt2,
prefix=params.prefix,
)
for idx, segment in enumerate(batch):
sr, audio_data = audio_datas[idx]
rate = float(segment.get("rate", "1.0"))
volume = float(segment.get("volume", "0"))
pitch = float(segment.get("pitch", "0"))
audio_segment = audio_data_to_segment(audio_data, sr)
audio_segment = apply_prosody(audio_segment, rate, volume, pitch)
original_index = src_segments.index(segment)
audio_segments[original_index] = audio_segment
def bucket_segments(
self, segments: List[Union[SSMLSegment, SSMLBreak]]
) -> List[List[Union[SSMLSegment, SSMLBreak]]]:
buckets = {"<break>": []}
for segment in segments:
if isinstance(segment, SSMLBreak):
buckets["<break>"].append(segment)
continue
params = self.segment_to_generate_params(segment)
if isinstance(params.spk, Speaker):
params.spk = str(params.spk.id)
key = json.dumps(
{k: v for k, v in params.items() if k != "text"}, sort_keys=True
)
if key not in buckets:
buckets[key] = []
buckets[key].append(segment)
return buckets
def synthesize_segments(
self, segments: List[Union[SSMLSegment, SSMLBreak]]
) -> List[AudioSegment]:
audio_segments = [None] * len(segments)
buckets = self.bucket_segments(segments)
break_segments = buckets.pop("<break>")
self.process_break_segments(segments, break_segments, audio_segments)
buckets = list(buckets.values())
for bucket in buckets:
self.process_voice_segments(segments, bucket, audio_segments)
return audio_segments
# 示例使用
if __name__ == "__main__":
ctx1 = SSMLContext()
ctx1.spk = 1
ctx1.seed = 42
ctx1.temp = 0.1
ctx2 = SSMLContext()
ctx2.spk = 2
ctx2.seed = 42
ctx2.temp = 0.1
ssml_segments = [
SSMLSegment(text="大🍌,一条大🍌,嘿,你的感觉真的很奇妙", attrs=ctx1.copy()),
SSMLBreak(duration_ms=1000),
SSMLSegment(text="大🍉,一个大🍉,嘿,你的感觉真的很奇妙", attrs=ctx1.copy()),
SSMLSegment(text="大🍊,一个大🍊,嘿,你的感觉真的很奇妙", attrs=ctx2.copy()),
]
synthesizer = SynthesizeSegments(batch_size=2)
audio_segments = synthesizer.synthesize_segments(ssml_segments)
print(audio_segments)
combined_audio = combine_audio_segments(audio_segments)
combined_audio.export("output.wav", format="wav")