""" 放置公用模型 """ import gc import logging import os import torch from transformers import AutoTokenizer, AutoModelForMaskedLM, BertTokenizer, MegatronBertModel from contants import config from utils.download import download_file from bert_vits2.text.chinese_bert import get_bert_feature as zh_bert from bert_vits2.text.english_bert_mock import get_bert_feature as en_bert from bert_vits2.text.japanese_bert import get_bert_feature as ja_bert from bert_vits2.text.japanese_bert_v111 import get_bert_feature as ja_bert_v111 from bert_vits2.text.japanese_bert_v200 import get_bert_feature as ja_bert_v200 from bert_vits2.text.english_bert_mock_v200 import get_bert_feature as en_bert_v200 from bert_vits2.text.chinese_bert_extra import get_bert_feature as zh_bert_extra from bert_vits2.text.japanese_bert_extra import get_bert_feature as ja_bert_extra class ModelHandler: def __init__(self, device=config.system.device): self.DOWNLOAD_PATHS = { "CHINESE_ROBERTA_WWM_EXT_LARGE": [ "https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin", "https://hf-mirror.com/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin", ], "BERT_BASE_JAPANESE_V3": [ "https://huggingface.co/cl-tohoku/bert-base-japanese-v3/resolve/main/pytorch_model.bin", "https://hf-mirror.com/cl-tohoku/bert-base-japanese-v3/resolve/main/pytorch_model.bin", ], "BERT_LARGE_JAPANESE_V2": [ "https://huggingface.co/cl-tohoku/bert-large-japanese-v2/resolve/main/pytorch_model.bin", "https://hf-mirror.com/cl-tohoku/bert-large-japanese-v2/resolve/main/pytorch_model.bin", ], "DEBERTA_V2_LARGE_JAPANESE": [ "https://huggingface.co/ku-nlp/deberta-v2-large-japanese/resolve/main/pytorch_model.bin", "https://hf-mirror.com/ku-nlp/deberta-v2-large-japanese/resolve/main/pytorch_model.bin", ], "DEBERTA_V3_LARGE": [ "https://huggingface.co/microsoft/deberta-v3-large/resolve/main/pytorch_model.bin", "https://hf-mirror.com/microsoft/deberta-v3-large/resolve/main/pytorch_model.bin", ], "SPM": [ "https://huggingface.co/microsoft/deberta-v3-large/resolve/main/spm.model", "https://hf-mirror.com/microsoft/deberta-v3-large/resolve/main/spm.model", ], "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM": [ "https://huggingface.co/ku-nlp/deberta-v2-large-japanese-char-wwm/resolve/main/pytorch_model.bin", "https://hf-mirror.com/ku-nlp/deberta-v2-large-japanese-char-wwm/resolve/main/pytorch_model.bin", ], "WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM": [ "https://huggingface.co/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim/resolve/main/pytorch_model.bin", "https://hf-mirror.com/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim/resolve/main/pytorch_model.bin", ], "CLAP_HTSAT_FUSED": [ "https://huggingface.co/laion/clap-htsat-fused/resolve/main/pytorch_model.bin?download=true", "https://hf-mirror.com/laion/clap-htsat-fused/resolve/main/pytorch_model.bin?download=true", ], "Erlangshen_MegatronBert_1.3B_Chinese": [ "https://huggingface.co/IDEA-CCNL/Erlangshen-UniMC-MegatronBERT-1.3B-Chinese/resolve/main/pytorch_model.bin", "https://hf-mirror.com/IDEA-CCNL/Erlangshen-UniMC-MegatronBERT-1.3B-Chinese/resolve/main/pytorch_model.bin", ], "G2PWModel": [ # "https://storage.googleapis.com/esun-ai/g2pW/G2PWModel-v2-onnx.zip", "https://huggingface.co/ADT109119/G2PWModel-v2-onnx/resolve/main/g2pw.onnx", "https://hf-mirror.com/ADT109119/G2PWModel-v2-onnx/resolve/main/g2pw.onnx", ], "CHINESE_HUBERT_BASE": [ "https://huggingface.co/TencentGameMate/chinese-hubert-base/resolve/main/pytorch_model.bin", "https://hf-mirror.com/TencentGameMate/chinese-hubert-base/resolve/main/pytorch_model.bin", ] } self.SHA256 = { "CHINESE_ROBERTA_WWM_EXT_LARGE": "4ac62d49144d770c5ca9a5d1d3039c4995665a080febe63198189857c6bd11cd", "BERT_BASE_JAPANESE_V3": "e172862e0674054d65e0ba40d67df2a4687982f589db44aa27091c386e5450a4", "BERT_LARGE_JAPANESE_V2": "50212d714f79af45d3e47205faa356d0e5030e1c9a37138eadda544180f9e7c9", "DEBERTA_V2_LARGE_JAPANESE": "a6c15feac0dea77ab8835c70e1befa4cf4c2137862c6fb2443b1553f70840047", "DEBERTA_V3_LARGE": "dd5b5d93e2db101aaf281df0ea1216c07ad73620ff59c5b42dccac4bf2eef5b5", "SPM": "c679fbf93643d19aab7ee10c0b99e460bdbc02fedf34b92b05af343b4af586fd", "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM": "bf0dab8ad87bd7c22e85ec71e04f2240804fda6d33196157d6b5923af6ea1201", "WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM": "176d9d1ce29a8bddbab44068b9c1c194c51624c7f1812905e01355da58b18816", "CLAP_HTSAT_FUSED": "1ed5d0215d887551ddd0a49ce7311b21429ebdf1e6a129d4e68f743357225253", "Erlangshen_MegatronBert_1.3B_Chinese": "3456bb8f2c7157985688a4cb5cecdb9e229cb1dcf785b01545c611462ffe3579", # "G2PWModel": "bb40c8c7b5baa755b2acd317c6bc5a65e4af7b80c40a569247fbd76989299999", "G2PWModel": "", "CHINESE_HUBERT_BASE": "2fefccd26c2794a583b80f6f7210c721873cb7ebae2c1cde3baf9b27855e24d8", } self.model_path = { "CHINESE_ROBERTA_WWM_EXT_LARGE": os.path.join(config.abs_path, config.system.data_path, config.model_config.chinese_roberta_wwm_ext_large), "BERT_BASE_JAPANESE_V3": os.path.join(config.abs_path, config.system.data_path, config.model_config.bert_base_japanese_v3), "BERT_LARGE_JAPANESE_V2": os.path.join(config.abs_path, config.system.data_path, config.model_config.bert_large_japanese_v2), "DEBERTA_V2_LARGE_JAPANESE": os.path.join(config.abs_path, config.system.data_path, config.model_config.deberta_v2_large_japanese), "DEBERTA_V3_LARGE": os.path.join(config.abs_path, config.system.data_path, config.model_config.deberta_v3_large), "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM": os.path.join(config.abs_path, config.system.data_path, config.model_config.deberta_v2_large_japanese_char_wwm), "WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM": os.path.join(config.abs_path, config.system.data_path, config.model_config.wav2vec2_large_robust_12_ft_emotion_msp_dim), "CLAP_HTSAT_FUSED": os.path.join(config.abs_path, config.system.data_path, config.model_config.clap_htsat_fused), "Erlangshen_MegatronBert_1.3B_Chinese": os.path.join(config.abs_path, config.system.data_path, config.model_config.erlangshen_MegatronBert_1_3B_Chinese), "G2PWModel": os.path.join(config.abs_path, config.system.data_path, config.model_config.g2pw_model), "CHINESE_HUBERT_BASE": os.path.join(config.abs_path, config.system.data_path, config.model_config.chinese_hubert_base), } self.lang_bert_func_map = {"zh": zh_bert, "en": en_bert, "ja": ja_bert, "ja_v111": ja_bert_v111, "ja_v200": ja_bert_v200, "en_v200": en_bert_v200, "zh_extra": zh_bert_extra, "ja_extra": ja_bert_extra} self.bert_models = {} # Value: (tokenizer, model, reference_count) self.emotion = None self.clap = None self.pinyinPlus = None self.device = device self.ssl_model = None if config.bert_vits2_config.torch_data_type.lower() in ["float16", "fp16"]: self.torch_dtype = torch.float16 else: self.torch_dtype = None @property def emotion_model(self): return self.emotion["model"] @property def emotion_processor(self): return self.emotion["processor"] @property def clap_model(self): return self.clap["model"] @property def clap_processor(self): return self.clap["processor"] def _download_model(self, model_name, target_path=None): urls = self.DOWNLOAD_PATHS[model_name] if target_path is None: target_path = os.path.join(self.model_path[model_name], "pytorch_model.bin") expected_sha256 = self.SHA256[model_name] success, message = download_file(urls, target_path, expected_sha256=expected_sha256) if not success: logging.error(f"Failed to download {model_name}: {message}") else: logging.info(f"{message}") def load_bert(self, bert_model_name, max_retries=3): if bert_model_name not in self.bert_models: retries = 0 model_path = "" while retries < max_retries: model_path = self.model_path[bert_model_name] logging.info(f"Loading BERT model: {model_path}") try: if bert_model_name == "Erlangshen_MegatronBert_1.3B_Chinese": tokenizer = BertTokenizer.from_pretrained(model_path, torch_dtype=self.torch_dtype) model = MegatronBertModel.from_pretrained(model_path, torch_dtype=self.torch_dtype).to( self.device) else: tokenizer = AutoTokenizer.from_pretrained(model_path, torch_dtype=self.torch_dtype) model = AutoModelForMaskedLM.from_pretrained(model_path, torch_dtype=self.torch_dtype).to( self.device) self.bert_models[bert_model_name] = (tokenizer, model, 1) # 初始化引用计数为1 logging.info(f"Success loading: {model_path}") break except Exception as e: logging.error(f"Failed loading {model_path}. {e}") logging.info(f"Trying to download.") if bert_model_name == "DEBERTA_V3_LARGE" and not os.path.exists( os.path.join(model_path, "spm.model")): self._download_model("SPM", os.path.join(model_path, "spm.model")) self._download_model(bert_model_name) retries += 1 if retries == max_retries: logging.error(f"Failed to load {model_path} after {max_retries} retries.") else: tokenizer, model, count = self.bert_models[bert_model_name] self.bert_models[bert_model_name] = (tokenizer, model, count + 1) def load_emotion(self, max_retries=3): """Bert-VITS2 v2.1 EmotionModel""" if self.emotion is None: from transformers import Wav2Vec2Processor from bert_vits2.get_emo import EmotionModel retries = 0 model_path = self.model_path["WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM"] while retries < max_retries: logging.info(f"Loading WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM: {model_path}") try: self.emotion = {} self.emotion["model"] = EmotionModel.from_pretrained(model_path).to(self.device) self.emotion["processor"] = Wav2Vec2Processor.from_pretrained(model_path) self.emotion["reference_count"] = 1 logging.info(f"Success loading: {model_path}") break except Exception as e: logging.error(f"Failed loading {model_path}. {e}") self._download_model("WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM") retries += 1 if retries == max_retries: logging.error(f"Failed to load {model_path} after {max_retries} retries.") else: self.emotion["reference_count"] += 1 def release_emotion(self): if self.emotion is not None: self.emotion["reference_count"] -= 1 if self.emotion["reference_count"] <= 0: del self.emotion self.emotion = None gc.collect() torch.cuda.empty_cache() logging.info(f"Emotion model has been released.") def load_clap(self, max_retries=3): """Bert-VITS2 v2.2 ClapModel""" if self.clap is None: from transformers import ClapModel, ClapProcessor retries = 0 model_path = self.model_path["CLAP_HTSAT_FUSED"] while retries < max_retries: logging.info(f"Loading CLAP_HTSAT_FUSED: {model_path}") try: self.clap = {} self.clap["model"] = ClapModel.from_pretrained(model_path, torch_dtype=self.torch_dtype).to( self.device) self.clap["processor"] = ClapProcessor.from_pretrained(model_path, torch_dtype=self.torch_dtype) self.clap["reference_count"] = 1 logging.info(f"Success loading: {model_path}") break except Exception as e: logging.error(f"Failed loading {model_path}. {e}") self._download_model("CLAP_HTSAT_FUSED") retries += 1 if retries == max_retries: logging.error(f"Failed to load {model_path} after {max_retries} retries.") else: self.clap["reference_count"] += 1 def release_clap(self): if self.clap is not None: self.clap["reference_count"] -= 1 if self.clap["reference_count"] <= 0: del self.clap self.clap = None gc.collect() torch.cuda.empty_cache() logging.info(f"Clap model has been released.") def get_bert_model(self, bert_model_name): if bert_model_name not in self.bert_models: self.load_bert(bert_model_name) tokenizer, model, _ = self.bert_models[bert_model_name] return tokenizer, model def get_bert_feature(self, norm_text, word2ph, language, bert_model_name, style_text=None, style_weight=0.7): tokenizer, model = self.get_bert_model(bert_model_name) bert_feature = self.lang_bert_func_map[language](norm_text, word2ph, tokenizer, model, self.device, style_text=style_text, style_weight=style_weight) return bert_feature def get_pinyinPlus(self): if self.pinyinPlus is None: from bert_vits2.g2pW.pypinyin_G2pW_bv2 import G2PWPinyin logging.info(f"Loading G2PWModel: {self.model_path['G2PWModel']}") self.pinyinPlus = G2PWPinyin( model_dir=self.model_path["G2PWModel"], model_source=self.model_path["Erlangshen_MegatronBert_1.3B_Chinese"], v_to_u=False, neutral_tone_with_five=True, ) logging.info("Success loading G2PWModel") return self.pinyinPlus def release_bert(self, bert_model_name): if bert_model_name in self.bert_models: _, _, count = self.bert_models[bert_model_name] count -= 1 if count == 0: # 当引用计数为0时,删除模型并释放其资源 del self.bert_models[bert_model_name] gc.collect() torch.cuda.empty_cache() logging.info(f"BERT model {bert_model_name} has been released.") else: tokenizer, model = self.bert_models[bert_model_name][:2] self.bert_models[bert_model_name] = (tokenizer, model, count) def load_ssl(self, max_retries=3): """GPT-SoVITS""" if self.ssl_model is None: retries = 0 model_path = self.model_path["CHINESE_HUBERT_BASE"] while retries < max_retries: logging.info(f"Loading CHINESE_HUBERT_BASE: {model_path}") try: from gpt_sovits.feature_extractor.cnhubert import CNHubert self.ssl_model = {} model_path = self.model_path.get("CHINESE_HUBERT_BASE") self.ssl_model["model"] = CNHubert(model_path) self.ssl_model["model"].eval() if config.gpt_sovits_config.is_half: self.ssl_model["model"] = self.ssl_model["model"].half() self.ssl_model["model"] = self.ssl_model["model"].to(self.device) self.ssl_model["reference_count"] = 1 logging.info(f"Success loading: {model_path}") break except Exception as e: logging.error(f"Failed loading {model_path}. {e}") self._download_model("CHINESE_HUBERT_BASE") retries += 1 if retries == max_retries: logging.error(f"Failed to load {model_path} after {max_retries} retries.") else: self.ssl_model["reference_count"] += 1 def get_ssl_model(self): if self.ssl_model is None: self.load_ssl() return self.ssl_model.get("model") def release_ssl_model(self): if self.ssl_model is not None: self.ssl_model["reference_count"] -= 1 if self.ssl_model["reference_count"] <= 0: del self.ssl_model self.ssl_model = None gc.collect() torch.cuda.empty_cache() logging.info(f"SSL model has been released.") def is_model_loaded(self, bert_model_name): return bert_model_name in self.bert_models def reference_count(self, bert_model_name): return self.bert_models[bert_model_name][2] if bert_model_name in self.bert_models else 0