import gradio as gr import time from faster_whisper import WhisperModel from utils import ffmpeg_read, stt, greeting_list from sentence_transformers import SentenceTransformer, util import torch whisper_models = ["tiny", "base", "small", "medium", "large-v1", "large-v2"] audio_model = WhisperModel("base", compute_type="int8", device="cpu") text_model = SentenceTransformer('all-MiniLM-L6-v2') corpus_embeddings = torch.load('corpus_embeddings.pt') model_type = "whisper" title= "Greeting detection demo app" def speech_to_text(upload_audio): """ Transcribe audio using whisper model. """ # Transcribe audio if model_type == "whisper": transcribe_options = dict(task="transcribe", language="ja", beam_size=5, best_of=5, vad_filter=True) segments_raw, info = audio_model.transcribe(upload_audio, **transcribe_options) segments = [segment.text for segment in segments_raw] return ' '.join(segments) else: text = stt(upload_audio) return text def voice_detect(audio, recongnize_text=""): """ Transcribe audio using whisper model. """ # time.sleep(2) if len(recongnize_text) !=0: count_state = int(recongnize_text[0]) recongnize_text = recongnize_text[1:] else: count_state = 0 threshold = 0.8 detect_greeting = 0 text = speech_to_text(audio) if "ご視聴ありがとうございました" in text: text = "" recongnize_text = recongnize_text + " " + text query_embedding = text_model.encode(text, convert_to_tensor=True) for greeting in greeting_list: if greeting in text: detect_greeting = 1 break if detect_greeting == 0: hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=1)[0] if hits[0]['score'] > threshold: detect_greeting = 1 recongnize_state = str(count_state + detect_greeting) + recongnize_text return recongnize_text, recongnize_state, count_state def clear(): return None, None, None demo = gr.Blocks(title=title) with demo: gr.Markdown('''

挨拶カウンター

''') with gr.Row(): with gr.Column(): audio_source = gr.Audio(source="microphone", type="filepath", streaming=True) state = gr.State(value="") with gr.Column(): greeting_count = gr.Number(label="挨拶回数") with gr.Row(): text_output = gr.Textbox(label="認識されたテキスト") audio_source.stream(voice_detect, inputs=[audio_source, state], outputs=[text_output, state, greeting_count]) demo.launch(debug=True)