Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import JSONResponse, FileResponse | |
from pydantic import BaseModel | |
import numpy as np | |
import io | |
import soundfile as sf | |
import base64 | |
import logging | |
import torch | |
import librosa | |
from pathlib import Path | |
import magic # For MIME type detection | |
from pydub import AudioSegment | |
# Import functions from other modules | |
from asr import transcribe, ASR_LANGUAGES | |
from tts import synthesize, TTS_LANGUAGES | |
from lid import identify | |
from asr import ASR_SAMPLING_RATE | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI(title="MMS: Scaling Speech Technology to 1000+ languages") | |
# Define request models | |
class AudioRequest(BaseModel): | |
audio: str # Base64 encoded audio or video data | |
language: str | |
class TTSRequest(BaseModel): | |
text: str | |
language: str | |
speed: float | |
def detect_mime_type(input_bytes): | |
mime = magic.Magic(mime=True) | |
return mime.from_buffer(input_bytes) | |
def extract_audio(input_bytes): | |
mime_type = detect_mime_type(input_bytes) | |
if mime_type.startswith('audio/'): | |
return sf.read(io.BytesIO(input_bytes)) | |
elif mime_type.startswith('video/webm'): | |
audio = AudioSegment.from_file(io.BytesIO(input_bytes), format="webm") | |
audio_array = np.array(audio.get_array_of_samples()) | |
sample_rate = audio.frame_rate | |
return audio_array, sample_rate | |
else: | |
raise ValueError(f"Unsupported MIME type: {mime_type}") | |
async def transcribe_audio(request: AudioRequest): | |
try: | |
input_bytes = base64.b64decode(request.audio) | |
audio_array, sample_rate = extract_audio(input_bytes) | |
# Convert to mono if stereo | |
if len(audio_array.shape) > 1: | |
audio_array = audio_array.mean(axis=1) | |
# Ensure audio_array is float32 | |
audio_array = audio_array.astype(np.float32) | |
# Resample if necessary | |
if sample_rate != ASR_SAMPLING_RATE: | |
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE) | |
result = transcribe(audio_array, request.language) | |
return JSONResponse(content={"transcription": result}) | |
except Exception as e: | |
logger.error(f"Error in transcribe_audio: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
async def synthesize_speech(request: TTSRequest): | |
try: | |
logger.info(f"Synthesizing speech for text: {request.text}, language: {request.language}, speed: {request.speed}") | |
result, filtered_text = synthesize(request.text, request.language, request.speed) | |
logger.info(f"Synthesis complete. Filtered text: {filtered_text}") | |
sample_rate, audio = result | |
logger.info(f"Sample rate: {sample_rate}, Audio shape: {audio.shape}, Audio dtype: {audio.dtype}") | |
# Ensure audio is a numpy array with the correct dtype | |
audio = np.array(audio, dtype=np.float32) | |
# Normalize audio to [-1, 1] range | |
audio = audio / np.max(np.abs(audio)) | |
# Convert to int16 for WAV file | |
audio = (audio * 32767).astype(np.int16) | |
# Convert numpy array to bytes | |
buffer = io.BytesIO() | |
sf.write(buffer, audio, sample_rate, format='wav') | |
buffer.seek(0) | |
return FileResponse( | |
buffer, | |
media_type="audio/wav", | |
headers={"Content-Disposition": "attachment; filename=synthesized_audio.wav"} | |
) | |
except Exception as e: | |
logger.error(f"Error in synthesize_speech: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
async def identify_language(request: AudioRequest): | |
try: | |
input_bytes = base64.b64decode(request.audio) | |
audio_array, sample_rate = extract_audio(input_bytes) | |
result = identify(audio_array) | |
return JSONResponse(content={"language_identification": result}) | |
except Exception as e: | |
logger.error(f"Error in identify_language: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
async def get_asr_languages(): | |
try: | |
return JSONResponse(content=ASR_LANGUAGES) | |
except Exception as e: | |
logger.error(f"Error in get_asr_languages: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
async def get_tts_languages(): | |
try: | |
return JSONResponse(content=TTS_LANGUAGES) | |
except Exception as e: | |
logger.error(f"Error in get_tts_languages: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") |