sepal commited on
Commit
4315919
1 Parent(s): decd378

Implement basic diarization

Browse files
Files changed (2) hide show
  1. app.py +42 -4
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,8 +1,46 @@
 
 
1
  import gradio as gr
 
 
 
 
 
2
 
 
3
 
4
- def greet(name):
5
- return "Hello " + name + "!!"
6
 
7
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
8
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
  import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ from pyannote.audio import Pipeline
7
+ from pydub import AudioSegment
8
+ from mimetypes import MimeTypes
9
 
10
+ load_dotenv()
11
 
12
+ hg_token = os.getenv("HG_ACCESS_TOKEN")
 
13
 
14
+ if hg_token != None:
15
+ pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hg_token)
16
+ else:
17
+ print('''No hugging face access token set.
18
+ You need to set it via an .env or environment variable HG_ACCESS_TOKEN''')
19
+ exit(1)
20
+
21
+
22
+ def diarization(audio_file: tuple[int, np.array]) -> np.array:
23
+ """
24
+ Receives a tuple with the sample rate and audio data and returns the
25
+ a numpy array containing the audio segments, track names and speakers for
26
+ each segment.
27
+ """
28
+ waveform = torch.tensor(audio_file[1].astype(np.float32, order='C')).reshape(1,-1)
29
+ audio_data = {
30
+ "waveform": waveform,
31
+ "sample_rate": audio_file[0]
32
+ }
33
+
34
+ diarization = pipeline(audio_data)
35
+
36
+ return np.array(list(diarization.itertracks(yield_label=True)))
37
+
38
+
39
+
40
+ demo = gr.Interface(
41
+ fn=diarization,
42
+ inputs=gr.Audio(type="numpy"),
43
+ outputs="text",
44
+ )
45
+
46
+ demo.launch()
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  gradio==3.21.0
2
  openai-whisper==20230314
 
3
  python-dotenv==1.0.0
4
  pyannote.audio==2.1.1
5
  torch==1.13.1
 
1
  gradio==3.21.0
2
  openai-whisper==20230314
3
+ pydub==0.25.1
4
  python-dotenv==1.0.0
5
  pyannote.audio==2.1.1
6
  torch==1.13.1