Demosthene-OR commited on
Commit
e73113b
·
1 Parent(s): 7a07c71

Update chatbot_tab.py

Browse files
Files changed (1) hide show
  1. tabs/chatbot_tab.py +73 -55
tabs/chatbot_tab.py CHANGED
@@ -90,14 +90,52 @@ human_message1=""
90
  thread_id =""
91
  virulence = 1
92
  question = []
 
 
 
 
 
 
 
 
 
 
93
  if 'model' in st.session_state:
 
94
  used_model = st.session_state.model
95
 
96
- # @st.cache_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def init():
98
- global config,thread_id, context,human_message1,ai_message1,language, app, model_speech,prompt,model,question
99
  global selected_index1, selected_index2, selected_index3, selected_indices4,selected_indices5,selected_indices6,selected_indices7
100
- global selected_options4,selected_options5,selected_options6,selected_options7, selected_index8, virulence, used_model
101
 
102
  model_speech = whisper.load_model("base")
103
 
@@ -229,7 +267,7 @@ Attention: Si le vendeur aborde des points qui ne concerne pas cette simulation,
229
  SystemMessage(content=context),
230
  HumanMessage(content=human_message1),
231
  AIMessage(content=ai_message1),
232
- HumanMessage(content=tr("Commençons la conversation. Attention, je parle le premier"))
233
  ]
234
 
235
  st.write("")
@@ -240,31 +278,13 @@ Attention: Si le vendeur aborde des points qui ne concerne pas cette simulation,
240
  to_init = False
241
  else:
242
  to_init = True
243
-
244
  if to_init:
245
- thread_id = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
246
- config = {"configurable": {"thread_id": thread_id}}
247
- app.invoke(
248
- {"messages": messages, "language": language},
249
- config,
250
- )
251
- st.session_state.thread_id = thread_id
252
- st.session_state.config = config
253
- st.session_state.messages_init = messages
254
- st.session_state.context = context
255
- st.session_state.human_message1 = human_message1
256
- st.session_state.messages = []
257
- if 'model' in st.session_state and (st.session_state.model[:3]=="gpt") and ("OPENAI_API_KEY" in st.session_state):
258
- model = ChatOpenAI(model=st.session_state.model,
259
- temperature=0.8, # Adjust creativity level
260
- max_tokens=150 # Define max output token limit
261
- )
262
- else:
263
- model = ChatMistralAI(model=st.session_state.model)
264
- if 'model' in st.session_state:
265
- used_model=st.session_state.model
266
-
267
- return config, thread_id
268
 
269
  # Fonction pour générer et jouer le texte en speech
270
  def play_audio(custom_sentence, Lang_target, speed=1.0):
@@ -298,7 +318,7 @@ def play_audio(custom_sentence, Lang_target, speed=1.0):
298
 
299
 
300
  def run():
301
- global thread_id, config, model_speech, language,prompt,model, model_name, question
302
 
303
  st.write("")
304
  st.write("")
@@ -325,27 +345,17 @@ def run():
325
  model = ChatMistralAI(model=st.session_state.model)
326
 
327
 
328
- config,thread_id = init()
329
  query = ""
330
- st.button(label=tr("Validez"), type="primary")
331
- st.write("**thread_id:** "+thread_id)
332
  elif (chosen_id == "tab2"):
333
  try:
334
- config
335
- # On ne fait rien
336
  except NameError:
337
- config,thread_id = init()
338
  with st.container():
339
  # Diviser l'écran en deux colonnes
340
  col1, col2 = st.columns(2)
341
- # with col1:
342
- # st.markdown(
343
- # """
344
- # <div style="height: 400px;">
345
- # </div>
346
- # """,
347
- # unsafe_allow_html=True,
348
- # )
349
  with col1:
350
  st.write("**thread_id:** "+thread_id)
351
  query = ""
@@ -430,19 +440,27 @@ def run():
430
  with st.chat_message(message["role"]):
431
  st.markdown(message["content"])
432
  else:
 
 
433
  st.write("**thread_id:** "+thread_id)
434
  for i in range(8,len(question)):
435
  st.write("")
436
 
437
- q = st.text_input(label="", value=tr(question[i]),label_visibility="collapsed")
438
- output = app.invoke(
439
- {"messages": q,"language": language},
440
- config,
441
- )
442
- custom_sentence = output["messages"][-1].content
443
- st.write(custom_sentence)
444
- st.write("")
445
- if (used_model[:3] == 'mis'):
446
- time.sleep(2)
447
-
448
- st.divider()
 
 
 
 
 
 
 
90
  thread_id =""
91
  virulence = 1
92
  question = []
93
+ thread_id = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
94
+ config = {"configurable": {"thread_id": thread_id}}
95
+ to_init = True
96
+ initialized = False
97
+ messages = [
98
+ SystemMessage(content=""),
99
+ HumanMessage(content=""),
100
+ AIMessage(content=""),
101
+ HumanMessage(content="")
102
+ ]
103
  if 'model' in st.session_state:
104
+ model = st.session_state.model
105
  used_model = st.session_state.model
106
 
107
+ def init_run():
108
+ global initialized, to_init, thread_id, config, app, context, human_message1, model, used_model, messages
109
+
110
+ initialized = True
111
+ to_init = False
112
+ thread_id = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
113
+ config = {"configurable": {"thread_id": thread_id}}
114
+ app.invoke(
115
+ {"messages": messages, "language": language},
116
+ config,
117
+ )
118
+ st.session_state.thread_id = thread_id
119
+ st.session_state.config = config
120
+ st.session_state.messages_init = messages
121
+ st.session_state.context = context
122
+ st.session_state.human_message1 = human_message1
123
+ st.session_state.messages = []
124
+ if 'model' in st.session_state and (st.session_state.model[:3]=="gpt") and ("OPENAI_API_KEY" in st.session_state):
125
+ model = ChatOpenAI(model=st.session_state.model,
126
+ temperature=0.8, # Adjust creativity level
127
+ max_tokens=150 # Define max output token limit
128
+ )
129
+ else:
130
+ model = ChatMistralAI(model=st.session_state.model)
131
+ if 'model' in st.session_state:
132
+ used_model=st.session_state.model
133
+ return
134
+
135
  def init():
136
+ global config,thread_id, context,human_message1,ai_message1,language, app, model_speech,prompt,model,question, to_init, initialized
137
  global selected_index1, selected_index2, selected_index3, selected_indices4,selected_indices5,selected_indices6,selected_indices7
138
+ global selected_options4,selected_options5,selected_options6,selected_options7, selected_index8, virulence, used_model, messages
139
 
140
  model_speech = whisper.load_model("base")
141
 
 
267
  SystemMessage(content=context),
268
  HumanMessage(content=human_message1),
269
  AIMessage(content=ai_message1),
270
+ HumanMessage(content=tr("Commençons la conversation. Attention, je suis le vendeur et je parle le premier. Tu es le propect."))
271
  ]
272
 
273
  st.write("")
 
278
  to_init = False
279
  else:
280
  to_init = True
281
+
282
  if to_init:
283
+ if st.button(label=tr("Validez"), on_click=init_run,type="primary"):
284
+ initialized=True
285
+ else: initialized = False
286
+ st.write("**thread_id:** "+thread_id)
287
+ return config, thread_id, messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
  # Fonction pour générer et jouer le texte en speech
290
  def play_audio(custom_sentence, Lang_target, speed=1.0):
 
318
 
319
 
320
  def run():
321
+ global thread_id, config, model_speech, language,prompt,model, model_name, question, to_init, initialized, messages
322
 
323
  st.write("")
324
  st.write("")
 
345
  model = ChatMistralAI(model=st.session_state.model)
346
 
347
 
348
+ config,thread_id, messages = init()
349
  query = ""
 
 
350
  elif (chosen_id == "tab2"):
351
  try:
352
+ if to_init and not initialized:
353
+ init_run()
354
  except NameError:
355
+ config,thread_id, messages = init()
356
  with st.container():
357
  # Diviser l'écran en deux colonnes
358
  col1, col2 = st.columns(2)
 
 
 
 
 
 
 
 
359
  with col1:
360
  st.write("**thread_id:** "+thread_id)
361
  query = ""
 
440
  with st.chat_message(message["role"]):
441
  st.markdown(message["content"])
442
  else:
443
+ if to_init and not initialized:
444
+ init_run()
445
  st.write("**thread_id:** "+thread_id)
446
  for i in range(8,len(question)):
447
  st.write("")
448
 
449
+ q = st.text_input(label=".", value=tr(question[i]),label_visibility="collapsed")
450
+ if (q !=""):
451
+ input_messages = [HumanMessage(q)]
452
+ output = app.invoke(
453
+ {"messages": input_messages, "language": language},
454
+ config,
455
+ )
456
+ # output = app.invoke(
457
+ # {"messages": q,"language": language},
458
+ # config,
459
+ # )
460
+ custom_sentence = output["messages"][-1].content
461
+ st.write(custom_sentence)
462
+ st.write("")
463
+ if (used_model[:3] == 'mis'):
464
+ time.sleep(2)
465
+
466
+ st.divider()