Spaces:
Runtime error
Runtime error
import torch | |
from parler_tts import ParlerTTSForConditionalGeneration | |
from transformers import AutoTokenizer | |
import soundfile as sf | |
class TTSModel: | |
def __init__(self): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model_name = "ai4bharat/indic-parler-tts" | |
# Print cache directory and model files | |
print(f"Loading model on device: {self.device}") | |
# Initialize model and tokenizers exactly as in the documentation | |
self.model = ParlerTTSForConditionalGeneration.from_pretrained(self.model_name).to(self.device) | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self.description_tokenizer = AutoTokenizer.from_pretrained(self.model.config.text_encoder._name_or_path) | |
print("Model loaded successfully") | |
def generate_audio(self, text, description): | |
try: | |
# Tokenize exactly as shown in the documentation | |
description_inputs = self.description_tokenizer( | |
description, | |
return_tensors="pt" | |
).to(self.device) | |
prompt_inputs = self.tokenizer( | |
text, | |
return_tensors="pt" | |
).to(self.device) | |
# Generate audio | |
with torch.no_grad(): | |
generation = self.model.generate( | |
input_ids=description_inputs.input_ids, | |
attention_mask=description_inputs.attention_mask, | |
prompt_input_ids=prompt_inputs.input_ids, | |
prompt_attention_mask=prompt_inputs.attention_mask | |
) | |
# Convert to numpy array | |
audio_array = generation.cpu().numpy().squeeze() | |
return audio_array | |
except Exception as e: | |
print(f"Error in speech generation: {str(e)}") | |
raise | |