File size: 3,041 Bytes
3023ee7
 
 
 
 
 
 
a6620dd
3023ee7
 
 
 
 
a6620dd
3023ee7
 
 
 
 
a6620dd
 
 
 
 
 
3023ee7
 
 
 
 
 
 
 
a6620dd
679e261
49e5bfc
 
 
 
3023ee7
49e5bfc
3023ee7
 
 
 
 
 
49e5bfc
3023ee7
 
49e5bfc
3023ee7
 
97bf953
 
 
 
 
3023ee7
97bf953
 
 
 
 
 
 
 
 
 
 
 
3023ee7
 
d4ed546
3023ee7
 
 
 
 
 
 
d4ed546
3023ee7
 
 
 
 
 
 
 
 
 
97bf953
3023ee7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch

import gradio as gr
import yt_dlp as youtube_dl
from transformers import pipeline
from transformers.pipelines.audio_utils import ffmpeg_read
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
from faster_whisper import WhisperModel


import tempfile
import os

MODEL_NAME = "medium"
BATCH_SIZE = 8
FILE_LIMIT_MB = 1000

device = 0 if torch.cuda.is_available() else "cpu"

# pipe = pipeline(
#     task="automatic-speech-recognition",
#     model=MODEL_NAME,
#     chunk_length_s=30,
#     device=device,
# )

model = MBartForConditionalGeneration.from_pretrained("sanjitaa/mbart-many-to-many")
tokenizer = MBart50TokenizerFast.from_pretrained("sanjitaa/mbart-many-to-many")

def translate(inputs, task):
    if inputs is None:
        raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")

    #text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
    ts_model = WhisperModel(MODEL_NAME, device = device, compute_type = "int8")
    segments, _ = ts_model.transcribe(inputs, task = "translate")
    lst = ''
    for segment in segments:
         lst = lst + segment.text

    encoded_text = tokenizer(lst, return_tensors="pt")
    tokenizer.src_lang = "en_XX"
    
    generated_tokens = model.generate(
        **encoded_text,
        forced_bos_token_id=tokenizer.lang_code_to_id["fr_XX"]
    )
    
    result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
    return result
       
demo = gr.Blocks()

# mf_transcribe = gr.Interface(
#     fn=translate,
#     inputs=[
#         gr.inputs.Audio(source="microphone", type="filepath", optional=True),
#         gr.inputs.Radio(["translate"], label="Task", default="translate"),
        
#     ],
#     outputs="text",
#     layout="horizontal",
#     theme="huggingface",
#     title="Whisper Medium: Transcribe Audio",
#     description=(
#         "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the"
#         f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and πŸ€— Transformers to transcribe audio files"
#         " of arbitrary length."
#     ),
#     allow_flagging="never",
# )

file_transcribe = gr.Interface(
    fn=translate,
    inputs=[
        gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Audio file"),
        gr.inputs.Radio(["translate"], label="Task", default="transcribe"),
    ],
    outputs="text",
    layout="horizontal",
    theme="huggingface",
    title="Whisper Medium: Transcribe Audio",
    description=(
        "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the"
        f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and πŸ€— Transformers to transcribe audio files"
        " of arbitrary length."
    ),
    allow_flagging="never",
)


with demo:
    gr.TabbedInterface([file_transcribe], ["Audio file"])

demo.launch(enable_queue=True)