Spaces:
Running
Running
import gradio as gr | |
import whisper | |
import datetime | |
import torch | |
import subprocess | |
import os | |
from pyannote.audio import Audio | |
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding | |
from pyannote.core import Segment | |
import wave | |
import contextlib | |
from sklearn.cluster import AgglomerativeClustering | |
import numpy as np | |
# Load Whisper model | |
model_size = "medium.en" | |
model = whisper.load_model(model_size) | |
audio = Audio() | |
embedding_model = PretrainedSpeakerEmbedding("speechbrain/spkrec-ecapa-voxceleb", device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) | |
def transcribe_and_diarize(audio_file, num_speakers=2): | |
try: | |
path = audio_file.name | |
# Convert to WAV if necessary | |
if not path.endswith('.wav'): | |
subprocess.call(['ffmpeg', '-i', path, 'audio.wav', '-y']) | |
path = 'audio.wav' | |
# Transcribe audio | |
result = model.transcribe(path) | |
segments = result["segments"] | |
# Get audio duration | |
with contextlib.closing(wave.open(path, 'r')) as f: | |
frames = f.getnframes() | |
rate = f.getframerate() | |
duration = frames / float(rate) | |
# Define function to extract segment embeddings | |
def segment_embedding(segment): | |
start = segment["start"] | |
end = min(duration, segment["end"]) | |
clip = Segment(start, end) | |
waveform, sample_rate = audio.crop(path, clip) | |
return embedding_model(waveform[None]) | |
# Extract embeddings for each segment | |
embeddings = np.zeros(shape=(len(segments), 192)) | |
for i, segment in enumerate(segments): | |
embeddings[i] = segment_embedding(segment) | |
embeddings = np.nan_to_num(embeddings) | |
# Perform speaker clustering | |
clustering = AgglomerativeClustering(num_speakers).fit(embeddings) | |
labels = clustering.labels_ | |
for i in range(len(segments)): | |
segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1) | |
# Generate transcript | |
transcript = "" | |
for i, segment in enumerate(segments): | |
if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]: | |
transcript += "\n" + segment["speaker"] + ' ' + str(datetime.timedelta(seconds=round(segment["start"]))) + '\n' | |
transcript += segment["text"][1:] + ' ' | |
transcript += "\n\n" | |
return transcript | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
iface = gr.Interface( | |
fn=transcribe_and_diarize, | |
inputs=[ | |
gr.Audio(type="filepath", label="Upload Audio File"), | |
gr.Number(value=2, label="Number of Speakers") | |
], | |
outputs="text", | |
title="Audio Transcription and Speaker Diarization", | |
description="Upload an audio file to get a transcription with speaker diarization." | |
) | |
iface.launch() | |