File size: 7,683 Bytes
5a39a85
 
30aecac
 
5a39a85
 
 
30aecac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a39a85
30aecac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a39a85
 
 
 
30aecac
 
 
 
 
 
 
 
5a39a85
 
30aecac
 
5a39a85
 
 
 
30aecac
5a39a85
 
 
 
 
 
 
30aecac
 
 
 
 
 
 
 
5a39a85
 
30aecac
5a39a85
 
30aecac
 
 
 
 
5a39a85
 
30aecac
 
 
 
 
5a39a85
7ce428c
30aecac
7ce428c
30aecac
 
7ce428c
 
 
 
 
 
 
30aecac
 
 
 
 
 
 
 
 
 
 
 
5a39a85
7ce428c
30aecac
 
 
 
 
 
 
 
 
 
 
 
 
 
7ce428c
 
5a39a85
7ce428c
 
30aecac
7ce428c
 
 
 
 
 
5a39a85
7ce428c
30aecac
 
 
 
 
 
7ce428c
30aecac
 
 
 
 
7ce428c
5a39a85
 
7ce428c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import gradio as gr
import torch
import torch.nn as nn
import os
from outetts.v0_1.interface import InterfaceHF
import soundfile as sf
import tempfile
from faster_whisper import WhisperModel
from pathlib import Path

# Configure PyTorch for CPU efficiency
torch.set_num_threads(4)  # Limit CPU threads
torch.set_grad_enabled(False)  # Disable gradient computation

class OptimizedTTSInterface:
    def __init__(self, model_name="OuteAI/OuteTTS-0.1-350M"):
        self.interface = InterfaceHF(model_name)
        # Quantize the model to INT8
        self.interface.model = torch.quantization.quantize_dynamic(
            self.interface.model, {nn.Linear}, dtype=torch.qint8
        )
        # Move model to CPU and enable inference mode
        self.interface.model.cpu()
        self.interface.model.eval()
        
    def create_speaker(self, *args, **kwargs):
        with torch.inference_mode():
            return self.interface.create_speaker(*args, **kwargs)
            
    def generate(self, *args, **kwargs):
        with torch.inference_mode():
            return self.interface.generate(*args, **kwargs)

def initialize_models():
    """Initialize the OptimizedTTS and Faster-Whisper models"""
    # Use cached models if available
    cache_dir = Path("model_cache")
    cache_dir.mkdir(exist_ok=True)
    
    tts_interface = OptimizedTTSInterface()
    
    # Initialize Whisper with maximum optimization
    asr_model = WhisperModel("tiny", 
                            device="cpu",
                            compute_type="int8",
                            num_workers=1,
                            cpu_threads=2,
                            download_root=str(cache_dir))
    return tts_interface, asr_model

def transcribe_audio(audio_path):
    """Transcribe audio using Faster-Whisper tiny"""
    try:
        segments, _ = ASR_MODEL.transcribe(audio_path,
                                         beam_size=1,
                                         best_of=1,
                                         temperature=1.0,
                                         condition_on_previous_text=False,
                                         compression_ratio_threshold=2.4,
                                         log_prob_threshold=-1.0,
                                         no_speech_threshold=0.6)
        
        text = " ".join([segment.text for segment in segments]).strip()
        return text
    except Exception as e:
        return f"Error transcribing audio: {str(e)}"

def preprocess_audio(audio_path):
    """Preprocess audio to reduce memory usage"""
    try:
        # Load and resample audio to 16kHz if needed
        data, sr = sf.read(audio_path)
        if sr != 16000:
            import resampy
            data = resampy.resample(data, sr, 16000)
            sr = 16000
        
        # Convert to mono if stereo
        if len(data.shape) > 1:
            data = data.mean(axis=1)
        
        # Save preprocessed audio
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
        sf.write(temp_file.name, data, sr)
        return temp_file.name
    except Exception as e:
        return audio_path  # Return original if preprocessing fails

def process_audio_file(audio_path, reference_text, text_to_speak, temperature=0.1, repetition_penalty=1.1):
    """Process the audio file and generate speech with the cloned voice"""
    try:
        # Preprocess audio
        processed_audio = preprocess_audio(audio_path)
        
        # If no reference text provided, transcribe the audio
        if not reference_text.strip():
            reference_text = transcribe_audio(processed_audio)
            if reference_text.startswith("Error"):
                return None, reference_text
        
        # Create speaker from reference audio
        speaker = TTS_INTERFACE.create_speaker(
            processed_audio,
            reference_text
        )
        
        # Generate speech with cloned voice
        output = TTS_INTERFACE.generate(
            text=text_to_speak,
            speaker=speaker,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
            max_lenght=4096
        )
        
        # Clean up preprocessed audio if it was created
        if processed_audio != audio_path:
            try:
                os.unlink(processed_audio)
            except:
                pass
        
        # Save output to temporary file
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
        output.save(temp_file.name)
        return temp_file.name, f"Voice cloning successful!\nReference text used: {reference_text}"
        
    except Exception as e:
        if processed_audio != audio_path:
            try:
                os.unlink(processed_audio)
            except:
                pass
        return None, f"Error: {str(e)}"

print("Initializing models...")
# Initialize models globally
TTS_INTERFACE, ASR_MODEL = initialize_models()
print("Models initialized!")

# Create Gradio interface
with gr.Blocks(title="Voice Cloning with OuteTTS") as demo:
    gr.Markdown("# ๐ŸŽ™๏ธ Optimized Voice Cloning with OuteTTS")
    gr.Markdown("""
    This app uses optimized versions of OuteTTS and Whisper for efficient voice cloning on CPU. 
    Upload a reference audio file, provide the text being spoken in that audio (or leave blank for automatic transcription),
    and enter the new text you want to be spoken in the cloned voice.
    
    Note: For best results, use clear audio with minimal background noise.
    """)
    
    with gr.Row():
        with gr.Column():
            audio_input = gr.Audio(
                label="Upload Reference Audio", 
                type="filepath"
            )
            reference_text = gr.Textbox(
                label="Reference Text (leave blank for auto-transcription)",
                placeholder="Leave empty to auto-transcribe or enter the exact text from the reference audio"
            )
            text_to_speak = gr.Textbox(
                label="Text to Speak",
                placeholder="Enter the text you want the cloned voice to speak"
            )
            
            with gr.Row():
                temperature = gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.1,
                    step=0.1,
                    label="Temperature"
                )
                repetition_penalty = gr.Slider(
                    minimum=1.0,
                    maximum=2.0,
                    value=1.1,
                    step=0.1,
                    label="Repetition Penalty"
                )
            
            submit_btn = gr.Button("Generate Voice", variant="primary")
        
        with gr.Column():
            output_audio = gr.Audio(label="Generated Speech")
            output_message = gr.Textbox(label="Status", max_lines=3)
    
    submit_btn.click(
        fn=process_audio_file,
        inputs=[audio_input, reference_text, text_to_speak, temperature, repetition_penalty],
        outputs=[output_audio, output_message]
    )
    
    gr.Markdown("""
    ### Optimization Notes:
    - Using INT8 quantization for efficient CPU usage
    - Optimized audio preprocessing
    - Cached model loading
    - Memory-efficient inference
    
    ### Tips for best results:
    1. Use clear, high-quality reference audio
    2. Keep reference audio short (5-10 seconds)
    3. Verify auto-transcription accuracy
    4. For best quality, manually input exact reference text
    5. Keep generated text concise
    """)

if __name__ == "__main__":
    demo.launch()