File size: 3,645 Bytes
a853668
1f89c40
a853668
 
 
1f89c40
 
 
a853668
 
1f89c40
 
 
 
 
 
 
a853668
 
1f89c40
 
 
a853668
 
1f89c40
 
 
 
 
a853668
1f89c40
 
 
 
 
 
 
 
a853668
1f89c40
 
a853668
 
1f89c40
 
 
 
 
a853668
1f89c40
 
a853668
 
1f89c40
 
 
 
 
 
 
 
 
a853668
1f89c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a853668
1f89c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import nemo.collections.asr as nemo_asr
import shutil
import os
from tempfile import NamedTemporaryFile
from typing import Dict
from pydantic import BaseModel
import uvicorn

# Dictionary mapping language codes to model names
LANGUAGE_MODELS = {
    "hi": "ai4bharat/indicconformer_stt_hi_hybrid_ctc_rnnt_large",
    "bn": "ai4bharat/indicconformer_stt_bn_hybrid_ctc_rnnt_large",
    "ta": "ai4bharat/indicconformer_stt_ta_hybrid_ctc_rnnt_large",
    # Add more languages and their corresponding models as needed
}


class TranscriptionResponse(BaseModel):
    text: str
    language: str


app = FastAPI(
    title="Indian Languages ASR API",
    description="API for automatic speech recognition in Indian languages",
    version="1.0.0",
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Cache for loaded models
model_cache = {}


def get_model(language: str):
    """
    Get or load the ASR model for the specified language
    """
    if language not in LANGUAGE_MODELS:
        raise HTTPException(
            status_code=400,
            detail=f"Unsupported language: {language}. Supported languages are: {list(LANGUAGE_MODELS.keys())}",
        )

    if language not in model_cache:
        try:
            model = nemo_asr.models.ASRModel.from_pretrained(LANGUAGE_MODELS[language])
            model_cache[language] = model
        except Exception as e:
            raise HTTPException(
                status_code=500,
                detail=f"Error loading model for language {language}: {str(e)}",
            )

    return model_cache[language]


@app.post("/transcribe/", response_model=TranscriptionResponse)
async def transcribe_audio(
    language: str,
    file: UploadFile = File(...),
):
    """
    Transcribe audio file in the specified Indian language

    Parameters:
    - language: Language code (e.g., 'hi' for Hindi, 'bn' for Bengali)
    - file: Audio file in WAV format

    Returns:
    - Transcription text and language
    """
    # Validate file format
    if not file.filename.endswith(".wav"):
        raise HTTPException(status_code=400, detail="Only WAV files are supported")

    # Get the appropriate model
    model = get_model(language)

    # Save uploaded file temporarily
    with NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
        try:
            # Copy uploaded file to temporary file
            shutil.copyfileobj(file.file, temp_file)
            temp_file.flush()

            # Perform transcription
            transcriptions = model.transcribe([temp_file.name])

            if not transcriptions or len(transcriptions) == 0:
                raise HTTPException(status_code=500, detail="Transcription failed")

            return TranscriptionResponse(text=transcriptions[0], language=language)

        except Exception as e:
            raise HTTPException(
                status_code=500, detail=f"Error during transcription: {str(e)}"
            )
        finally:
            # Clean up temporary file
            os.unlink(temp_file.name)


@app.get("/languages/")
async def get_supported_languages() -> Dict[str, str]:
    """
    Get list of supported languages and their model names
    """
    return LANGUAGE_MODELS


@app.get("/health/")
async def health_check():
    """
    Health check endpoint
    """
    return {"status": "healthy"}


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)