|
from sys import platform |
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline |
|
import logging |
|
import torch |
|
from transformers.utils import is_flash_attn_2_available |
|
from pyannote.audio import Pipeline |
|
from pyannote.core import Segment |
|
import pandas as pd |
|
|
|
languages = { |
|
"English": "en", |
|
"Chinese": "zh", |
|
"German": "de", |
|
"Spanish": "es", |
|
"Russian": "ru", |
|
"Korean": "ko", |
|
"French": "fr", |
|
"Japanese": "ja", |
|
"Portuguese": "pt", |
|
"Turkish": "tr", |
|
"Polish": "pl", |
|
"Catalan": "ca", |
|
"Dutch": "nl", |
|
"Arabic": "ar", |
|
"Swedish": "sv", |
|
"Italian": "it", |
|
"Indonesian": "id", |
|
"Hindi": "hi", |
|
"Finnish": "fi", |
|
"Vietnamese": "vi", |
|
"Hebrew": "iw", |
|
"Ukrainian": "uk", |
|
"Greek": "el", |
|
"Malay": "ms", |
|
"Czech": "cs", |
|
"Romanian": "ro", |
|
"Danish": "da", |
|
"Hungarian": "hu", |
|
"Tamil": "ta", |
|
"Norwegian": "no", |
|
"Thai": "th", |
|
"Urdu": "ur", |
|
"Croatian": "hr", |
|
"Bulgarian": "bg", |
|
"Lithuanian": "lt", |
|
"Latin": "la", |
|
"Maori": "mi", |
|
"Malayalam": "ml", |
|
"Welsh": "cy", |
|
"Slovak": "sk", |
|
"Telugu": "te", |
|
"Persian": "fa", |
|
"Latvian": "lv", |
|
"Bengali": "bn", |
|
"Serbian": "sr", |
|
"Azerbaijani": "az", |
|
"Slovenian": "sl", |
|
"Kannada": "kn", |
|
"Estonian": "et", |
|
"Macedonian": "mk", |
|
"Breton": "br", |
|
"Basque": "eu", |
|
"Icelandic": "is", |
|
"Armenian": "hy", |
|
"Nepali": "ne", |
|
"Mongolian": "mn", |
|
"Bosnian": "bs", |
|
"Kazakh": "kk", |
|
"Albanian": "sq", |
|
"Swahili": "sw", |
|
"Galician": "gl", |
|
"Marathi": "mr", |
|
"Punjabi": "pa", |
|
"Sinhala": "si", |
|
"Khmer": "km", |
|
"Shona": "sn", |
|
"Yoruba": "yo", |
|
"Somali": "so", |
|
"Afrikaans": "af", |
|
"Occitan": "oc", |
|
"Georgian": "ka", |
|
"Belarusian": "be", |
|
"Tajik": "tg", |
|
"Sindhi": "sd", |
|
"Gujarati": "gu", |
|
"Amharic": "am", |
|
"Yiddish": "yi", |
|
"Lao": "lo", |
|
"Uzbek": "uz", |
|
"Faroese": "fo", |
|
"Haitian creole": "ht", |
|
"Pashto": "ps", |
|
"Turkmen": "tk", |
|
"Nynorsk": "nn", |
|
"Maltese": "mt", |
|
"Sanskrit": "sa", |
|
"Luxembourgish": "lb", |
|
"Myanmar": "my", |
|
"Tibetan": "bo", |
|
"Tagalog": "tl", |
|
"Malagasy": "mg", |
|
"Assamese": "as", |
|
"Tatar": "tt", |
|
"Hawaiian": "haw", |
|
"Lingala": "ln", |
|
"Hausa": "ha", |
|
"Bashkir": "ba", |
|
"Javanese": "jw", |
|
"Sundanese": "su", |
|
} |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda:0") |
|
elif platform == "darwin": |
|
device = torch.device("mps") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
|
|
def get_text_with_timestamp(transcribe_res): |
|
timestamp_texts = [] |
|
for item in transcribe_res["chunks"]: |
|
start = item["timestamp"][0] |
|
end = item["timestamp"][1] |
|
text = item["text"] |
|
timestamp_texts.append((Segment(start, end), text)) |
|
return timestamp_texts |
|
|
|
|
|
def add_speaker_info_to_text(timestamp_texts, ann): |
|
spk_text = [] |
|
for seg, text in timestamp_texts: |
|
spk = ann.crop(seg).argmax() |
|
spk_text.append((seg, spk, text)) |
|
return spk_text |
|
|
|
|
|
def merge_cache(text_cache): |
|
sentence = "".join([item[-1] for item in text_cache]) |
|
spk = text_cache[0][1] |
|
start = text_cache[0][0].start |
|
end = text_cache[-1][0].end |
|
return Segment(start, end), spk, sentence |
|
|
|
|
|
PUNC_SENT_END = [".", "?", "!"] |
|
|
|
|
|
def merge_sentence(spk_text): |
|
merged_spk_text = [] |
|
pre_spk = None |
|
text_cache = [] |
|
for seg, spk, text in spk_text: |
|
if spk != pre_spk and pre_spk is not None and len(text_cache) > 0: |
|
merged_spk_text.append(merge_cache(text_cache)) |
|
text_cache = [(seg, spk, text)] |
|
pre_spk = spk |
|
|
|
elif text[-1] in PUNC_SENT_END: |
|
text_cache.append((seg, spk, text)) |
|
merged_spk_text.append(merge_cache(text_cache)) |
|
text_cache = [] |
|
pre_spk = spk |
|
else: |
|
text_cache.append((seg, spk, text)) |
|
pre_spk = spk |
|
if len(text_cache) > 0: |
|
merged_spk_text.append(merge_cache(text_cache)) |
|
return merged_spk_text |
|
|
|
def diarize_text(transcribe_res, diarization_result): |
|
timestamp_texts = get_text_with_timestamp(transcribe_res) |
|
spk_text = add_speaker_info_to_text(timestamp_texts, diarization_result) |
|
res_processed = merge_sentence(spk_text) |
|
return res_processed |
|
|
|
def make_conversation(transcribe_result, diarization_result): |
|
processed = diarize_text(transcribe_result, diarization_result) |
|
df = pd.DataFrame(processed, columns=["segment", "speaker", "text"])[ |
|
["speaker", "text"] |
|
] |
|
df["key"] = (df["speaker"] != df["speaker"].shift(1)).astype(int).cumsum() |
|
conversation = df.groupby(["key", "speaker"])["text"].apply(" ".join).reset_index() |
|
conversation_list = list(zip(conversation.text, conversation.speaker)) |
|
return conversation_list |
|
|
|
|
|
def transcriber(input: str, model: str, language: str, translate: bool, diarize: bool, input_diarization_token) -> dict: |
|
"""Transcribes the audio using the OpenAI Whisper model. |
|
Args: |
|
input: file path to the audio file in any format |
|
language: name of the language in which the audio is recorded |
|
translate: boolean indicator to enable immediate translation |
|
Returns: transcription and segment-timestamps. |
|
""" |
|
model_id = model |
|
|
|
if diarize: |
|
|
|
pipeline_diarization = Pipeline.from_pretrained( |
|
"pyannote/speaker-diarization-3.1", |
|
use_auth_token=input_diarization_token) |
|
|
|
|
|
pipeline_diarization.to(device) |
|
|
|
|
|
diarization = pipeline_diarization(input) |
|
|
|
|
|
|
|
|
|
|
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
model_id, |
|
torch_dtype=torch_dtype, |
|
low_cpu_mem_usage=True, |
|
use_safetensors=True, |
|
use_flash_attention_2=True if is_flash_attn_2_available() else False |
|
) |
|
|
|
print(device) |
|
|
|
model.to(device) |
|
|
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
language = languages.get(language, None) |
|
task = None |
|
if translate: |
|
task = "translate" |
|
|
|
pipe = pipeline( |
|
"automatic-speech-recognition", |
|
model=model, |
|
tokenizer=processor.tokenizer, |
|
feature_extractor=processor.feature_extractor, |
|
max_new_tokens=128, |
|
chunk_length_s=15, |
|
batch_size=16, |
|
return_timestamps=True, |
|
torch_dtype=torch_dtype, |
|
device=device, |
|
generate_kwargs={"task": task} |
|
) |
|
|
|
|
|
results = pipe(input) |
|
results["text"] = results["text"].strip() |
|
|
|
text = "" |
|
chunks = results.get("chunks", []) |
|
for chunk in chunks: |
|
text += chunk["text"] + "\n" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return text |
|
|