Spaces:
Sleeping
Sleeping
import streamlit as st # type: ignore | |
import os | |
from datetime import datetime | |
from extra_streamlit_components import tab_bar, TabBarItemData | |
import io | |
from gtts import gTTS | |
import soundfile as sf | |
import wavio | |
from audio_recorder_streamlit import audio_recorder | |
import speech_recognition as sr | |
import whisper | |
import numpy as np | |
from translate_app import tr | |
import getpass | |
from langchain_mistralai import ChatMistralAI | |
from langchain_openai import ChatOpenAI | |
from langgraph.checkpoint.memory import MemorySaver | |
from langgraph.graph import START, END, MessagesState, StateGraph | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from typing import Sequence | |
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage, trim_messages | |
from langgraph.graph.message import add_messages | |
from typing_extensions import Annotated, TypedDict | |
from dotenv import load_dotenv | |
import time | |
from tabs.google_drive_read_preprompt import read_param, format_param | |
import warnings | |
warnings.filterwarnings('ignore') | |
title = "Sales coaching" | |
sidebar_name = "Sales coaching" | |
dataPath = st.session_state.DataPath | |
os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
os.environ["LANGCHAIN_ENDPOINT"]="https://api.smith.langchain.com" | |
os.environ["LANGCHAIN_HUB_API_URL"]="https://api.smith.langchain.com" | |
os.environ["LANGCHAIN_PROJECT"] = "Sales Coaching Chatbot" | |
if st.session_state.Cloud != 0: | |
load_dotenv() | |
os.getenv("LANGCHAIN_API_KEY") | |
os.getenv("MISTRAL_API_KEY") | |
os.getenv("OPENAI_API_KEY") | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
( | |
"system", | |
"Répond à toutes les questions du mieux possible dans la langue {language}, même si la question est posée dans une autre langue", | |
), | |
MessagesPlaceholder(variable_name="messages"), | |
] | |
) | |
class State(TypedDict): | |
messages: Annotated[Sequence[BaseMessage], add_messages] | |
language: str | |
def call_model(state: State): | |
chain = prompt | model | |
response = chain.invoke(state) | |
return {"messages": [response]} | |
# Define a new graph | |
workflow = StateGraph(state_schema=State) | |
# Define the (single) node in the graph | |
workflow.add_edge(START, "model") | |
workflow.add_node("model", call_model) | |
workflow.add_edge("model", END) | |
# Add memory | |
memory = MemorySaver() | |
app = workflow.compile(checkpointer=memory) | |
selected_index1 = 0 | |
selected_index2 = 0 | |
selected_index3 = 0 | |
selected_indices4 = [] | |
selected_indices5 = [] | |
selected_indices6 = [] | |
selected_indices7 = [] | |
selected_options4 = [] | |
selected_options5 = [] | |
selected_options6 = [] | |
selected_options7 = [] | |
selected_index8 = 0 | |
context="" | |
human_message1="" | |
thread_id ="" | |
virulence = 1 | |
question = [] | |
thread_id = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
config = {"configurable": {"thread_id": thread_id}} | |
to_init = True | |
initialized = False | |
messages = [ | |
SystemMessage(content=""), | |
HumanMessage(content=""), | |
AIMessage(content=""), | |
HumanMessage(content="") | |
] | |
if 'model' in st.session_state: | |
model = st.session_state.model | |
used_model = st.session_state.model | |
def init_run(): | |
global initialized, to_init, thread_id, config, app, context, human_message1, model, used_model, messages | |
initialized = True | |
to_init = False | |
thread_id = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
config = {"configurable": {"thread_id": thread_id}} | |
app.invoke( | |
{"messages": messages, "language": language}, | |
config, | |
) | |
st.session_state.thread_id = thread_id | |
st.session_state.config = config | |
st.session_state.messages_init = messages | |
st.session_state.context = context | |
st.session_state.human_message1 = human_message1 | |
st.session_state.messages = [] | |
if 'model' in st.session_state and (st.session_state.model[:3]=="gpt") and ("OPENAI_API_KEY" in st.session_state): | |
model = ChatOpenAI(model=st.session_state.model, | |
temperature=0.8, # Adjust creativity level | |
max_tokens=150 # Define max output token limit | |
) | |
else: | |
model = ChatMistralAI(model=st.session_state.model) | |
if 'model' in st.session_state: | |
used_model=st.session_state.model | |
return | |
def init(): | |
global config,thread_id, context,human_message1,ai_message1,language, app, model_speech,prompt,model,question, to_init, initialized | |
global selected_index1, selected_index2, selected_index3, selected_indices4,selected_indices5,selected_indices6,selected_indices7 | |
global selected_options4,selected_options5,selected_options6,selected_options7, selected_index8, virulence, used_model, messages | |
model_speech = whisper.load_model("base") | |
if (st.button(label=tr("Nouvelle conversation"), type="primary")): | |
selected_index1 = 0 | |
selected_index2 = 0 | |
selected_index3 = 0 | |
selected_indices4 = [] | |
selected_indices5 = [] | |
selected_indices6 = [] | |
selected_indices7 = [] | |
selected_options4 = [] | |
selected_options5 = [] | |
selected_options6 = [] | |
selected_options7 = [] | |
selected_index8 = 0 | |
context = "" | |
human_message1="" | |
thread_id ="" | |
virulence = 1 | |
if 'model' in st.session_state and (st.session_state.model[:3]=="gpt") and ("OPENAI_API_KEY" in st.session_state): | |
model = ChatOpenAI(model=st.session_state.model, | |
temperature=0.8, # Adjust creativity level | |
max_tokens=150 # Define max output token limit | |
) | |
else: | |
model = ChatMistralAI(model=st.session_state.model) | |
if 'model' in st.session_state: | |
used_model=st.session_state.model | |
label, question, options = format_param() | |
translated_options1 = [tr(o) for o in options[0]] | |
selected_option1 = st.selectbox(tr(label[0]),translated_options1, index = selected_index1) # index=int(var1_init)) | |
selected_index1 = translated_options1.index(selected_option1) | |
translated_options2 = [tr(o) for o in options[1]] | |
selected_option2 = st.selectbox(tr(label[1]),translated_options2, index = selected_index2) # index=int(var2_init)) | |
selected_index2 = translated_options2.index(selected_option2) | |
translated_options3 = [tr(o) for o in options[2]] | |
selected_option3 = st.selectbox(tr(label[2]),translated_options3, index=selected_index3) #index=int(var3_init)) | |
selected_index3 = translated_options3.index(selected_option3) | |
context = tr(f"""Tu es un {options[0][selected_index1]}, d'une {options[1][selected_index2]}. | |
Cette entreprise propose des {options[2][selected_index3]}. | |
""") | |
context = st.text_area(label=tr("Résumé du Contexte (modifiable):"), value=context) | |
st.markdown(''' | |
------------------------------------------------------------------------------------ | |
''') | |
translated_options4 = [tr(o) for o in options[3]] | |
selected_options4 = st.multiselect(tr(label[3]),translated_options4, default=[translated_options4[o] for o in selected_indices4]) | |
selected_indices4 = [translated_options4.index(o) for o in selected_options4] | |
problematique = selected_options4 | |
if problematique != []: | |
markdown_text4 = """\n"""+tr(question[3]) | |
markdown_text4 = markdown_text4+"".join(f"\n- {o}" for o in problematique) | |
st.write(markdown_text4) | |
else: markdown_text4 = "" | |
translated_options5 = [tr(o) for o in options[4]] | |
selected_options5 = st.multiselect(tr(label[4]),translated_options5, default=[translated_options5[o] for o in selected_indices5]) | |
selected_indices5 = [translated_options5.index(o) for o in selected_options5] | |
processus = selected_options5 | |
if processus != []: | |
markdown_text5 = """\n\n"""+tr(question[4]) | |
markdown_text5 = markdown_text5+"".join(f"\n- {o}" for o in processus) | |
st.write(markdown_text5) | |
else: markdown_text5 = "" | |
translated_options6 = [tr(o) for o in options[5]] | |
selected_options6 = st.multiselect(tr(label[5]),translated_options6, default=[translated_options6[o] for o in selected_indices6]) | |
selected_indices6 = [translated_options6.index(o) for o in selected_options6] | |
objectifs = selected_options6 | |
if objectifs != []: | |
markdown_text6 = """\n\n"""+tr(question[5]) | |
markdown_text6 = markdown_text6+"".join(f"\n- {o}" for o in objectifs) | |
st.write(markdown_text6) | |
else: markdown_text6 = "" | |
translated_options7 = [tr(o) for o in options[6]] | |
selected_options7 = st.multiselect(tr(label[6]),translated_options7, default=[translated_options7[o] for o in selected_indices7]) | |
selected_indices7 = [translated_options7.index(o) for o in selected_options7] | |
solutions_utilisees = selected_options7 | |
if solutions_utilisees != []: | |
markdown_text7 = """\n\n"""+tr(question[6]) | |
markdown_text7 = markdown_text7+"".join(f"\n- {o}" for o in solutions_utilisees) | |
st.write(markdown_text7) | |
st.write("") | |
else: markdown_text7 = "" | |
translated_options8 = [tr(o) for o in options[7]] | |
selected_option8 = st.selectbox(tr(label[7]),translated_options8, index = selected_index8) | |
selected_index8 = translated_options8.index(selected_option8) | |
markdown_text8 = """\n\n"""+tr(question[7])+"""\n"""+(f"""{translated_options8[selected_index8]}""") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
virulence = st.slider(tr("Virulence (choisissez une valeur entre 1 et 5)"), min_value=1, max_value=5, step=1,value=virulence) | |
markdown_text9 = """\n\n"""+tr(f"""Le prospect est très occupé et n'aime pas être dérangé inutilement. | |
Tu vas utiliser une échelle de 1 à 5 d'agressivité du prospect à l'égard du vendeur. | |
Pour cette simulation utilise le niveau {virulence}.""") | |
human_message1 = tr("""Je souhaite que nous ayons une conversation verbale entre moi le vendeur, et toi que je prospecte. | |
Mon entreprise propose une solution logicielle pour gérer la proposition de valeur d’entreprise B2B qui commercialise des solutions technologiques. | |
""")+markdown_text4+markdown_text5+markdown_text6+markdown_text7+markdown_text8+markdown_text9+tr(f""" | |
Je suis le vendeur. | |
Répond à mes questions en tant que {options[0][selected_index1]}, connaissant mal le concept de proposition de valeur, | |
et mon équipe de vente n'est pas performante. | |
Attention: Ce n'est pas toi qui m'aide, c'est moi qui t'aide avec ma solution. | |
Attention: Si le vendeur aborde des points qui ne concerne pas cette simulation, lui répondre que c'est hors contexte. | |
Es tu prêt à commencer ? | |
""") | |
human_message1 = st.text_area(label=tr("Consigne"), value=tr(human_message1),height=300) | |
st.markdown(''' | |
------------------------------------------------------------------------------------ | |
''') | |
ai_message1 = tr(f"J'ai bien compris, je suis un {options[0][selected_index1]} prospecté et je réponds seulement à tes questions. Je réponds à une seule question à la fois, sans commencer mes réponses par 'En tant que {options[0][selected_index1]}'.") | |
# ai_message1 = st.text_area(label=tr("Réponse du prospect"), value=ai_message1) | |
messages = [ | |
SystemMessage(content=context), | |
HumanMessage(content=human_message1), | |
# AIMessage(content=ai_message1), | |
# HumanMessage(content=tr("Commençons la conversation. Attention, je suis le vendeur et je parle le premier. Tu es le propect.")) | |
] | |
st.write("") | |
if ("context" in st.session_state) and ("human_message1" in st.session_state): | |
if (st.session_state.context != context) or (st.session_state.human_message1 != human_message1 ) or (used_model != st.session_state.model) or (thread_id==""): | |
to_init = True | |
else: | |
to_init = False | |
else: | |
to_init = True | |
if to_init: | |
if st.button(label=tr("Validez"), on_click=init_run,type="primary"): | |
initialized=True | |
else: initialized = False | |
st.write("**thread_id:** "+thread_id) | |
return config, thread_id, messages | |
# Fonction pour générer et jouer le texte en speech | |
def play_audio(custom_sentence, Lang_target, speed=1.0): | |
# Générer le speech avec gTTS | |
audio_stream_bytesio_src = io.BytesIO() | |
tts = gTTS(custom_sentence, lang=Lang_target) | |
# Revenir au début du flux audio | |
audio_stream_bytesio_src.seek(0) | |
audio_stream_bytesio_src.truncate(0) | |
tts.write_to_fp(audio_stream_bytesio_src) | |
audio_stream_bytesio_src.seek(0) | |
# Charger l'audio dans un tableau numpy | |
data, samplerate = sf.read(audio_stream_bytesio_src) | |
# Modifier la vitesse de lecture en ajustant le taux d'échantillonnage | |
new_samplerate = int(samplerate * speed) | |
new_audio_stream_bytesio = io.BytesIO() | |
# Enregistrer l'audio avec la nouvelle fréquence d'échantillonnage | |
sf.write(new_audio_stream_bytesio, data, new_samplerate, format='wav') | |
new_audio_stream_bytesio.seek(0) | |
# Lire l'audio dans Streamlit | |
# time.sleep(2) | |
st.audio(new_audio_stream_bytesio, start_time=0, autoplay=True) | |
def run(): | |
global thread_id, config, model_speech, language,prompt,model, model_name, question, to_init, initialized, messages | |
st.write("") | |
st.write("") | |
st.title(tr(title)) | |
if 'language_label' in st.session_state: | |
language = st.session_state['language_label'] | |
else: language = "French" | |
chosen_id = tab_bar(data=[ | |
TabBarItemData(id="tab1", title=tr("Initialisation"), description=tr("d'une nouvelle conversation")), | |
TabBarItemData(id="tab2", title=tr("Conversation"), description=tr("avec le prospect")), | |
TabBarItemData(id="tab3", title=tr("Evaluation"), description=tr("de l'acte de vente"))], | |
default="tab1") | |
if (chosen_id == "tab1"): | |
if 'model' in st.session_state and (st.session_state.model[:3]=="gpt") and ("OPENAI_API_KEY" in st.session_state): | |
model = ChatOpenAI(model=st.session_state.model, | |
temperature=0.8, # Adjust creativity level | |
max_tokens=150 # Define max output token limit | |
) | |
else: | |
model = ChatMistralAI(model=st.session_state.model) | |
config,thread_id, messages = init() | |
query = "" | |
elif (chosen_id == "tab2"): | |
try: | |
if to_init and not initialized: | |
init_run() | |
except NameError: | |
config,thread_id, messages = init() | |
with st.container(): | |
# Diviser l'écran en deux colonnes | |
col1, col2 = st.columns(2) | |
with col1: | |
st.write("**thread_id:** "+thread_id) | |
query = "" | |
audio_bytes = audio_recorder (pause_threshold=2.0, sample_rate=16000, auto_start=False, text=tr("Cliquez pour parler, puis attendre 2sec."), \ | |
recording_color="#e8b62c", neutral_color="#1ec3bc", icon_size="6x",) | |
if audio_bytes: | |
# st.write("**"+tr("Vendeur")+" :**\n") | |
# Fonction pour générer et jouer le texte en speech | |
st.audio(audio_bytes, format="audio/wav", autoplay=False) | |
try: | |
detection = False | |
if detection: | |
# Create a BytesIO object from the audio stream | |
audio_stream_bytesio = io.BytesIO(audio_bytes) | |
# Read the WAV stream using wavio | |
wav = wavio.read(audio_stream_bytesio) | |
# Extract the audio data from the wavio.Wav object | |
audio_data = wav.data | |
# Convert the audio data to a NumPy array | |
audio_input = np.array(audio_data, dtype=np.float32) | |
audio_input = np.mean(audio_input, axis=1)/32768 | |
result = model_speech.transcribe(audio_input) | |
Lang_detected = result["language"] | |
query = result["text"] | |
else: | |
# Avec l'aide de la bibliothèque speech_recognition de Google | |
Lang_detected = st.session_state['Language'] | |
# Transcription google | |
audio_stream = sr.AudioData(audio_bytes, 32000, 2) | |
r = sr.Recognizer() | |
query = r.recognize_google(audio_stream, language = Lang_detected) | |
# Transcription | |
# st.write("**"+tr("Vendeur :")+"** "+query) | |
with st.chat_message("user"): | |
st.markdown(query) | |
st.write("") | |
if query != "": | |
input_messages = [HumanMessage(query)] | |
output = app.invoke( | |
{"messages": input_messages, "language": language}, | |
config, | |
) | |
#with st.chat_message("user"): | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": query}) | |
# Récupération de la réponse | |
custom_sentence = output["messages"][-1].content | |
# Joue l'audio | |
play_audio(custom_sentence,Lang_detected , 1) | |
# st.write("**"+tr("Prospect :")+"** "+custom_sentence) | |
with st.chat_message("assistant"): | |
st.markdown(custom_sentence) | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "assistant", "content": custom_sentence}) | |
except KeyboardInterrupt: | |
st.write(tr("Arrêt de la reconnaissance vocale.")) | |
except: | |
st.write(tr("Problème, essayer de nouveau..")) | |
st.write("") | |
# Ajouter un espace pour séparer les zones | |
# st.divider() | |
with col2: | |
if ("messages" in st.session_state) : | |
if (st.session_state.messages != []): | |
# Display chat messages from history on app rerun | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
else: | |
if to_init and not initialized: | |
init_run() | |
st.write("**thread_id:** "+thread_id) | |
for i in range(8,len(question)): | |
st.write("") | |
q = st.text_input(label=".", value=tr(question[i]),label_visibility="collapsed") | |
if (q !=""): | |
input_messages = [HumanMessage(q)] | |
output = app.invoke( | |
{"messages": input_messages, "language": language}, | |
config, | |
) | |
# output = app.invoke( | |
# {"messages": q,"language": language}, | |
# config, | |
# ) | |
custom_sentence = output["messages"][-1].content | |
st.write(custom_sentence) | |
st.write("") | |
if (used_model[:3] == 'mis'): | |
time.sleep(2) | |
st.divider() |