|
import os |
|
from typing import Union |
|
from box import Box |
|
import torch |
|
|
|
from modules import models |
|
from modules.utils.SeedContext import SeedContext |
|
|
|
import uuid |
|
|
|
|
|
def create_speaker_from_seed(seed): |
|
chat_tts = models.load_chat_tts() |
|
with SeedContext(seed): |
|
emb = chat_tts.sample_random_speaker() |
|
return emb |
|
|
|
|
|
class Speaker: |
|
@staticmethod |
|
def from_file(file_like): |
|
speaker = torch.load(file_like, map_location=torch.device("cpu")) |
|
speaker.fix() |
|
return speaker |
|
|
|
@staticmethod |
|
def from_tensor(tensor): |
|
speaker = Speaker(seed=-2) |
|
speaker.emb = tensor |
|
return speaker |
|
|
|
def __init__(self, seed, name="", gender="", describe=""): |
|
self.id = uuid.uuid4() |
|
self.seed = seed |
|
self.name = name |
|
self.gender = gender |
|
self.describe = describe |
|
self.emb = None |
|
|
|
|
|
self.tokens = [] |
|
|
|
def to_json(self, with_emb=False): |
|
return Box( |
|
**{ |
|
"id": str(self.id), |
|
"seed": self.seed, |
|
"name": self.name, |
|
"gender": self.gender, |
|
"describe": self.describe, |
|
"emb": self.emb.tolist() if with_emb else None, |
|
} |
|
) |
|
|
|
def fix(self): |
|
is_update = False |
|
if "id" not in self.__dict__: |
|
setattr(self, "id", uuid.uuid4()) |
|
is_update = True |
|
if "seed" not in self.__dict__: |
|
setattr(self, "seed", -2) |
|
is_update = True |
|
if "name" not in self.__dict__: |
|
setattr(self, "name", "") |
|
is_update = True |
|
if "gender" not in self.__dict__: |
|
setattr(self, "gender", "*") |
|
is_update = True |
|
if "describe" not in self.__dict__: |
|
setattr(self, "describe", "") |
|
is_update = True |
|
|
|
return is_update |
|
|
|
def __hash__(self): |
|
return hash(str(self.id)) |
|
|
|
def __eq__(self, other): |
|
if not isinstance(other, Speaker): |
|
return False |
|
return str(self.id) == str(other.id) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpeakerManager: |
|
def __init__(self): |
|
self.speakers = {} |
|
self.speaker_dir = "./data/speakers/" |
|
self.refresh_speakers() |
|
|
|
def refresh_speakers(self): |
|
self.speakers = {} |
|
for speaker_file in os.listdir(self.speaker_dir): |
|
if speaker_file.endswith(".pt"): |
|
self.speakers[speaker_file] = Speaker.from_file( |
|
self.speaker_dir + speaker_file |
|
) |
|
|
|
def list_speakers(self): |
|
return list(self.speakers.values()) |
|
|
|
def create_speaker_from_seed(self, seed, name="", gender="", describe=""): |
|
if name == "": |
|
name = seed |
|
filename = name + ".pt" |
|
speaker = Speaker(seed, name=name, gender=gender, describe=describe) |
|
speaker.emb = create_speaker_from_seed(seed) |
|
torch.save(speaker, self.speaker_dir + filename) |
|
self.refresh_speakers() |
|
return speaker |
|
|
|
def create_speaker_from_tensor( |
|
self, tensor, filename="", name="", gender="", describe="" |
|
): |
|
if filename == "": |
|
filename = name |
|
speaker = Speaker(seed=-2, name=name, gender=gender, describe=describe) |
|
if isinstance(tensor, torch.Tensor): |
|
speaker.emb = tensor |
|
if isinstance(tensor, list): |
|
speaker.emb = torch.tensor(tensor) |
|
torch.save(speaker, self.speaker_dir + filename + ".pt") |
|
self.refresh_speakers() |
|
return speaker |
|
|
|
def get_speaker(self, name) -> Union[Speaker, None]: |
|
for speaker in self.speakers.values(): |
|
if speaker.name == name: |
|
return speaker |
|
return None |
|
|
|
def get_speaker_by_id(self, id) -> Union[Speaker, None]: |
|
for speaker in self.speakers.values(): |
|
if str(speaker.id) == str(id): |
|
return speaker |
|
return None |
|
|
|
def get_speaker_filename(self, id: str): |
|
filename = None |
|
for fname, spk in self.speakers.items(): |
|
if str(spk.id) == str(id): |
|
filename = fname |
|
break |
|
return filename |
|
|
|
def update_speaker(self, speaker: Speaker): |
|
filename = None |
|
for fname, spk in self.speakers.items(): |
|
if str(spk.id) == str(speaker.id): |
|
filename = fname |
|
break |
|
|
|
if filename: |
|
torch.save(speaker, self.speaker_dir + filename) |
|
self.refresh_speakers() |
|
return speaker |
|
else: |
|
raise ValueError("Speaker not found for update") |
|
|
|
def save_all(self): |
|
for speaker in self.speakers.values(): |
|
filename = self.get_speaker_filename(speaker.id) |
|
torch.save(speaker, self.speaker_dir + filename) |
|
|
|
|
|
def __len__(self): |
|
return len(self.speakers) |
|
|
|
|
|
speaker_mgr = SpeakerManager() |
|
|