Spaces:
Running
Running
import logging | |
import math | |
import os.path | |
import re | |
from typing import List | |
import librosa | |
import numpy as np | |
import torch | |
from time import time as ttime | |
from contants import config | |
from gpt_sovits.AR.models.t2s_lightning_module import Text2SemanticLightningModule | |
from gpt_sovits.module.mel_processing import spectrogram_torch | |
from gpt_sovits.module.models import SynthesizerTrn | |
from gpt_sovits.utils import DictToAttrRecursive | |
from gpt_sovits.text import cleaned_text_to_sequence | |
from gpt_sovits.text.cleaner import clean_text | |
from utils.classify_language import classify_language | |
from utils.data_utils import check_is_none | |
from utils.sentence import split_languages, sentence_split | |
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", } | |
class GPT_SoVITS: | |
def __init__(self, sovits_path, gpt_path, device, **kwargs): | |
self.sovits_path = sovits_path | |
self.gpt_path = gpt_path | |
self.hz = config.gpt_sovits_config.hz | |
self.sampling_rate = None | |
self.device = device | |
self.model_handler = None | |
self.is_half = config.gpt_sovits_config.is_half | |
self.np_dtype = np.float16 if self.is_half else np.float32 | |
self.torch_dtype = torch.float16 if self.is_half else torch.float32 | |
self.speakers = None | |
self.lang = ["zh", "ja", "en"] | |
self.flash_attn_enabled = True | |
self.prompt_cache: dict = { | |
"ref_audio_path": None, | |
"prompt_semantic": None, | |
"refer_spepc": None, | |
"prompt_text": None, | |
"prompt_lang": None, | |
"phones": None, | |
"bert_features": None, | |
"norm_text": None, | |
} | |
def load_model(self, model_handler): | |
self.model_handler = model_handler | |
self.load_sovits(self.sovits_path) | |
self.load_gpt(self.gpt_path) | |
self.tokenizer, self.bert_model = self.model_handler.get_bert_model("CHINESE_ROBERTA_WWM_EXT_LARGE") | |
self.ssl_model = self.model_handler.get_ssl_model() | |
def load_weight(self, saved_state_dict, model): | |
if hasattr(model, 'module'): | |
state_dict = model.module.state_dict() | |
else: | |
state_dict = model.state_dict() | |
new_state_dict = {} | |
for k, v in state_dict.items(): | |
try: | |
new_state_dict[k] = saved_state_dict[k] | |
except: | |
# logging.info(f"{k} is not in the checkpoint") | |
new_state_dict[k] = v | |
if hasattr(model, 'module'): | |
model.module.load_state_dict(new_state_dict) | |
else: | |
model.load_state_dict(new_state_dict) | |
def load_sovits(self, sovits_path): | |
# self.n_semantic = 1024 | |
logging.info(f"Loaded checkpoint '{sovits_path}'") | |
dict_s2 = torch.load(sovits_path, map_location=self.device) | |
self.hps = dict_s2["config"] | |
self.hps = DictToAttrRecursive(self.hps) | |
self.hps.model.semantic_frame_rate = "25hz" | |
# self.speakers = [self.hps.get("name")] # 从模型配置中获取名字 | |
self.speakers = [os.path.basename(os.path.dirname(self.sovits_path))] # 用模型文件夹作为名字 | |
self.vq_model = SynthesizerTrn( | |
self.hps.data.filter_length // 2 + 1, | |
self.hps.train.segment_size // self.hps.data.hop_length, | |
n_speakers=self.hps.data.n_speakers, | |
**self.hps.model).to(self.device) | |
if config.gpt_sovits_config.is_half: | |
self.vq_model = self.vq_model.half() | |
self.vq_model.eval() | |
self.sampling_rate = self.hps.data.sampling_rate | |
self.load_weight(dict_s2['weight'], self.vq_model) | |
def load_gpt(self, gpt_path): | |
logging.info(f"Loaded checkpoint '{gpt_path}'") | |
dict_s1 = torch.load(gpt_path, map_location=self.device) | |
self.gpt_config = dict_s1["config"] | |
self.max_sec = self.gpt_config.get("data").get("max_sec") | |
self.t2s_model = Text2SemanticLightningModule(self.gpt_config, "****", is_train=False, | |
flash_attn_enabled=self.flash_attn_enabled).to( | |
self.device) | |
self.load_weight(dict_s1['weight'], self.t2s_model) | |
if config.gpt_sovits_config.is_half: | |
self.t2s_model = self.t2s_model.half() | |
self.t2s_model.eval() | |
total = sum([param.nelement() for param in self.t2s_model.parameters()]) | |
logging.info(f"Number of parameter: {total / 1e6:.2f}M") | |
def get_speakers(self): | |
return self.speakers | |
def get_cleaned_text(self, text, language): | |
phones, word2ph, norm_text = clean_text(text, language.replace("all_", "")) | |
phones = cleaned_text_to_sequence(phones) | |
return phones, word2ph, norm_text | |
def get_cleaned_text_multilang(self, text): | |
sentences = split_languages(text, expand_abbreviations=True, expand_hyphens=True) | |
phones, word2ph, norm_text = [], [], [] | |
for sentence, lang in sentences: | |
lang = classify_language(sentence) | |
_phones, _word2ph, _norm_text = self.get_cleaned_text(sentence, lang) | |
phones.extend(_phones) | |
word2ph.extend(_word2ph) | |
norm_text.extend(_norm_text) | |
return phones, word2ph, norm_text | |
def get_bert_feature(self, text, phones, word2ph, language): | |
if language == "zh": | |
with torch.no_grad(): | |
inputs = self.tokenizer(text, return_tensors="pt") | |
for i in inputs: | |
inputs[i] = inputs[i].to(self.device) #####输入是long不用管精度问题,精度随bert_model | |
res = self.bert_model(**inputs, output_hidden_states=True) | |
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] | |
assert len(word2ph) == len(text) | |
phone_level_feature = [] | |
for i in range(len(word2ph)): | |
repeat_feature = res[i].repeat(word2ph[i], 1) | |
phone_level_feature.append(repeat_feature) | |
phone_level_feature = torch.cat(phone_level_feature, dim=0) | |
# if(config.gpt_sovits_config.is_half==True):phone_level_feature=phone_level_feature.half() | |
bert = phone_level_feature.T | |
torch.cuda.empty_cache() | |
else: | |
bert = torch.zeros((1024, len(phones)), dtype=self.torch_dtype) | |
return bert | |
def get_bert_and_cleaned_text_multilang(self, text: list): | |
sentences = split_languages(text, expand_abbreviations=True, expand_hyphens=True) | |
phones, word2ph, norm_text, bert = [], [], [], [] | |
for sentence, lang in sentences: | |
_phones, _word2ph, _norm_text = self.get_cleaned_text(sentence, lang) | |
_bert = self.get_bert_feature(sentence, _phones, _word2ph, _norm_text) | |
phones.extend(_phones) | |
if _word2ph is not None: | |
word2ph.extend(_word2ph) | |
norm_text.extend(_norm_text) | |
bert.append(_bert) | |
bert = torch.cat(bert, dim=1).to(self.device, dtype=self.torch_dtype) | |
return phones, word2ph, norm_text, bert | |
def get_spepc(self, audio, orig_sr): | |
"""audio的sampling_rate与模型相同""" | |
audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=int(self.hps.data.sampling_rate)) | |
audio = torch.FloatTensor(audio) | |
audio_norm = audio | |
audio_norm = audio_norm.unsqueeze(0) | |
spec = spectrogram_torch( | |
audio_norm, | |
self.hps.data.filter_length, | |
self.hps.data.sampling_rate, | |
self.hps.data.hop_length, | |
self.hps.data.win_length, | |
center=False, | |
) | |
return spec | |
def _set_prompt_semantic(self, reference_audio, reference_audio_sr): | |
zero_wav = np.zeros( | |
int(self.sampling_rate * 0.3), | |
dtype=np.float16 if self.is_half else np.float32, | |
) | |
wav16k = librosa.resample(reference_audio, orig_sr=reference_audio_sr, target_sr=16000) | |
with torch.no_grad(): | |
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000): | |
raise OSError("参考音频在3~10秒范围外,请更换!") | |
wav16k = torch.from_numpy(wav16k) | |
zero_wav_torch = torch.from_numpy(zero_wav) | |
if self.is_half == True: | |
wav16k = wav16k.half() | |
zero_wav_torch = zero_wav_torch.half() | |
wav16k = wav16k.to(self.device) | |
zero_wav_torch = zero_wav_torch.to(self.device) | |
wav16k = torch.cat([wav16k, zero_wav_torch]).unsqueeze(0) | |
ssl_content = self.ssl_model.model(wav16k)[ | |
"last_hidden_state" | |
].transpose( | |
1, 2 | |
) # .float() | |
codes = self.vq_model.extract_latent(ssl_content) | |
prompt_semantic = codes[0, 0].to(self.device) | |
# prompt_semantic = prompt_semantic.unsqueeze(0).to(self.device) | |
self.prompt_cache["prompt_semantic"] = prompt_semantic | |
torch.cuda.empty_cache() | |
def get_first(self, text): | |
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]" | |
text = re.split(pattern, text)[0].strip() | |
return text | |
def preprocess_text(self, text: str, lang: str, segment_size: int): | |
texts = sentence_split(text, segment_size) | |
result = [] | |
for text in texts: | |
phones, word2ph, norm_text, bert_features = self.get_bert_and_cleaned_text_multilang(text) | |
res = { | |
"phones": phones, | |
"bert_features": bert_features, | |
"norm_text": norm_text, | |
} | |
result.append(res) | |
return result | |
def preprocess_prompt(self, reference_audio, reference_audio_sr, prompt_text: str, prompt_lang: str): | |
if self.prompt_cache.get("prompt_text") != prompt_text: | |
if prompt_lang.lower() == "auto": | |
prompt_lang = classify_language(prompt_text) | |
if (prompt_text[-1] not in splits): | |
prompt_text += "。" if prompt_lang != "en" else "." | |
phones, word2ph, norm_text = self.get_cleaned_text(prompt_text, prompt_lang) | |
bert_features = self.get_bert_feature(norm_text, phones, word2ph, prompt_lang).to(self.device, | |
dtype=self.torch_dtype) | |
self.prompt_cache["prompt_text"] = prompt_text | |
self.prompt_cache["prompt_lang"] = prompt_lang | |
self.prompt_cache["phones"] = phones | |
self.prompt_cache["bert_features"] = bert_features | |
self.prompt_cache["norm_text"] = norm_text | |
self.prompt_cache["refer_spepc"] = self.get_spepc(reference_audio, orig_sr=reference_audio_sr) | |
self._set_prompt_semantic(reference_audio, reference_audio_sr) | |
def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length: int = None): | |
seq = sequences[0] | |
ndim = seq.dim() | |
if axis < 0: | |
axis += ndim | |
dtype: torch.dtype = seq.dtype | |
pad_value = torch.tensor(pad_value, dtype=dtype) | |
seq_lengths = [seq.shape[axis] for seq in sequences] | |
if max_length is None: | |
max_length = max(seq_lengths) | |
else: | |
max_length = max(seq_lengths) if max_length < max(seq_lengths) else max_length | |
padded_sequences = [] | |
for seq, length in zip(sequences, seq_lengths): | |
padding = [0] * axis + [0, max_length - length] + [0] * (ndim - axis - 1) | |
padded_seq = torch.nn.functional.pad(seq, padding, value=pad_value) | |
padded_sequences.append(padded_seq) | |
batch = torch.stack(padded_sequences) | |
return batch | |
def to_batch(self, data: list, prompt_data: dict = None, batch_size: int = 5, threshold: float = 0.75, | |
split_bucket: bool = True): | |
_data: list = [] | |
index_and_len_list = [] | |
for idx, item in enumerate(data): | |
norm_text_len = len(item["norm_text"]) | |
index_and_len_list.append([idx, norm_text_len]) | |
batch_index_list = [] | |
if split_bucket: | |
index_and_len_list.sort(key=lambda x: x[1]) | |
index_and_len_list = np.array(index_and_len_list, dtype=np.int64) | |
batch_index_list_len = 0 | |
pos = 0 | |
while pos < index_and_len_list.shape[0]: | |
# batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))]) | |
pos_end = min(pos + batch_size, index_and_len_list.shape[0]) | |
while pos < pos_end: | |
batch = index_and_len_list[pos:pos_end, 1].astype(np.float32) | |
score = batch[(pos_end - pos) // 2] / batch.mean() | |
if (score >= threshold) or (pos_end - pos == 1): | |
batch_index = index_and_len_list[pos:pos_end, 0].tolist() | |
batch_index_list_len += len(batch_index) | |
batch_index_list.append(batch_index) | |
pos = pos_end | |
break | |
pos_end = pos_end - 1 | |
assert batch_index_list_len == len(data) | |
else: | |
for i in range(len(data)): | |
if i % batch_size == 0: | |
batch_index_list.append([]) | |
batch_index_list[-1].append(i) | |
for batch_idx, index_list in enumerate(batch_index_list): | |
item_list = [data[idx] for idx in index_list] | |
phones_list = [] | |
phones_len_list = [] | |
# bert_features_list = [] | |
all_phones_list = [] | |
all_phones_len_list = [] | |
all_bert_features_list = [] | |
norm_text_batch = [] | |
bert_max_len = 0 | |
phones_max_len = 0 | |
for item in item_list: | |
if prompt_data is not None: | |
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1) | |
all_phones = torch.LongTensor(prompt_data["phones"] + item["phones"]) | |
phones = torch.LongTensor(item["phones"]) | |
# norm_text = prompt_data["norm_text"]+item["norm_text"] | |
else: | |
all_bert_features = item["bert_features"] | |
phones = torch.LongTensor(item["phones"]) | |
all_phones = phones | |
# norm_text = item["norm_text"] | |
bert_max_len = max(bert_max_len, all_bert_features.shape[-1]) | |
phones_max_len = max(phones_max_len, phones.shape[-1]) | |
phones_list.append(phones) | |
phones_len_list.append(phones.shape[-1]) | |
all_phones_list.append(all_phones) | |
all_phones_len_list.append(all_phones.shape[-1]) | |
all_bert_features_list.append(all_bert_features) | |
norm_text_batch.append(item["norm_text"]) | |
phones_batch = phones_list | |
max_len = max(bert_max_len, phones_max_len) | |
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) | |
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len) | |
all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, max_len) | |
all_bert_features_batch.zero_() | |
for idx, item in enumerate(all_bert_features_list): | |
if item != None: | |
all_bert_features_batch[idx, :, : item.shape[-1]] = item | |
batch = { | |
"phones": phones_batch, | |
"phones_len": torch.LongTensor(phones_len_list), | |
"all_phones": all_phones_batch, | |
"all_phones_len": torch.LongTensor(all_phones_len_list), | |
"all_bert_features": all_bert_features_batch, | |
"norm_text": norm_text_batch | |
} | |
_data.append(batch) | |
return _data, batch_index_list | |
def recovery_order(self, data: list, batch_index_list: list) -> list: | |
''' | |
Recovery the order of the audio according to the batch_index_list. | |
Args: | |
data (List[list(np.ndarray)]): the out of order audio . | |
batch_index_list (List[list[int]]): the batch index list. | |
Returns: | |
list (List[np.ndarray]): the data in the original order. | |
''' | |
lenght = len(sum(batch_index_list, [])) | |
_data = [None] * lenght | |
for i, index_list in enumerate(batch_index_list): | |
for j, index in enumerate(index_list): | |
_data[index] = data[i][j] | |
return _data | |
def audio_postprocess(self, audio: List[torch.Tensor], sr: int, batch_index_list: list = None, | |
speed_factor: float = 1.0, split_bucket: bool = True) -> tuple[int, np.ndarray]: | |
zero_wav = torch.zeros( | |
int(self.sampling_rate * 0.3), | |
dtype=torch.float16 if self.is_half else torch.float32, | |
device=self.device | |
) | |
for i, batch in enumerate(audio): | |
for j, audio_fragment in enumerate(batch): | |
max_audio = torch.abs(audio_fragment).max() # 简单防止16bit爆音 | |
if max_audio > 1: audio_fragment /= max_audio | |
audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) | |
audio[i][j] = audio_fragment.cpu().numpy() | |
if split_bucket: | |
audio = self.recovery_order(audio, batch_index_list) | |
else: | |
# audio = [item for batch in audio for item in batch] | |
audio = sum(audio, []) | |
audio = np.concatenate(audio, 0) | |
try: | |
if speed_factor != 1.0: | |
audio = self.speed_change(audio, speed_factor=speed_factor, sr=int(sr)) | |
except Exception as e: | |
logging.error(f"Failed to change speed of audio: \n{e}") | |
return audio | |
def speed_change(self, input_audio: np.ndarray, speed_factor: float, sr: int): | |
# 变速处理 | |
processed_audio = librosa.effects.time_stretch(input_audio, rate=speed_factor) | |
return processed_audio | |
def infer(self, text, lang, reference_audio, reference_audio_sr, prompt_text, prompt_lang, top_k, top_p, | |
temperature, batch_size: int = 5, batch_threshold: float = 0.75, split_bucket: bool = True, | |
return_fragment: bool = False, speed_factor: float = 1.0, | |
segment_size: int = config.gpt_sovits_config.segment_size, **kwargs): | |
if return_fragment: | |
split_bucket = False | |
data = self.preprocess_text(text, lang, segment_size) | |
no_prompt_text = False | |
if check_is_none(prompt_text): | |
no_prompt_text = True | |
else: | |
self.preprocess_prompt(reference_audio, reference_audio_sr, prompt_text, prompt_lang) | |
data, batch_index_list = self.to_batch(data, | |
prompt_data=self.prompt_cache if not no_prompt_text else None, | |
batch_size=batch_size, | |
threshold=batch_threshold, | |
split_bucket=split_bucket | |
) | |
audio = [] | |
for item in data: | |
batch_phones = item["phones"] | |
batch_phones_len = item["phones_len"] | |
all_phoneme_ids = item["all_phones"] | |
all_phoneme_lens = item["all_phones_len"] | |
all_bert_features = item["all_bert_features"] | |
norm_text = item["norm_text"] | |
# batch_phones = batch_phones.to(self.device) | |
batch_phones_len = batch_phones_len.to(self.device) | |
all_phoneme_ids = all_phoneme_ids.to(self.device) | |
all_phoneme_lens = all_phoneme_lens.to(self.device) | |
all_bert_features = all_bert_features.to(self.device) | |
if self.is_half: | |
all_bert_features = all_bert_features.half() | |
logging.debug(f"Infer text:{[''.join(text) for text in norm_text]}") | |
if no_prompt_text: | |
prompt = None | |
else: | |
prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to( | |
self.device) | |
with torch.no_grad(): | |
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( | |
all_phoneme_ids, | |
all_phoneme_lens, | |
prompt, | |
all_bert_features, | |
# prompt_phone_len=ph_offset, | |
top_k=top_k, | |
top_p=top_p, | |
temperature=temperature, | |
early_stop_num=self.hz * self.max_sec, | |
) | |
refer_audio_spepc: torch.Tensor = self.prompt_cache["refer_spepc"].to(self.device) | |
if self.is_half: | |
refer_audio_spepc = refer_audio_spepc.half() | |
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] | |
upsample_rate = math.prod(self.vq_model.upsample_rates) | |
audio_frag_idx = [pred_semantic_list[i].shape[0] * 2 * upsample_rate for i in | |
range(0, len(pred_semantic_list))] | |
audio_frag_end_idx = [sum(audio_frag_idx[:i + 1]) for i in range(0, len(audio_frag_idx))] | |
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.device) | |
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.device) | |
_batch_audio_fragment = (self.vq_model.decode( | |
all_pred_semantic, _batch_phones, refer_audio_spepc | |
).detach()[0, 0, :]) | |
audio_frag_end_idx.insert(0, 0) | |
batch_audio_fragment = [_batch_audio_fragment[audio_frag_end_idx[i - 1]:audio_frag_end_idx[i]] for i in | |
range(1, len(audio_frag_end_idx))] | |
torch.cuda.empty_cache() | |
if return_fragment: | |
yield self.audio_postprocess([batch_audio_fragment], | |
reference_audio_sr, | |
batch_index_list, | |
speed_factor, | |
split_bucket) | |
else: | |
audio.append(batch_audio_fragment) | |
if not return_fragment: | |
yield self.audio_postprocess(audio, | |
reference_audio_sr, | |
batch_index_list, | |
speed_factor, | |
split_bucket) | |