Spaces:
Running
on
Zero
Running
on
Zero
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py | |
# Copyright 2023 (authors: Feiteng Li) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import re | |
from dataclasses import asdict, dataclass | |
from typing import Any, Dict, List, Optional, Pattern, Union | |
import numpy as np | |
import torch | |
import torchaudio | |
# from lhotse.features import FeatureExtractor | |
# from lhotse.utils import Seconds, compute_num_frames | |
from phonemizer.backend import EspeakBackend | |
from phonemizer.backend.espeak.language_switch import LanguageSwitch | |
from phonemizer.backend.espeak.words_mismatch import WordMismatch | |
from phonemizer.punctuation import Punctuation | |
from phonemizer.separator import Separator | |
class TextTokenizer: | |
"""Phonemize Text.""" | |
def __init__( | |
self, | |
language="en-us", | |
backend="espeak", | |
separator=Separator(word="_", syllable="-", phone="|"), | |
preserve_punctuation=True, | |
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), | |
with_stress: bool = False, | |
tie: Union[bool, str] = False, | |
language_switch: LanguageSwitch = "keep-flags", | |
words_mismatch: WordMismatch = "ignore", | |
) -> None: | |
phonemizer = EspeakBackend( | |
language, | |
punctuation_marks=punctuation_marks, | |
preserve_punctuation=preserve_punctuation, | |
with_stress=with_stress, | |
tie=tie, | |
language_switch=language_switch, | |
words_mismatch=words_mismatch, | |
) | |
self.backend = phonemizer | |
self.separator = separator | |
def to_list(self, phonemized: str) -> List[str]: | |
fields = [] | |
for word in phonemized.split(self.separator.word): | |
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z. | |
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) | |
fields.extend( | |
[p for p in pp if p != self.separator.phone] | |
+ [self.separator.word] | |
) | |
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( | |
self.separator.phone | |
) | |
return fields[:-1] | |
def __call__(self, text, strip=True) -> List[List[str]]: | |
if isinstance(text, str): | |
text = [text] | |
phonemized = self.backend.phonemize( | |
text, separator=self.separator, strip=strip, njobs=1 | |
) | |
return [self.to_list(p) for p in phonemized] | |
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]: | |
phonemes = tokenizer([text.strip()]) | |
return phonemes[0] # k2symbols | |
def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int): | |
assert wav.shape[0] in [1, 2], "Audio must be mono or stereo." | |
if target_channels == 1: | |
wav = wav.mean(0, keepdim=True) | |
elif target_channels == 2: | |
*shape, _, length = wav.shape | |
wav = wav.expand(*shape, target_channels, length) | |
elif wav.shape[0] == 1: | |
wav = wav.expand(target_channels, -1) | |
wav = torchaudio.transforms.Resample(sr, target_sr)(wav) | |
return wav | |
class AudioTokenizer: | |
"""EnCodec audio.""" | |
def __init__( | |
self, | |
device: Any = None, | |
signature = None | |
) -> None: | |
from audiocraft.solvers import CompressionSolver | |
model = CompressionSolver.model_from_checkpoint(signature) | |
self.sample_rate = model.sample_rate | |
self.channels = model.channels | |
if not device: | |
device = torch.device("cpu") | |
if torch.cuda.is_available(): | |
device = torch.device("cuda:0") | |
self._device = device | |
self.codec = model.to(device) | |
def device(self): | |
return self._device | |
def encode(self, wav: torch.Tensor) -> torch.Tensor: | |
codes = self.codec.encode(wav.to(self.device)) | |
return [(codes[0], None)] | |
def decode(self, frames: torch.Tensor) -> torch.Tensor: | |
frames = frames[0][0] # [1,4,T] | |
return self.codec.decode(frames) | |
def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1): | |
# Load and pre-process the audio waveform | |
if offset != -1 and num_frames!=-1: | |
wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames) | |
else: | |
wav, sr = torchaudio.load(audio_path) | |
wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels) | |
wav = wav.unsqueeze(0) | |
# Extract discrete codes from EnCodec | |
with torch.no_grad(): | |
encoded_frames = tokenizer.encode(wav) | |
return encoded_frames |