DHEIVER commited on
Commit
23dd469
·
verified ·
1 Parent(s): 9167858

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import scipy.io.wavfile
4
+ from transformers import AutoProcessor, SeamlessM4Tv2Model
5
+ from pathlib import Path
6
+ from typing import Optional, Union
7
+
8
+ class SeamlessTranslator:
9
+ """
10
+ A wrapper class for Facebook's SeamlessM4T translation model.
11
+ Handles both text-to-speech and speech-to-speech translation.
12
+ """
13
+
14
+ def __init__(self, model_name: str = "facebook/seamless-m4t-v2-large"):
15
+ """
16
+ Initialize the translator with the specified model.
17
+
18
+ Args:
19
+ model_name (str): Name of the model to use
20
+ """
21
+ try:
22
+ self.processor = AutoProcessor.from_pretrained(model_name)
23
+ self.model = SeamlessM4Tv2Model.from_pretrained(model_name)
24
+ self.sample_rate = self.model.config.sampling_rate
25
+ except Exception as e:
26
+ raise RuntimeError(f"Failed to initialize model: {str(e)}")
27
+
28
+ def translate_text(self, text: str, src_lang: str, tgt_lang: str) -> numpy.ndarray:
29
+ """
30
+ Translate text to speech in the target language.
31
+
32
+ Args:
33
+ text (str): Input text to translate
34
+ src_lang (str): Source language code (e.g., 'eng')
35
+ tgt_lang (str): Target language code (e.g., 'rus')
36
+
37
+ Returns:
38
+ numpy.ndarray: Audio waveform array
39
+ """
40
+ try:
41
+ inputs = self.processor(text=text, src_lang=src_lang, return_tensors="pt")
42
+ audio_array = self.model.generate(**inputs, tgt_lang=tgt_lang)[0].cpu().numpy().squeeze()
43
+ return audio_array
44
+ except Exception as e:
45
+ raise RuntimeError(f"Text translation failed: {str(e)}")
46
+
47
+ def translate_audio(self, audio_path: Union[str, Path], tgt_lang: str) -> numpy.ndarray:
48
+ """
49
+ Translate audio to speech in the target language.
50
+
51
+ Args:
52
+ audio_path (str or Path): Path to input audio file
53
+ tgt_lang (str): Target language code (e.g., 'rus')
54
+
55
+ Returns:
56
+ numpy.ndarray: Audio waveform array
57
+ """
58
+ try:
59
+ # Load and resample audio
60
+ audio, orig_freq = torchaudio.load(audio_path)
61
+ audio = torchaudio.functional.resample(
62
+ audio,
63
+ orig_freq=orig_freq,
64
+ new_freq=16_000
65
+ )
66
+
67
+ # Process and generate translation
68
+ inputs = self.processor(audios=audio, return_tensors="pt")
69
+ audio_array = self.model.generate(**inputs, tgt_lang=tgt_lang)[0].cpu().numpy().squeeze()
70
+ return audio_array
71
+ except Exception as e:
72
+ raise RuntimeError(f"Audio translation failed: {str(e)}")
73
+
74
+ def save_audio(self, audio_array: numpy.ndarray, output_path: Union[str, Path]) -> None:
75
+ """
76
+ Save an audio array to a WAV file.
77
+
78
+ Args:
79
+ audio_array (numpy.ndarray): Audio data to save
80
+ output_path (str or Path): Path where to save the WAV file
81
+ """
82
+ try:
83
+ scipy.io.wavfile.write(
84
+ output_path,
85
+ rate=self.sample_rate,
86
+ data=audio_array
87
+ )
88
+ except Exception as e:
89
+ raise RuntimeError(f"Failed to save audio: {str(e)}")
90
+
91
+ def main():
92
+ """Example usage of the SeamlessTranslator class."""
93
+ try:
94
+ # Initialize translator
95
+ translator = SeamlessTranslator()
96
+
97
+ # Example text translation
98
+ text_audio = translator.translate_text(
99
+ text="Hello, my dog is cute",
100
+ src_lang="eng",
101
+ tgt_lang="rus"
102
+ )
103
+ translator.save_audio(text_audio, "output_from_text.wav")
104
+
105
+ # Example audio translation
106
+ audio_audio = translator.translate_audio(
107
+ audio_path="input_audio.wav",
108
+ tgt_lang="rus"
109
+ )
110
+ translator.save_audio(audio_audio, "output_from_audio.wav")
111
+
112
+ except Exception as e:
113
+ print(f"Translation failed: {str(e)}")
114
+
115
+ if __name__ == "__main__":
116
+ main()