rishiraj commited on
Commit
a647c50
1 Parent(s): 698d07a

add audio extractor

Browse files
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import base64
4
+ import librosa
5
+ from extractors.asrdiarization.asr_extractor import ASRExtractorConfig, ASRExtractor
6
+ from indexify_extractor_sdk import Content
7
+
8
+ MAX_AUDIO_MINUTES = 60 # wont try to transcribe if longer than this
9
+
10
+ asr_extractor = ASRExtractor()
11
+
12
+ def check_audio(audio_filepath):
13
+ """
14
+ Do not convert and raise error if audio too long.
15
+ """
16
+ data, sr = librosa.load(audio_filepath, sr=None, mono=True)
17
+ duration = librosa.get_duration(y=data, sr=sr)
18
+
19
+ if duration / 60.0 > MAX_AUDIO_MINUTES:
20
+ raise gr.Error(
21
+ f"This demo can transcribe up to {MAX_AUDIO_MINUTES} minutes of audio. "
22
+ "If you wish, you may trim the audio using the Audio viewer in Step 1 "
23
+ "(click on the scissors icon to start trimming audio)."
24
+ )
25
+
26
+ return audio_filepath
27
+
28
+ @spaces.GPU
29
+ def transcribe(audio_filepath, task, batch_size, chunk_length_s, sampling_rate, language, num_speakers, min_speakers, max_speakers, assisted):
30
+ if audio_filepath is None:
31
+ raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
32
+
33
+ audio_filepath = check_audio(audio_filepath)
34
+
35
+ with open(audio_filepath, "rb") as f:
36
+ converted_audio_filepath = base64.b64encode(f.read()).decode("utf-8")
37
+
38
+ content = Content(content_type="audio/mpeg", data=converted_audio_filepath)
39
+ config = ASRExtractorConfig(task=task, batch_size=batch_size, chunk_length_s=chunk_length_s, sampling_rate=sampling_rate, language=language, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers, assisted=assisted)
40
+
41
+ result = asr_extractor.extract(content, config)
42
+ text_content = next(content.data.decode('utf-8') for content in result)
43
+
44
+ return text_content
45
+
46
+ with gr.Blocks(
47
+ title="ASR + diarization + speculative decoding with Indexify"
48
+ ) as audio_demo:
49
+
50
+ gr.HTML("<h1 style='text-align: center'>ASR + diarization + speculative decoding with Indexify</h1>")
51
+ gr.HTML("<p style='text-align: center'>Indexify is a scalable realtime and continuous indexing and structured extraction engine for unstructured data to build generative AI applications</p>")
52
+ gr.HTML("<h3 style='text-align: center'>If you like this demo, please ⭐ Star us on <a href='https://github.com/tensorlakeai/indexify' target='_blank'>GitHub</a>!</h3>")
53
+
54
+ with gr.Row():
55
+ with gr.Column():
56
+ gr.HTML(
57
+ "<p><b>Step 1:</b> Upload an audio file or record with your microphone.</p>"
58
+
59
+ "<p style='color: #A0A0A0;'>Use this demo for audio files only up to 60 mins long. "
60
+ "You can transcribe longer files and try various other extractors locally with "
61
+ "<a href='https://getindexify.io/'>Indexify</a>.</p>"
62
+ )
63
+
64
+ audio_file = gr.Audio(sources=["microphone", "upload"], type="filepath")
65
+
66
+ gr.HTML("<p><b>Step 2:</b> Choose the parameters or leave to default.</p>")
67
+
68
+ task = gr.Dropdown(
69
+ choices=["transcribe", "translate"],
70
+ value="transcribe",
71
+ info="passed to the ASR pipeline",
72
+ label="Task:"
73
+ )
74
+
75
+ with gr.Column():
76
+ batch_size = gr.Number(
77
+ value=24,
78
+ info="for assisted generation the `batch_size` must be set to 1",
79
+ label="Batch Size:"
80
+ )
81
+ chunk_length_s = gr.Number(
82
+ value=30,
83
+ info="passed to the ASR pipeline",
84
+ label="Chunk Length:"
85
+ )
86
+ sampling_rate = gr.Number(
87
+ value=16000,
88
+ info="`sampling_rate` indicates the sampling rate of the audio to process and is used for preprocessing",
89
+ label="Sampling Rate:"
90
+ )
91
+ language = gr.Dropdown(
92
+ choices=['english', 'chinese', 'german', 'spanish', 'russian', 'korean', 'french', 'japanese', 'portuguese', 'turkish', 'polish', 'catalan', 'dutch', 'arabic', 'swedish', 'italian', 'indonesian', 'hindi', 'finnish', 'vietnamese', 'hebrew', 'ukrainian', 'greek', 'malay', 'czech', 'romanian', 'danish', 'hungarian', 'tamil', 'norwegian', 'thai', 'urdu', 'croatian', 'bulgarian', 'lithuanian', 'latin', 'maori', 'malayalam', 'welsh', 'slovak', 'telugu', 'persian', 'latvian', 'bengali', 'serbian', 'azerbaijani', 'slovenian', 'kannada', 'estonian', 'macedonian', 'breton', 'basque', 'icelandic', 'armenian', 'nepali', 'mongolian', 'bosnian', 'kazakh', 'albanian', 'swahili', 'galician', 'marathi', 'punjabi', 'sinhala', 'khmer', 'shona', 'yoruba', 'somali', 'afrikaans', 'occitan', 'georgian', 'belarusian', 'tajik', 'sindhi', 'gujarati', 'amharic', 'yiddish', 'lao', 'uzbek', 'faroese', 'haitian creole', 'pashto', 'turkmen', 'nynorsk', 'maltese', 'sanskrit', 'luxembourgish', 'myanmar', 'tibetan', 'tagalog', 'malagasy', 'assamese', 'tatar', 'hawaiian', 'lingala', 'hausa', 'bashkir', 'javanese', 'sundanese', 'cantonese', 'burmese', 'valencian', 'flemish', 'haitian', 'letzeburgesch', 'pushto', 'panjabi', 'moldavian', 'moldovan', 'sinhalese', 'castilian', 'mandarin'],
93
+ info="passed to the ASR pipeline",
94
+ label="Language:"
95
+ )
96
+ num_speakers = gr.Number(
97
+ info="passed to diarization pipeline",
98
+ label="Number of Speakers:"
99
+ )
100
+ min_speakers = gr.Number(
101
+ info="passed to diarization pipeline",
102
+ label="Minimum Speakers:"
103
+ )
104
+ max_speakers = gr.Number(
105
+ info="passed to diarization pipeline",
106
+ label="Maximum Speakers:"
107
+ )
108
+ assisted = gr.Checkbox(
109
+ value=False,
110
+ info="the `assisted` flag tells the pipeline whether to use speculative decoding",
111
+ label="Assisted?",
112
+ )
113
+
114
+ with gr.Column():
115
+
116
+ gr.HTML("<p><b>Step 3:</b> Run the extractor.</p>")
117
+
118
+ go_button = gr.Button(
119
+ value="Run extractor",
120
+ variant="primary", # make "primary" so it stands out (default is "secondary")
121
+ )
122
+
123
+ model_output_text_box = gr.Textbox(
124
+ label="Extractor Output",
125
+ elem_id="model_output_text_box",
126
+ )
127
+
128
+ with gr.Row():
129
+
130
+ gr.HTML(
131
+ "<p style='text-align: center'>"
132
+ "Developed with 🫶 by <a href='https://getindexify.io/' target='_blank'>Indexify</a> | "
133
+ "a <a href='https://www.tensorlake.ai/' target='_blank'>Tensorlake</a> product"
134
+ "</p>"
135
+ )
136
+
137
+ go_button.click(
138
+ fn=transcribe,
139
+ inputs = [audio_file, task, batch_size, chunk_length_s, sampling_rate, language, num_speakers, min_speakers, max_speakers, assisted],
140
+ outputs = [model_output_text_box]
141
+ )
142
+
143
+ demo = gr.TabbedInterface([audio_demo], ["Audio Extraction"], theme=gr.themes.Soft())
144
+
145
+ demo.queue()
146
+ demo.launch()
extractors/__init__.py ADDED
File without changes
extractors/asrdiarization/__init__.py ADDED
File without changes
extractors/asrdiarization/asr_extractor.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ import base64
4
+ import os
5
+
6
+ from indexify_extractor_sdk import Content, Extractor, Feature
7
+ from pyannote.audio import Pipeline
8
+ from transformers import pipeline, AutoModelForCausalLM
9
+ from .diarization_utils import diarize
10
+ from huggingface_hub import HfApi
11
+ from starlette.exceptions import HTTPException
12
+
13
+ from pydantic import BaseModel
14
+ from pydantic_settings import BaseSettings
15
+ from typing import Optional, Literal, List, Union
16
+
17
+ logger = logging.getLogger(__name__)
18
+ token = os.getenv('HF_TOKEN')
19
+
20
+ class ModelSettings(BaseSettings):
21
+ asr_model: str = "openai/whisper-large-v3"
22
+ assistant_model: Optional[str] = "distil-whisper/distil-large-v3"
23
+ diarization_model: Optional[str] = "pyannote/speaker-diarization-3.1"
24
+ hf_token: Optional[str] = token
25
+
26
+ model_settings = ModelSettings()
27
+
28
+ class ASRExtractorConfig(BaseModel):
29
+ task: Literal["transcribe", "translate"] = "transcribe"
30
+ batch_size: int = 24
31
+ assisted: bool = False
32
+ chunk_length_s: int = 30
33
+ sampling_rate: int = 16000
34
+ language: Optional[str] = None
35
+ num_speakers: Optional[int] = None
36
+ min_speakers: Optional[int] = None
37
+ max_speakers: Optional[int] = None
38
+
39
+ class ASRExtractor(Extractor):
40
+ name = "tensorlake/asrdiarization"
41
+ description = "Powerful ASR + diarization + speculative decoding."
42
+ system_dependencies = ["ffmpeg"]
43
+ input_mime_types = ["audio", "audio/mpeg"]
44
+
45
+ def __init__(self):
46
+ super(ASRExtractor, self).__init__()
47
+
48
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
49
+ logger.info(f"Using device: {device.type}")
50
+ torch_dtype = torch.float32 if device.type == "cpu" else torch.float16
51
+
52
+ self.assistant_model = AutoModelForCausalLM.from_pretrained(
53
+ model_settings.assistant_model,
54
+ torch_dtype=torch_dtype,
55
+ low_cpu_mem_usage=True,
56
+ use_safetensors=True
57
+ ) if model_settings.assistant_model else None
58
+
59
+ if self.assistant_model:
60
+ self.assistant_model.to(device)
61
+
62
+ self.asr_pipeline = pipeline(
63
+ "automatic-speech-recognition",
64
+ model=model_settings.asr_model,
65
+ torch_dtype=torch_dtype,
66
+ device=device
67
+ )
68
+
69
+ if model_settings.diarization_model:
70
+ # diarization pipeline doesn't raise if there is no token
71
+ HfApi().whoami(model_settings.hf_token)
72
+ self.diarization_pipeline = Pipeline.from_pretrained(
73
+ checkpoint_path=model_settings.diarization_model,
74
+ use_auth_token=model_settings.hf_token,
75
+ )
76
+ self.diarization_pipeline.to(device)
77
+ else:
78
+ self.diarization_pipeline = None
79
+
80
+ def extract(self, content: Content, params: ASRExtractorConfig) -> List[Union[Feature, Content]]:
81
+ file = base64.b64decode(content.data)
82
+ logger.info(f"inference params: {params}")
83
+
84
+ generate_kwargs = {
85
+ "task": params.task,
86
+ "language": params.language,
87
+ "assistant_model": self.assistant_model if params.assisted else None
88
+ }
89
+
90
+ try:
91
+ asr_outputs = self.asr_pipeline(
92
+ file,
93
+ chunk_length_s=params.chunk_length_s,
94
+ batch_size=params.batch_size,
95
+ generate_kwargs=generate_kwargs,
96
+ return_timestamps=True,
97
+ )
98
+ except RuntimeError as e:
99
+ logger.error(f"ASR inference error: {str(e)}")
100
+ raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}")
101
+ except Exception as e:
102
+ logger.error(f"Unknown error diring ASR inference: {str(e)}")
103
+ raise HTTPException(status_code=500, detail=f"Unknown error diring ASR inference: {str(e)}")
104
+
105
+ if self.diarization_pipeline:
106
+ try:
107
+ transcript = diarize(self.diarization_pipeline, file, params, asr_outputs)
108
+ except RuntimeError as e:
109
+ logger.error(f"Diarization inference error: {str(e)}")
110
+ raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}")
111
+ except Exception as e:
112
+ logger.error(f"Unknown error during diarization: {str(e)}")
113
+ raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}")
114
+ else:
115
+ transcript = []
116
+
117
+ feature = Feature.metadata(value={"chunks": asr_outputs["chunks"], "text": asr_outputs["text"]})
118
+ return [Content.from_text(str(transcript), features=[feature])]
119
+
120
+ def sample_input(self) -> Content:
121
+ filepath = "sample.mp3"
122
+ with open(filepath, 'rb') as f:
123
+ audio_encoded = base64.b64encode(f.read()).decode("utf-8")
124
+ return Content(content_type="audio/mpeg", data=audio_encoded)
125
+
126
+ if __name__ == "__main__":
127
+ filepath = "sample.mp3"
128
+ with open(filepath, 'rb') as f:
129
+ audio_encoded = base64.b64encode(f.read()).decode("utf-8")
130
+ data = Content(content_type="audio/mpeg", data=audio_encoded)
131
+ params = ASRExtractorConfig(batch_size=24)
132
+ extractor = ASRExtractor()
133
+ results = extractor.extract(data, params=params)
134
+ print(results)
extractors/asrdiarization/diarization_utils.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torchaudio import functional as F
4
+ from transformers.pipelines.audio_utils import ffmpeg_read
5
+ from starlette.exceptions import HTTPException
6
+ import sys
7
+
8
+ # Code from insanely-fast-whisper:
9
+ # https://github.com/Vaibhavs10/insanely-fast-whisper
10
+
11
+ import logging
12
+ logger = logging.getLogger(__name__)
13
+
14
+ def preprocess_inputs(inputs, sampling_rate):
15
+ inputs = ffmpeg_read(inputs, sampling_rate)
16
+
17
+ if sampling_rate != 16000:
18
+ inputs = F.resample(
19
+ torch.from_numpy(inputs), sampling_rate, 16000
20
+ ).numpy()
21
+
22
+ if len(inputs.shape) != 1:
23
+ logger.error(f"Diarization pipeline expecs single channel audio, received {inputs.shape}")
24
+ raise HTTPException(
25
+ status_code=400,
26
+ detail=f"Diarization pipeline expecs single channel audio, received {inputs.shape}"
27
+ )
28
+
29
+ # diarization model expects float32 torch tensor of shape `(channels, seq_len)`
30
+ diarizer_inputs = torch.from_numpy(inputs).float()
31
+ diarizer_inputs = diarizer_inputs.unsqueeze(0)
32
+
33
+ return inputs, diarizer_inputs
34
+
35
+
36
+ def diarize_audio(diarizer_inputs, diarization_pipeline, parameters):
37
+ diarization = diarization_pipeline(
38
+ {"waveform": diarizer_inputs, "sample_rate": parameters.sampling_rate},
39
+ num_speakers=parameters.num_speakers,
40
+ min_speakers=parameters.min_speakers,
41
+ max_speakers=parameters.max_speakers,
42
+ )
43
+
44
+ segments = []
45
+ for segment, track, label in diarization.itertracks(yield_label=True):
46
+ segments.append(
47
+ {
48
+ "segment": {"start": segment.start, "end": segment.end},
49
+ "track": track,
50
+ "label": label,
51
+ }
52
+ )
53
+
54
+ # diarizer output may contain consecutive segments from the same speaker (e.g. {(0 -> 1, speaker_1), (1 -> 1.5, speaker_1), ...})
55
+ # we combine these segments to give overall timestamps for each speaker's turn (e.g. {(0 -> 1.5, speaker_1), ...})
56
+ new_segments = []
57
+ prev_segment = cur_segment = segments[0]
58
+
59
+ for i in range(1, len(segments)):
60
+ cur_segment = segments[i]
61
+
62
+ # check if we have changed speaker ("label")
63
+ if cur_segment["label"] != prev_segment["label"] and i < len(segments):
64
+ # add the start/end times for the super-segment to the new list
65
+ new_segments.append(
66
+ {
67
+ "segment": {
68
+ "start": prev_segment["segment"]["start"],
69
+ "end": cur_segment["segment"]["start"],
70
+ },
71
+ "speaker": prev_segment["label"],
72
+ }
73
+ )
74
+ prev_segment = segments[i]
75
+
76
+ # add the last segment(s) if there was no speaker change
77
+ new_segments.append(
78
+ {
79
+ "segment": {
80
+ "start": prev_segment["segment"]["start"],
81
+ "end": cur_segment["segment"]["end"],
82
+ },
83
+ "speaker": prev_segment["label"],
84
+ }
85
+ )
86
+
87
+ return new_segments
88
+
89
+
90
+ def post_process_segments_and_transcripts(new_segments, transcript, group_by_speaker) -> list:
91
+ # get the end timestamps for each chunk from the ASR output
92
+ end_timestamps = np.array(
93
+ [chunk["timestamp"][-1] if chunk["timestamp"][-1] is not None else sys.float_info.max for chunk in transcript])
94
+ segmented_preds = []
95
+
96
+ # align the diarizer timestamps and the ASR timestamps
97
+ for segment in new_segments:
98
+ # get the diarizer end timestamp
99
+ end_time = segment["segment"]["end"]
100
+ # find the ASR end timestamp that is closest to the diarizer's end timestamp and cut the transcript to here
101
+ upto_idx = np.argmin(np.abs(end_timestamps - end_time))
102
+
103
+ if group_by_speaker:
104
+ segmented_preds.append(
105
+ {
106
+ "speaker": segment["speaker"],
107
+ "text": "".join(
108
+ [chunk["text"] for chunk in transcript[: upto_idx + 1]]
109
+ ),
110
+ "timestamp": (
111
+ transcript[0]["timestamp"][0],
112
+ transcript[upto_idx]["timestamp"][1],
113
+ ),
114
+ }
115
+ )
116
+ else:
117
+ for i in range(upto_idx + 1):
118
+ segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
119
+
120
+ # crop the transcripts and timestamp lists according to the latest timestamp (for faster argmin)
121
+ transcript = transcript[upto_idx + 1:]
122
+ end_timestamps = end_timestamps[upto_idx + 1:]
123
+
124
+ if len(end_timestamps) == 0:
125
+ break
126
+
127
+ return segmented_preds
128
+
129
+
130
+ def diarize(diarization_pipeline, file, parameters, asr_outputs):
131
+ _, diarizer_inputs = preprocess_inputs(file, parameters.sampling_rate)
132
+
133
+ segments = diarize_audio(
134
+ diarizer_inputs,
135
+ diarization_pipeline,
136
+ parameters
137
+ )
138
+
139
+ return post_process_segments_and_transcripts(
140
+ segments, asr_outputs["chunks"], group_by_speaker=False
141
+ )
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ indexify-extractor-sdk
2
+ accelerate==0.27.2
3
+ pyannote-audio==3.1.1
4
+ transformers==4.40.2
5
+ numpy==1.26.4
6
+ torchaudio==2.2.0
7
+ pydantic==2.6.3
8
+ pydantic-settings==2.2.1
9
+ librosa==0.10.2
10
+ torch==2.2.0
11
+ bitsandbytes
12
+ peft