File size: 3,746 Bytes
3d57cbf
1760204
3d57cbf
7255a1d
d49accf
7255a1d
8842007
d6b9b98
 
 
 
84e2e9f
e66719a
 
 
 
 
6f460e4
faf2e04
3d57cbf
 
faf2e04
d6b9b98
 
 
 
 
 
 
 
 
 
faf2e04
d6b9b98
3d57cbf
 
 
0173625
3d57cbf
 
0173625
6f460e4
 
3d57cbf
80bfac8
3d57cbf
 
 
 
80bfac8
3d57cbf
 
 
 
 
 
 
 
80bfac8
 
3d57cbf
 
 
80bfac8
3d57cbf
 
 
 
 
 
 
 
 
 
574c2e1
 
 
 
 
 
 
 
 
badb078
 
 
 
f0c95ae
badb078
4a7e5ac
1e42336
623eb5e
468e93e
 
e66719a
 
faf2e04
 
468e93e
4a7e5ac
468e93e
5a99f5b
 
 
badb078
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import io
import base64
from gtts import gTTS
import streamlit as st
import speech_recognition as sr
from huggingface_hub import InferenceClient
from streamlit_mic_recorder import mic_recorder
import wave

pre_prompt_text = "eres una IA conductual, tus respuestas serán breves."
temp_audio_file_path = "./output.wav"

if "history" not in st.session_state:
    st.session_state.history = []

if "pre_prompt_sent" not in st.session_state:
    st.session_state.pre_prompt_sent = False

def recognize_speech(audio_bytes, show_messages=True):
    recognizer = sr.Recognizer()

    with io.BytesIO(audio_bytes) as audio_file:
        try:
            audio_text = recognizer.recognize_google(audio_file, language="es-ES")
            if show_messages:
                st.subheader("Texto Reconocido:")
                st.write(audio_text)
                st.success("Reconocimiento de voz completado.")
        except sr.UnknownValueError:
            st.warning("No se pudo reconocer el audio. ¿Intentaste grabar algo?")
            audio_text = ""
        except sr.RequestError:
            st.error("Háblame para comenzar!")
            audio_text = ""

    return audio_text

def format_prompt(message, history):
    prompt = "<s>"

    if not st.session_state.pre_prompt_sent:
        prompt += f"[INST]{pre_prompt_text}[/INST]"

    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST] {bot_response}</s> "

    prompt += f"[INST] {message} [/INST]"
    return prompt

def generate(audio_text, history, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
    client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(audio_text, history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
    response = " ".join([response_token.token.text for response_token in stream]).replace('</s>', '')
    audio_file = text_to_speech(response, speed=1.3)
    return response, audio_file

def text_to_speech(text, speed=1.3):
    tts = gTTS(text=text, lang='es')
    audio_fp = io.BytesIO()
    tts.write_to_fp(audio_fp)
    audio_fp.seek(0)
    return audio_fp

def display_recognition_result(audio_text, output, audio_file):
    if audio_text:
        st.session_state.history.append((audio_text, output))  

    if audio_file is not None:
        st.markdown(
            f"""<audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{base64.b64encode(audio_file.read()).decode()}" type="audio/mp3" id="audio_player"></audio>""",
            unsafe_allow_html=True)

def main():
    if not st.session_state.pre_prompt_sent:
        st.session_state.pre_prompt_sent = True

    audio_data = mic_recorder(start_prompt="▶️", stop_prompt="🛑", key='recorder')

    if audio_data and 'bytes' in audio_data:
        st.audio(audio_data['bytes'])

        audio_bytes = audio_data['bytes']

        with wave.open(temp_audio_file_path, 'w') as wave_file:
            wave_file.setnchannels(1)
            wave_file.setsampwidth(2) 
            wave_file.setframerate(44100)
            wave_file.writeframes(audio_bytes)

        audio_text = recognize_speech(audio_bytes)
        formatted_prompt = format_prompt(audio_text, st.session_state.history)
        response, audio_file = generate(formatted_prompt, st.session_state.history)
        display_recognition_result(audio_text, response, audio_file)

if __name__ == "__main__":
    main()