File size: 1,898 Bytes
544e017
 
 
7a0f405
544e017
 
bf18876
 
544e017
3da96bb
64601f3
7a0f405
544e017
 
 
7a0f405
544e017
 
 
 
 
7a0f405
 
 
 
 
 
 
 
339c131
7a0f405
 
 
 
 
3da96bb
544e017
 
3da96bb
 
544e017
 
 
3da96bb
544e017
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torchaudio
import torch
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
import numpy as np

# Load processor and model
processor = AutoProcessor.from_pretrained("ixxan/whisper-small-common-voice-ug")
model = AutoModelForSpeechSeq2Seq.from_pretrained("ixxan/whisper-small-common-voice-ug")

target_sr = processor.feature_extractor.sampling_rate

def transcribe(audio_data) -> str:
    """
    Transcribes audio to text using the Whisper model for Uyghur.
    Args:
    - audio_data: Gradio audio input 
    Returns:
    - str: The transcription of the audio.
    """

    # Load audio file
    if not audio_data:
        return "<<ERROR: Empty Audio Input>>"

    if isinstance(audio_data, tuple):
        # microphone
        sampling_rate, audio_input = audio_data
        audio_input = (audio_input / 32768.0).astype(np.float32)
        
    elif isinstance(audio_data, str):
        # file upload
        audio_input, sampling_rate = torchaudio.load(audio_data)
        
    else:
        return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data))


    # Resample if needed
    if sampling_rate != target_sr:
        resampler = torchaudio.transforms.Resample(sampling_rate, target_sr)
        audio_input = resampler(audio_input)

    # Preprocess the audio input
    inputs = processor(audio_input.squeeze(), sampling_rate=target_sr, return_tensors="pt")

    # Move model to GPU if available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    inputs = {key: val.to(device) for key, val in inputs.items()}

    # Generate transcription
    with torch.no_grad():
        generated_ids = model.generate(inputs["input_features"], max_length=225)

    # Decode the output to get the transcription text
    transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    return transcription