xaman2 / app.py
salomonsky's picture
Update app.py
7255a1d verified
raw
history blame
4.9 kB
import io
import base64
import webrtcvad
import threading
import numpy as np
from gtts import gTTS
import streamlit as st
import sounddevice as sd
import speech_recognition as sr
from huggingface_hub import InferenceClient
devices = sd.query_devices()
print(devices)
if "history" not in st.session_state:
st.session_state.history = []
if "pre_prompt_sent" not in st.session_state:
st.session_state.pre_prompt_sent = False
gatherUsageStats = "false"
pre_prompt_text = "eres una IA conductual, tus respuestas ser谩n breves."
def recognize_speech(audio_data, show_messages=True):
recognizer = sr.Recognizer()
audio_recording = sr.AudioFile(audio_data)
with audio_recording as source:
audio = recognizer.record(source)
try:
audio_text = recognizer.recognize_google(audio, language="es-ES")
if show_messages:
st.subheader("Texto Reconocido:")
st.write(audio_text)
st.success("Reconocimiento de voz completado.")
except sr.UnknownValueError:
st.warning("No se pudo reconocer el audio. 驴Intentaste grabar algo?")
audio_text = ""
except sr.RequestError:
st.error("Hablame para comenzar!")
audio_text = ""
return audio_text
def format_prompt(message, history):
prompt = "<s>"
if not st.session_state.pre_prompt_sent:
prompt += f"[INST]{pre_prompt_text}[/INST]"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
def generate(audio_text, history, temperature=None, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
temperature = float(temperature) if temperature is not None else 0.9
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,)
formatted_prompt = format_prompt(audio_text, history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
response = ""
for response_token in stream:
response += response_token.token.text
response = ' '.join(response.split()).replace('</s>', '')
audio_file = text_to_speech(response, speed=1.3)
return response, audio_file
def text_to_speech(text, speed=1.3):
tts = gTTS(text=text, lang='es')
audio_fp = io.BytesIO()
tts.write_to_fp(audio_fp)
audio_fp.seek(0)
return audio_fp
def audio_play(audio_fp):
st.audio(audio_fp.read(), format="audio/mp3", start_time=0)
def display_recognition_result(audio_text, output, audio_file):
if audio_text:
st.session_state.history.append((audio_text, output))
if audio_file is not None:
st.markdown(
f"""<audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{base64.b64encode(audio_file.read()).decode()}" type="audio/mp3" id="audio_player"></audio>""",
unsafe_allow_html=True)
def voice_activity_detection(audio_data):
return vad.is_speech(audio_data, sample_rate)
def audio_callback(indata, frames, time, status):
assert frames == block_size
audio_data = indata[::downsample, mapping]
audio_data = map(lambda x: (x + 1) / 2, audio_data)
audio_data = np.fromiter(audio_data, np.float16)
audio_data = audio_data.tobytes()
detection = voice_activity_detection(audio_data)
print(detection)
def start_stream():
stream.start()
class Threader(threading.Thread):
def __init__(self, *args, **kwargs):
threading.Thread.__init__(self, *args, **kwargs)
self.start()
def run(self):
if self.name == 'mythread':
print("Started mythread")
start_stream()
if __name__ == "__main__":
vad = webrtcvad.Vad(1)
channels = [1]
mapping = [c - 1 for c in channels]
device_info = sd.query_devices(16, 'input')
sample_rate = int(device_info['default_samplerate'])
interval_size = 10
downsample = 1
block_size = int(sample_rate * interval_size / 1000)
Threader(name='mythread')
st.button("Detener Stream")
st.text("Esperando entrada de voz...")
st.text("Puedes detener el stream manualmente usando el bot贸n 'Detener Stream'.")
st.text("Nota: El c贸digo actual imprime los resultados de VAD en la consola.")
st.text("Puedes personalizar la l贸gica de VAD seg煤n tus necesidades.")
st.text("La transcripci贸n de voz y la generaci贸n de texto se manejar谩n una vez que se detecte actividad de voz.")
st.text("Inicia la grabaci贸n y espera a que aparezcan los resultados.")