Ahmed107 commited on
Commit
f60e2c1
·
verified ·
1 Parent(s): 1a6b12f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -24
app.py CHANGED
@@ -1,30 +1,50 @@
1
- from transformers import pipeline,AutoConfig,AutoModelForAudioClassification,AutoFeatureExtractor
 
 
 
2
 
3
- model_id = "Ahmed107/whisper-tiny-finetuned-eos"
 
4
 
5
- # define mappings as dictionaries
6
- id2label = {"0": "NOT_EOS", "1": "EOS"}
7
- label2id = {"NOT_EOS": "0", "EOS": "1"}
 
 
 
 
8
 
9
- # define config
10
- config = AutoConfig.from_pretrained(model_id, label2id=label2id, id2label=id2label)
11
- model = AutoModelForAudioClassification.from_pretrained(model_id, config = config)
12
- feature_extractor = AutoFeatureExtractor.from_pretrained(
13
- model_id,
14
- )
15
- pipe = pipeline("audio-classification", model=model,feature_extractor=feature_extractor)
 
 
 
16
 
17
- def classify_audio(filepath):
18
- preds = pipe(filepath)
19
- print(preds)
20
- outputs = {}
21
- for p in preds:
22
- outputs[p["label"]] = p["score"]
23
- return outputs
 
 
 
 
 
 
 
 
24
 
25
- import gradio as gr
 
 
 
26
 
27
- demo = gr.Interface(
28
- fn=classify_audio, inputs=gr.Audio(type="filepath"), outputs=gr.Label()
29
- )
30
- demo.launch(debug=True)
 
1
+ import gradio as gr
2
+ import torchaudio
3
+ from transformers import pipeline
4
+ from datasets import load_dataset, Audio
5
 
6
+ # Load your model
7
+ classifier = pipeline("audio-classification", model="Ahmed107/whisper-tiny-finetuned-eos")
8
 
9
+ # Function to resample audio to 16kHz
10
+ def resample_audio(audio_file, target_sampling_rate=16000):
11
+ waveform, original_sample_rate = torchaudio.load(audio_file)
12
+ if original_sample_rate != target_sampling_rate:
13
+ resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=target_sampling_rate)
14
+ waveform = resampler(waveform)
15
+ return waveform.squeeze().numpy(), target_sampling_rate
16
 
17
+ # Define the prediction function
18
+ def classify_audio(audio_file):
19
+ # Resample the audio to 16kHz
20
+ resampled_audio, _ = resample_audio(audio_file)
21
+
22
+ # Classify the audio
23
+ prediction = classifier(resampled_audio)
24
+
25
+ # Return predictions as a dictionary
26
+ return {entry['label']: entry['score'] for entry in prediction}
27
 
28
+ # Define Gradio interface
29
+ def demo():
30
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
31
+ gr.Markdown("## Eos")
32
+
33
+ # Input Audio
34
+ with gr.Row():
35
+ audio_input = gr.Audio(type="filepath", label="Input Audio")
36
+
37
+ # Output Labels
38
+ with gr.Row():
39
+ label_output = gr.Label(label="Prediction")
40
+
41
+ # Predict Button
42
+ classify_btn = gr.Button("Classify")
43
 
44
+ # Define the interaction
45
+ classify_btn.click(fn=classify_audio, inputs=audio_input, outputs=label_output)
46
+
47
+ return demo
48
 
49
+ # Launch the demo
50
+ demo().launch()