Artrajz's picture
init
960cd20
import logging
import numpy as np
import torch
from bert_vits2 import commons
from bert_vits2 import utils as bert_vits2_utils
from bert_vits2.clap_wrapper import get_clap_audio_feature, get_clap_text_feature
from bert_vits2.get_emo import get_emo
from bert_vits2.models import SynthesizerTrn
from bert_vits2.models_v230 import SynthesizerTrn as SynthesizerTrn_v230
from bert_vits2.models_ja_extra import SynthesizerTrn as SynthesizerTrn_ja_extra
from bert_vits2.text import *
from bert_vits2.text.cleaner import clean_text
from bert_vits2.utils import process_legacy_versions
from contants import config
from utils import get_hparams_from_file
from utils.sentence import split_languages
class Bert_VITS2:
def __init__(self, model_path, config, device=torch.device("cpu"), **kwargs):
self.model_path = model_path
self.hps_ms = get_hparams_from_file(config) if isinstance(config, str) else config
self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0)
self.speakers = [item[0] for item in
sorted(list(getattr(self.hps_ms.data, 'spk2id', {'0': 0}).items()), key=lambda x: x[1])]
self.symbols = symbols
self.sampling_rate = self.hps_ms.data.sampling_rate
self.bert_model_names = {}
self.zh_bert_extra = False
self.ja_bert_extra = False
self.ja_bert_dim = 1024
self.num_tones = num_tones
self.pinyinPlus = None
# Compatible with legacy versions
self.version = process_legacy_versions(self.hps_ms).lower().replace("-", "_")
self.text_extra_str_map = {"zh": "", "ja": "", "en": ""}
self.bert_extra_str_map = {"zh": "", "ja": "", "en": ""}
self.hps_ms.model.emotion_embedding = None
if self.version in ["1.0", "1.0.0", "1.0.1"]:
"""
chinese-roberta-wwm-ext-large
"""
self.version = "1.0"
self.symbols = symbols_legacy
self.hps_ms.model.n_layers_trans_flow = 3
self.lang = getattr(self.hps_ms.data, "lang", ["zh"])
self.ja_bert_dim = 768
self.num_tones = num_tones_v111
self.text_extra_str_map.update({"zh": "_v100"})
elif self.version in ["1.1.0-transition"]:
"""
chinese-roberta-wwm-ext-large
"""
self.version = "1.1.0-transition"
self.hps_ms.model.n_layers_trans_flow = 3
self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja"])
self.ja_bert_dim = 768
self.num_tones = num_tones_v111
if "ja" in self.lang: self.bert_model_names.update({"ja": "BERT_BASE_JAPANESE_V3"})
self.text_extra_str_map.update({"zh": "_v100", "ja": "_v111"})
self.bert_extra_str_map.update({"ja": "_v111"})
elif self.version in ["1.1", "1.1.0", "1.1.1"]:
"""
chinese-roberta-wwm-ext-large
bert-base-japanese-v3
"""
self.version = "1.1"
self.hps_ms.model.n_layers_trans_flow = 6
self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja"])
self.ja_bert_dim = 768
self.num_tones = num_tones_v111
if "ja" in self.lang: self.bert_model_names.update({"ja": "BERT_BASE_JAPANESE_V3"})
self.text_extra_str_map.update({"zh": "_v100", "ja": "_v111"})
self.bert_extra_str_map.update({"ja": "_v111"})
elif self.version in ["2.0", "2.0.0", "2.0.1", "2.0.2"]:
"""
chinese-roberta-wwm-ext-large
deberta-v2-large-japanese
deberta-v3-large
"""
self.version = "2.0"
self.hps_ms.model.n_layers_trans_flow = 4
self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja", "en"])
self.num_tones = num_tones
if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE"})
if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"})
self.text_extra_str_map.update({"zh": "_v100", "ja": "_v200", "en": "_v200"})
self.bert_extra_str_map.update({"ja": "_v200", "en": "_v200"})
elif self.version in ["2.1", "2.1.0"]:
"""
chinese-roberta-wwm-ext-large
deberta-v2-large-japanese-char-wwm
deberta-v3-large
wav2vec2-large-robust-12-ft-emotion-msp-dim
"""
self.version = "2.1"
self.hps_ms.model.n_layers_trans_flow = 4
self.hps_ms.model.emotion_embedding = 1
self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja", "en"])
self.num_tones = num_tones
if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"})
if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"})
elif self.version in ["2.2", "2.2.0"]:
"""
chinese-roberta-wwm-ext-large
deberta-v2-large-japanese-char-wwm
deberta-v3-large
clap-htsat-fused
"""
self.version = "2.2"
self.hps_ms.model.n_layers_trans_flow = 4
self.hps_ms.model.emotion_embedding = 2
self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja", "en"])
self.num_tones = num_tones
if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"})
if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"})
elif self.version in ["2.3", "2.3.0"]:
"""
chinese-roberta-wwm-ext-large
deberta-v2-large-japanese-char-wwm
deberta-v3-large
"""
self.version = "2.3"
self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja", "en"])
self.num_tones = num_tones
self.text_extra_str_map.update({"en": "_v230"})
if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"})
if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"})
elif self.version is not None and self.version in ["extra", "zh_clap"]:
"""
Erlangshen-MegatronBert-1.3B-Chinese
clap-htsat-fused
"""
self.version = "extra"
self.hps_ms.model.emotion_embedding = 2
self.hps_ms.model.n_layers_trans_flow = 6
self.lang = ["zh"]
self.num_tones = num_tones
self.zh_bert_extra = True
self.bert_model_names.update({"zh": "Erlangshen_MegatronBert_1.3B_Chinese"})
self.bert_extra_str_map.update({"zh": "_extra"})
elif self.version is not None and self.version in ["extra_fix", "2.4", "2.4.0"]:
"""
Erlangshen-MegatronBert-1.3B-Chinese
clap-htsat-fused
"""
self.version = "2.4"
self.hps_ms.model.emotion_embedding = 2
self.hps_ms.model.n_layers_trans_flow = 6
self.lang = ["zh"]
self.num_tones = num_tones
self.zh_bert_extra = True
self.bert_model_names.update({"zh": "Erlangshen_MegatronBert_1.3B_Chinese"})
self.bert_extra_str_map.update({"zh": "_extra"})
self.text_extra_str_map.update({"zh": "_v240"})
elif self.version is not None and self.version in ["ja_extra"]:
"""
deberta-v2-large-japanese-char-wwm
"""
self.version = "ja_extra"
self.hps_ms.model.emotion_embedding = 2
self.hps_ms.model.n_layers_trans_flow = 6
self.lang = ["ja"]
self.num_tones = num_tones
self.ja_bert_extra = True
self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"})
self.bert_extra_str_map.update({"ja": "_extra"})
self.text_extra_str_map.update({"ja": "_extra"})
else:
logging.debug("Version information not found. Loaded as the newest version: v2.3.")
self.version = "2.3"
self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja", "en"])
self.num_tones = num_tones
self.text_extra_str_map.update({"en": "_v230"})
if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"})
if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"})
if "zh" in self.lang and "zh" not in self.bert_model_names.keys():
self.bert_model_names.update({"zh": "CHINESE_ROBERTA_WWM_EXT_LARGE"})
self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
self.device = device
def load_model(self, model_handler):
self.model_handler = model_handler
if self.version in ["2.3", "extra", "2.4"]:
Synthesizer = SynthesizerTrn_v230
elif self.version == "ja_extra":
Synthesizer = SynthesizerTrn_ja_extra
else:
Synthesizer = SynthesizerTrn
if self.version == "2.4":
self.pinyinPlus = self.model_handler.get_pinyinPlus()
self.net_g = Synthesizer(
len(self.symbols),
self.hps_ms.data.filter_length // 2 + 1,
self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
n_speakers=self.hps_ms.data.n_speakers,
symbols=self.symbols,
ja_bert_dim=self.ja_bert_dim,
num_tones=self.num_tones,
zh_bert_extra=self.zh_bert_extra,
**self.hps_ms.model).to(self.device)
_ = self.net_g.eval()
bert_vits2_utils.load_checkpoint(self.model_path, self.net_g, None, skip_optimizer=True, version=self.version)
def get_speakers(self):
return self.speakers
def get_text(self, text, language_str, hps, style_text=None, style_weight=0.7):
clean_text_lang_str = language_str + self.text_extra_str_map.get(language_str, "")
bert_feature_lang_str = language_str + self.bert_extra_str_map.get(language_str, "")
tokenizer, _ = self.model_handler.get_bert_model(self.bert_model_names[language_str])
norm_text, phone, tone, word2ph = clean_text(text, clean_text_lang_str, tokenizer, self.pinyinPlus)
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, self._symbol_to_id)
if hps.data.add_blank:
phone = commons.intersperse(phone, 0)
tone = commons.intersperse(tone, 0)
language = commons.intersperse(language, 0)
for i in range(len(word2ph)):
word2ph[i] = word2ph[i] * 2
word2ph[0] += 1
if style_text == "" or self.zh_bert_extra:
style_text = None
bert = self.model_handler.get_bert_feature(norm_text, word2ph, bert_feature_lang_str,
self.bert_model_names[language_str], style_text, style_weight)
del word2ph
assert bert.shape[-1] == len(phone), phone
if self.zh_bert_extra:
zh_bert = bert
ja_bert, en_bert = None, None
elif self.ja_bert_extra:
ja_bert = bert
zh_bert, en_bert = None, None
elif language_str == "zh":
zh_bert = bert
ja_bert = torch.zeros(self.ja_bert_dim, len(phone))
en_bert = torch.zeros(1024, len(phone))
elif language_str == "ja":
zh_bert = torch.zeros(1024, len(phone))
ja_bert = bert
en_bert = torch.zeros(1024, len(phone))
elif language_str == "en":
zh_bert = torch.zeros(1024, len(phone))
ja_bert = torch.zeros(self.ja_bert_dim, len(phone))
en_bert = bert
else:
zh_bert = torch.zeros(1024, len(phone))
ja_bert = torch.zeros(self.ja_bert_dim, len(phone))
en_bert = torch.zeros(1024, len(phone))
assert bert.shape[-1] == len(
phone
), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
phone = torch.LongTensor(phone)
tone = torch.LongTensor(tone)
language = torch.LongTensor(language)
return zh_bert, ja_bert, en_bert, phone, tone, language
def _get_emo(self, reference_audio, emotion):
if reference_audio:
emo = torch.from_numpy(
get_emo(reference_audio, self.model_handler.emotion_model,
self.model_handler.emotion_processor))
else:
if emotion is None: emotion = 0
emo = torch.Tensor([emotion])
return emo
def _get_clap(self, reference_audio, text_prompt):
if isinstance(reference_audio, np.ndarray):
emo = get_clap_audio_feature(reference_audio, self.model_handler.clap_model,
self.model_handler.clap_processor, self.device)
else:
if text_prompt is None: text_prompt = config.bert_vits2_config.text_prompt
emo = get_clap_text_feature(text_prompt, self.model_handler.clap_model,
self.model_handler.clap_processor, self.device)
emo = torch.squeeze(emo, dim=1).unsqueeze(0)
return emo
def _infer(self, id, phones, tones, lang_ids, zh_bert, ja_bert, en_bert, sdp_ratio, noise, noisew, length,
emo=None):
with torch.no_grad():
x_tst = phones.to(self.device).unsqueeze(0)
tones = tones.to(self.device).unsqueeze(0)
lang_ids = lang_ids.to(self.device).unsqueeze(0)
if self.zh_bert_extra:
zh_bert = zh_bert.to(self.device).unsqueeze(0)
elif self.ja_bert_extra:
ja_bert = ja_bert.to(self.device).unsqueeze(0)
else:
zh_bert = zh_bert.to(self.device).unsqueeze(0)
ja_bert = ja_bert.to(self.device).unsqueeze(0)
en_bert = en_bert.to(self.device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(self.device)
speakers = torch.LongTensor([int(id)]).to(self.device)
audio = self.net_g.infer(x_tst,
x_tst_lengths,
speakers,
tones,
lang_ids,
zh_bert=zh_bert,
ja_bert=ja_bert,
en_bert=en_bert,
sdp_ratio=sdp_ratio,
noise_scale=noise,
noise_scale_w=noisew,
length_scale=length,
emo=emo
)[0][0, 0].data.cpu().float().numpy()
torch.cuda.empty_cache()
return audio
def infer(self, text, id, lang, sdp_ratio, noise, noisew, length, reference_audio=None, emotion=None,
text_prompt=None, style_text=None, style_weigth=0.7, **kwargs):
zh_bert, ja_bert, en_bert, phones, tones, lang_ids = self.get_text(text, lang, self.hps_ms, style_text,
style_weigth)
emo = None
if self.hps_ms.model.emotion_embedding == 1:
emo = self._get_emo(reference_audio, emotion).to(self.device).unsqueeze(0)
elif self.hps_ms.model.emotion_embedding == 2:
emo = self._get_clap(reference_audio, text_prompt)
return self._infer(id, phones, tones, lang_ids, zh_bert, ja_bert, en_bert, sdp_ratio, noise, noisew, length,
emo)
def infer_multilang(self, text, id, lang, sdp_ratio, noise, noisew, length, reference_audio=None, emotion=None,
text_prompt=None, style_text=None, style_weigth=0.7, **kwargs):
sentences_list = split_languages(text, self.lang, expand_abbreviations=True, expand_hyphens=True)
emo = None
if self.hps_ms.model.emotion_embedding == 1:
emo = self._get_emo(reference_audio, emotion).to(self.device).unsqueeze(0)
elif self.hps_ms.model.emotion_embedding == 2:
emo = self._get_clap(reference_audio, text_prompt)
phones, tones, lang_ids, zh_bert, ja_bert, en_bert = [], [], [], [], [], []
for idx, (_text, lang) in enumerate(sentences_list):
skip_start = idx != 0
skip_end = idx != len(sentences_list) - 1
_zh_bert, _ja_bert, _en_bert, _phones, _tones, _lang_ids = self.get_text(_text, lang, self.hps_ms,
style_text, style_weigth)
if skip_start:
_phones = _phones[3:]
_tones = _tones[3:]
_lang_ids = _lang_ids[3:]
_zh_bert = _zh_bert[:, 3:]
_ja_bert = _ja_bert[:, 3:]
_en_bert = _en_bert[:, 3:]
if skip_end:
_phones = _phones[:-2]
_tones = _tones[:-2]
_lang_ids = _lang_ids[:-2]
_zh_bert = _zh_bert[:, :-2]
_ja_bert = _ja_bert[:, :-2]
_en_bert = _en_bert[:, :-2]
phones.append(_phones)
tones.append(_tones)
lang_ids.append(_lang_ids)
zh_bert.append(_zh_bert)
ja_bert.append(_ja_bert)
en_bert.append(_en_bert)
zh_bert = torch.cat(zh_bert, dim=1)
ja_bert = torch.cat(ja_bert, dim=1)
en_bert = torch.cat(en_bert, dim=1)
phones = torch.cat(phones, dim=0)
tones = torch.cat(tones, dim=0)
lang_ids = torch.cat(lang_ids, dim=0)
audio = self._infer(id, phones, tones, lang_ids, zh_bert, ja_bert, en_bert, sdp_ratio, noise,
noisew, length, emo)
return audio