abhishekrajpurohit's picture
Upload 39 files
195bb33 verified
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf
class TTSModel:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_name = "ai4bharat/indic-parler-tts"
# Print cache directory and model files
print(f"Loading model on device: {self.device}")
# Initialize model and tokenizers exactly as in the documentation
self.model = ParlerTTSForConditionalGeneration.from_pretrained(self.model_name).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.description_tokenizer = AutoTokenizer.from_pretrained(self.model.config.text_encoder._name_or_path)
print("Model loaded successfully")
def generate_audio(self, text, description):
try:
# Tokenize exactly as shown in the documentation
description_inputs = self.description_tokenizer(
description,
return_tensors="pt"
).to(self.device)
prompt_inputs = self.tokenizer(
text,
return_tensors="pt"
).to(self.device)
# Generate audio
with torch.no_grad():
generation = self.model.generate(
input_ids=description_inputs.input_ids,
attention_mask=description_inputs.attention_mask,
prompt_input_ids=prompt_inputs.input_ids,
prompt_attention_mask=prompt_inputs.attention_mask
)
# Convert to numpy array
audio_array = generation.cpu().numpy().squeeze()
return audio_array
except Exception as e:
print(f"Error in speech generation: {str(e)}")
raise