|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
from dataclasses import asdict, dataclass |
|
from typing import Any, Dict, List, Optional, Pattern, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torchaudio |
|
|
|
|
|
from phonemizer.backend import EspeakBackend |
|
from phonemizer.backend.espeak.language_switch import LanguageSwitch |
|
from phonemizer.backend.espeak.words_mismatch import WordMismatch |
|
from phonemizer.punctuation import Punctuation |
|
from phonemizer.separator import Separator |
|
|
|
|
|
|
|
class TextTokenizer: |
|
"""Phonemize Text.""" |
|
|
|
def __init__( |
|
self, |
|
language="en-us", |
|
backend="espeak", |
|
separator=Separator(word="_", syllable="-", phone="|"), |
|
preserve_punctuation=True, |
|
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), |
|
with_stress: bool = False, |
|
tie: Union[bool, str] = False, |
|
language_switch: LanguageSwitch = "keep-flags", |
|
words_mismatch: WordMismatch = "ignore", |
|
) -> None: |
|
phonemizer = EspeakBackend( |
|
language, |
|
punctuation_marks=punctuation_marks, |
|
preserve_punctuation=preserve_punctuation, |
|
with_stress=with_stress, |
|
tie=tie, |
|
language_switch=language_switch, |
|
words_mismatch=words_mismatch, |
|
) |
|
|
|
self.backend = phonemizer |
|
self.separator = separator |
|
|
|
def to_list(self, phonemized: str) -> List[str]: |
|
fields = [] |
|
for word in phonemized.split(self.separator.word): |
|
|
|
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) |
|
fields.extend( |
|
[p for p in pp if p != self.separator.phone] |
|
+ [self.separator.word] |
|
) |
|
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( |
|
self.separator.phone |
|
) |
|
return fields[:-1] |
|
|
|
def __call__(self, text, strip=True) -> List[List[str]]: |
|
if isinstance(text, str): |
|
text = [text] |
|
|
|
phonemized = self.backend.phonemize( |
|
text, separator=self.separator, strip=strip, njobs=1 |
|
) |
|
return [self.to_list(p) for p in phonemized] |
|
|
|
|
|
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]: |
|
phonemes = tokenizer([text.strip()]) |
|
return phonemes[0] |
|
|
|
def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int): |
|
assert wav.shape[0] in [1, 2], "Audio must be mono or stereo." |
|
if target_channels == 1: |
|
wav = wav.mean(0, keepdim=True) |
|
elif target_channels == 2: |
|
*shape, _, length = wav.shape |
|
wav = wav.expand(*shape, target_channels, length) |
|
elif wav.shape[0] == 1: |
|
wav = wav.expand(target_channels, -1) |
|
wav = torchaudio.transforms.Resample(sr, target_sr)(wav) |
|
return wav |
|
|
|
class AudioTokenizer: |
|
"""EnCodec audio.""" |
|
|
|
def __init__( |
|
self, |
|
device: Any = None, |
|
signature = None |
|
) -> None: |
|
from audiocraft.solvers import CompressionSolver |
|
model = CompressionSolver.model_from_checkpoint(signature) |
|
self.sample_rate = model.sample_rate |
|
self.channels = model.channels |
|
|
|
if not device: |
|
device = torch.device("cpu") |
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda:0") |
|
|
|
self._device = device |
|
|
|
self.codec = model.to(device) |
|
|
|
@property |
|
def device(self): |
|
return self._device |
|
|
|
def encode(self, wav: torch.Tensor) -> torch.Tensor: |
|
codes = self.codec.encode(wav.to(self.device)) |
|
return [(codes[0], None)] |
|
|
|
def decode(self, frames: torch.Tensor) -> torch.Tensor: |
|
frames = frames[0][0] |
|
return self.codec.decode(frames) |
|
|
|
|
|
|
|
def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1): |
|
|
|
if offset != -1 and num_frames!=-1: |
|
wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames) |
|
else: |
|
wav, sr = torchaudio.load(audio_path) |
|
wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels) |
|
wav = wav.unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
encoded_frames = tokenizer.encode(wav) |
|
return encoded_frames |