import torch import nltk from scipy.io.wavfile import write import librosa import hashlib from typing import List def embed_questions( question_model, question_tokenizer, questions, max_length=128, device="cpu" ): query = question_tokenizer( questions, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", ) with torch.no_grad(): q_reps = question_model( query["input_ids"].to(device), query["attention_mask"].to(device) ).pooler_output return q_reps.cpu().numpy() def embed_passages(ctx_model, ctx_tokenizer, passages, max_length=128, device="cpu"): p = ctx_tokenizer( passages["text"], max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", ) with torch.no_grad(): a_reps = ctx_model( p["input_ids"].to(device), p["attention_mask"].to(device) ).pooler_output return {"embeddings": a_reps.cpu().numpy()} class Document: def __init__(self, meta={}, content: str = "", id_: str = ""): self.meta = meta self.content = content self.id = id_ def _alter_docs_for_haystack(passages): return [Document(content=passage, id_=str(i)) for i, passage in enumerate(passages)] def embed_passages_haystack( dpr_model, passages, ): passages = _alter_docs_for_haystack(passages["text"]) embeddings = dpr_model.embed_documents(passages) return {"embeddings": embeddings} def correct_casing(input_sentence): """This function is for correcting the casing of the generated transcribed text""" sentences = nltk.sent_tokenize(input_sentence) return " ".join([s.replace(s[0], s[0].capitalize(), 1) for s in sentences]) def clean_transcript(text): text = text.replace("[pad]".upper(), "") return text def add_question_symbols(text): if text[0] != "¿": text = "¿" + text if text[-1] != "?": text = text + "?" return text def remove_chars_to_tts(text): text = text.replace(",", " ") return text def transcript(input_file, audio_array, processor, model): if audio_array: rate, sample = audio_array write("temp.wav", rate, sample) input_file = "temp.wav" transcript = "" # Ensure that the sample rate is 16k sample_rate = librosa.get_samplerate(input_file) # Stream over 10 seconds chunks rather than load the full file stream = librosa.stream( input_file, block_length=20, # number of seconds to split the batch frame_length=sample_rate, # 16000, hop_length=sample_rate, # 16000 ) for speech in stream: if len(speech.shape) > 1: speech = speech[:, 0] + speech[:, 1] if sample_rate != 16000: speech = librosa.resample(speech, orig_sr=sample_rate, target_sr=16000) input_values = processor(speech, return_tensors="pt").input_values logits = model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.decode( predicted_ids[0], clean_up_tokenization_spaces=True, skip_special_tokens=True, ) transcription = clean_transcript(transcription) # transcript += transcription.lower() transcript += correct_casing(transcription.lower()) + ". " # transcript += " " whole_text = transcript[:3800] whole_text = add_question_symbols(whole_text) return whole_text def parse_final_answer(answer_text: str, contexts: List): """Parse the final answer into correct format""" answer = f"
{answer_text}
\n\n\n" docs = ( "\n".join( [ ("""""" + context)[:250] + "[...]
" for context in contexts[:5] ] ) ) return answer, docs