Manjot Singh commited on
Commit
7123f83
·
1 Parent(s): 1651ea1

added translation, and model selection

Browse files
Files changed (3) hide show
  1. app.py +15 -7
  2. audio_processing.py +108 -137
  3. requirements.txt +3 -1
app.py CHANGED
@@ -1,24 +1,32 @@
1
  import gradio as gr
2
  from audio_processing import process_audio, print_results
3
- def transcribe_audio(audio_file):
4
- language_segments, final_segments = process_audio(audio_file)
 
5
 
6
  output = "Detected language changes:\n\n"
7
  for segment in language_segments:
8
  output += f"Language: {segment['language']}\n"
9
  output += f"Time: {segment['start']:.2f}s - {segment['end']:.2f}s\n\n"
10
 
11
- output += "Transcription with language detection and speaker diarization:\n\n"
12
  for segment in final_segments:
13
- output += f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}) Speaker {segment['speaker']}: {segment['text']}\n"
14
- # output += f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}): {segment['text']}\n"
 
 
 
15
  return output
16
 
17
  iface = gr.Interface(
18
  fn=transcribe_audio,
19
- inputs=gr.Audio(type="filepath"),
 
 
 
 
20
  outputs="text",
21
- title="WhisperX Audio Transcription"
22
  )
23
 
24
  iface.launch()
 
1
  import gradio as gr
2
  from audio_processing import process_audio, print_results
3
+
4
+ def transcribe_audio(audio_file, translate, model_size):
5
+ language_segments, final_segments = process_audio(audio_file, translate=translate, model_size=model_size)
6
 
7
  output = "Detected language changes:\n\n"
8
  for segment in language_segments:
9
  output += f"Language: {segment['language']}\n"
10
  output += f"Time: {segment['start']:.2f}s - {segment['end']:.2f}s\n\n"
11
 
12
+ output += f"Transcription with language detection and speaker diarization (using {model_size} model):\n\n"
13
  for segment in final_segments:
14
+ output += f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}) {segment['speaker']}:\n"
15
+ output += f"Original: {segment['text']}\n"
16
+ if translate:
17
+ output += f"Translated: {segment['translated']}\n"
18
+ output += "\n"
19
  return output
20
 
21
  iface = gr.Interface(
22
  fn=transcribe_audio,
23
+ inputs=[
24
+ gr.Audio(type="filepath"),
25
+ gr.Checkbox(label="Enable Translation"),
26
+ gr.Dropdown(choices=["tiny", "base", "small", "medium", "large","large-v2","large-v3"], label="Whisper Model Size", value="small")
27
+ ],
28
  outputs="text",
29
+ title="WhisperX Audio Transcription and Translation"
30
  )
31
 
32
  iface.launch()
audio_processing.py CHANGED
@@ -2,168 +2,139 @@ import whisperx
2
  import torch
3
  import numpy as np
4
  from scipy.signal import resample
5
- import numpy as np
6
- import whisperx
7
  from pyannote.audio import Pipeline
8
  import os
9
  from dotenv import load_dotenv
10
-
11
  load_dotenv()
12
-
 
 
13
  hf_token = os.getenv("HF_TOKEN")
14
- import whisperx
15
- import torch
16
- import numpy as np
17
-
18
- import whisperx
19
- import torch
20
- import numpy as np
21
 
22
- import whisperx
23
- import torch
24
- import numpy as np
25
  CHUNK_LENGTH=5
26
-
27
- # def process_audio(audio_file):
28
- # device = "cuda" if torch.cuda.is_available() else "cpu"
29
- # compute_type = "float32"
30
- # audio = whisperx.load_audio(audio_file)
31
- # model = whisperx.load_model("small", device, compute_type=compute_type)
32
-
33
- # # Initial transcription
34
- # result = model.transcribe(audio, batch_size=8)
35
-
36
- # # Sliding window for language detection
37
- # window_size = 5 # seconds
38
- # step_size = 1 # seconds
39
- # sample_rate = 16000
40
-
41
- # language_probs = []
42
- # audio_duration = len(audio) / sample_rate
43
-
44
- # if audio_duration <= window_size:
45
- # # If audio is shorter than or equal to window size, detect language for entire audio
46
- # lang = model.detect_language(audio)
47
- # language_probs.append((0, lang))
48
- # else:
49
- # for i in range(0, len(audio) - window_size * sample_rate + 1, step_size * sample_rate):
50
- # window = audio[i:i + window_size * sample_rate]
51
- # lang = model.detect_language(window)
52
- # language_probs.append((i / sample_rate, lang))
53
-
54
- # # Detect language changes
55
- # language_segments = []
56
- # current_lang = language_probs[0][1]
57
- # start_time = 0
58
- # for time, lang in language_probs[1:]:
59
- # if lang != current_lang:
60
- # language_segments.append({
61
- # "language": current_lang,
62
- # "start": start_time,
63
- # "end": time
64
- # })
65
- # current_lang = lang
66
- # start_time = time
67
-
68
- # # Add the last segment
69
- # language_segments.append({
70
- # "language": current_lang,
71
- # "start": start_time,
72
- # "end": audio_duration
73
- # })
74
-
75
- # # Re-transcribe each language segment
76
- # final_segments = []
77
- # for segment in language_segments:
78
- # start_sample = int(segment["start"] * sample_rate)
79
- # end_sample = int(segment["end"] * sample_rate)
80
- # segment_audio = audio[start_sample:end_sample]
81
-
82
- # segment_result = model.transcribe(segment_audio, language=segment["language"])
83
-
84
- # for seg in segment_result["segments"]:
85
- # seg["start"] += segment["start"]
86
- # seg["end"] += segment["start"]
87
- # seg["language"] = segment["language"]
88
- # final_segments.append(seg)
89
-
90
- # return language_segments, final_segments
91
-
92
  import whisperx
93
  import torch
94
  import numpy as np
 
 
 
95
 
96
- def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000): # 30 seconds at 16kHz
97
  chunks = []
98
- for i in range(0, len(audio), chunk_size):
99
  chunk = audio[i:i+chunk_size]
100
  if len(chunk) < chunk_size:
101
  chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
102
  chunks.append(chunk)
103
  return chunks
104
 
105
- def process_audio(audio_file):
106
- device = "cuda" if torch.cuda.is_available() else "cpu"
107
- compute_type = "float32"
108
- audio = whisperx.load_audio(audio_file)
109
- model = whisperx.load_model("small", device, compute_type=compute_type)
110
-
111
- # Initialize speaker diarization pipeline
112
- diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
113
- diarization_pipeline = diarization_pipeline.to(torch.device(device))
114
 
115
- # Perform diarization on the entire audio
116
- diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
117
 
 
118
 
119
- # Preprocess audio into consistent chunks
120
- chunks = preprocess_audio(audio)
121
 
122
- language_segments = []
123
- final_segments = []
124
-
125
- for i, chunk in enumerate(chunks):
126
- # Detect language for this chunk
127
- lang = model.detect_language(chunk)
128
-
129
- # Transcribe this chunk
130
- result = model.transcribe(chunk, language=lang)
131
 
132
- chunk_start_time = i * 5 # Each chunk is 30 seconds
133
-
134
- # Adjust timestamps and add language information
135
- for segment in result["segments"]:
136
- segment_start = chunk_start_time + segment["start"]
137
- segment_end = chunk_start_time + segment["end"]
138
- segment["start"] = segment_start
139
- segment["end"] = segment_end
140
- segment["language"] = lang
 
 
 
 
 
 
 
 
141
 
142
- speakers = []
143
- for turn, track, speaker in diarization_result.itertracks(yield_label=True):
144
- if turn.start <= segment_end and turn.end >= segment_start:
145
- speakers.append(speaker)
146
- if speakers:
147
- segment["speaker"] = max(set(speakers), key=speakers.count)
148
- else:
149
- segment["speaker"] = "Unknown"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- final_segments.append(segment)
152
- # Add language segment
153
- language_segments.append({
154
- "language": lang,
155
- "start": chunk_start_time,
156
- "end": chunk_start_time + 5
157
- })
158
 
159
- return language_segments, final_segments
 
160
 
161
- def print_results(language, language_probs, segments):
162
- print(f"Detected Language: {language}")
163
- print("Language Probabilities:")
164
- for lang, prob in language_probs.items():
165
- print(f" {lang}: {prob:.4f}")
166
-
167
- print("\nTranscription:")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  for segment in segments:
169
- print(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] Speaker {segment['speaker']}: {segment['text']}")
 
 
 
 
 
2
  import torch
3
  import numpy as np
4
  from scipy.signal import resample
 
 
5
  from pyannote.audio import Pipeline
6
  import os
7
  from dotenv import load_dotenv
 
8
  load_dotenv()
9
+ import logging
10
+ import time
11
+ from difflib import SequenceMatcher
12
  hf_token = os.getenv("HF_TOKEN")
 
 
 
 
 
 
 
13
 
 
 
 
14
  CHUNK_LENGTH=5
15
+ OVERLAP=2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  import whisperx
17
  import torch
18
  import numpy as np
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
 
23
+ def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000): # 2 seconds overlap
24
  chunks = []
25
+ for i in range(0, len(audio), chunk_size - overlap):
26
  chunk = audio[i:i+chunk_size]
27
  if len(chunk) < chunk_size:
28
  chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
29
  chunks.append(chunk)
30
  return chunks
31
 
32
+ def process_audio(audio_file, translate=False, model_size="small"):
33
+ start_time = time.time()
34
+
35
+ try:
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+ compute_type = "float32"
38
+ audio = whisperx.load_audio(audio_file)
39
+ model = whisperx.load_model(model_size, device, compute_type=compute_type)
 
40
 
41
+ diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
42
+ diarization_pipeline = diarization_pipeline.to(torch.device(device))
43
 
44
+ diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
45
 
46
+ chunks = preprocess_audio(audio)
 
47
 
48
+ language_segments = []
49
+ final_segments = []
 
 
 
 
 
 
 
50
 
51
+ overlap_duration = 2 # 2 seconds overlap
52
+ for i, chunk in enumerate(chunks):
53
+ chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
54
+ chunk_end_time = chunk_start_time + CHUNK_LENGTH
55
+ logger.info(f"Processing chunk {i+1}/{len(chunks)}")
56
+ lang = model.detect_language(chunk)
57
+ result_transcribe = model.transcribe(chunk, language=lang)
58
+ if translate:
59
+ result_translate = model.transcribe(chunk, task="translate")
60
+ chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
61
+ for j, t_seg in enumerate(result_transcribe["segments"]):
62
+ segment_start = chunk_start_time + t_seg["start"]
63
+ segment_end = chunk_start_time + t_seg["end"]
64
+ # Skip segments in the overlapping region of the previous chunk
65
+ if i > 0 and segment_end <= chunk_start_time + overlap_duration:
66
+ print(f"Skipping segment in overlap with previous chunk: {segment_start:.2f} - {segment_end:.2f}")
67
+ continue
68
 
69
+ # Skip segments in the overlapping region of the next chunk
70
+ if i < len(chunks) - 1 and segment_start >= chunk_end_time - overlap_duration:
71
+ print(f"Skipping segment in overlap with next chunk: {segment_start:.2f} - {segment_end:.2f}")
72
+ continue
73
+
74
+ speakers = []
75
+ for turn, track, speaker in diarization_result.itertracks(yield_label=True):
76
+ if turn.start <= segment_end and turn.end >= segment_start:
77
+ speakers.append(speaker)
78
+
79
+ segment = {
80
+ "start": segment_start,
81
+ "end": segment_end,
82
+ "language": lang,
83
+ "speaker": max(set(speakers), key=speakers.count) if speakers else "Unknown",
84
+ "text": t_seg["text"],
85
+ }
86
+
87
+ if translate:
88
+ segment["translated"] = result_translate["segments"][j]["text"]
89
+
90
+ final_segments.append(segment)
91
 
92
+ language_segments.append({
93
+ "language": lang,
94
+ "start": chunk_start_time,
95
+ "end": chunk_start_time + CHUNK_LENGTH
96
+ })
97
+ chunk_end_time = time.time()
98
+ logger.info(f"Chunk {i+1} processed in {chunk_end_time - chunk_start_time:.2f} seconds")
99
 
100
+ final_segments.sort(key=lambda x: x["start"])
101
+ merged_segments = merge_nearby_segments(final_segments)
102
 
103
+ end_time = time.time()
104
+ logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")
105
+
106
+ return language_segments, merged_segments
107
+ except Exception as e:
108
+ logger.error(f"An error occurred during audio processing: {str(e)}")
109
+ raise
110
+
111
+ def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.7):
112
+ merged = []
113
+ for segment in segments:
114
+ if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
115
+ merged.append(segment)
116
+ else:
117
+ # Find the overlap
118
+ matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
119
+ match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
120
+
121
+ if match.size / len(segment['text']) > similarity_threshold:
122
+ # Merge the segments
123
+ merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
124
+ merged_translated = merged[-1]['translated'] + segment['translated'][match.b + match.size:]
125
+
126
+ merged[-1]['end'] = segment['end']
127
+ merged[-1]['text'] = merged_text
128
+ merged[-1]['translated'] = merged_translated
129
+ else:
130
+ # If no significant overlap, append as a new segment
131
+ merged.append(segment)
132
+ return merged
133
+
134
+ def print_results(segments):
135
  for segment in segments:
136
+ print(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}) {segment['speaker']}:")
137
+ print(f"Original: {segment['text']}")
138
+ if 'translated' in segment:
139
+ print(f"Translated: {segment['translated']}")
140
+ print()
requirements.txt CHANGED
@@ -12,4 +12,6 @@ torchaudio>=2
12
  faster-whisper==1.0.0
13
  setuptools>=65
14
  nltk
15
- python-dotenv
 
 
 
12
  faster-whisper==1.0.0
13
  setuptools>=65
14
  nltk
15
+ python-dotenv
16
+ difflib
17
+ pydub