File size: 1,511 Bytes
807daef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea38473
 
807daef
 
ea38473
807daef
 
 
 
 
 
 
 
 
 
 
 
 
ea38473
807daef
 
ea38473
807daef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
import torch
import librosa

model_id = "facebook/mms-lid-1024"

processor = AutoFeatureExtractor.from_pretrained(model_id)
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id)

LID_SAMPLING_RATE = 16_000
LID_THRESHOLD = 0.33

LID_LANGUAGES = {}
with open(f"data/lid/all_langs.tsv") as f:
    for line in f:
        iso, name = line.split(" ", 1)
        LID_LANGUAGES[iso] = name.strip()

def identify_language(audio=None):
    if audio is None:
        return "ERROR: You have to either use the microphone or upload an audio file"
    
    audio_samples = librosa.load(audio, sr=LID_SAMPLING_RATE, mono=True)[0]
    inputs = processor(audio_samples, sampling_rate=LID_SAMPLING_RATE, return_tensors="pt")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    inputs = inputs.to(device)

    with torch.no_grad():
        logit = model(**inputs).logits

    logit_lsm = torch.log_softmax(logit.squeeze(), dim=-1)
    scores, indices = torch.topk(logit_lsm, 5, dim=-1)
    scores, indices = torch.exp(scores).to("cpu").tolist(), indices.to("cpu").tolist()
    iso2score = {model.config.id2label[int(i)]: s for s, i in zip(scores, indices)}
    
    if max(iso2score.values()) < LID_THRESHOLD:
        return "Low confidence in the language identification predictions. Output is not shown!"
    
    return {LID_LANGUAGES[iso]: score for iso, score in iso2score.items()}