datacipen commited on
Commit
50a8089
β€’
1 Parent(s): 64762cf

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +10 -3
main.py CHANGED
@@ -693,6 +693,7 @@ async def construction_FCS(romeListArray,settings):
693
  async def construction_NCS(romeListArray):
694
  context = await contexte(romeListArray)
695
  emploisST = cl.user_session.get("EmploiST")
 
696
  ### Mistral Completion ###
697
  client_llm = await IA()
698
  structure = str(modele('Note de composante sectorielle'))
@@ -710,11 +711,17 @@ async def construction_NCS(romeListArray):
710
  prompt = PromptTemplate(template=template, input_variables=["question","context"])
711
  #llm_chain = LLMChain(prompt=prompt, llm=client_llm)
712
  #completion_NCS = llm_chain.run({"question":question_p,"context":context_p}, callbacks=[StreamingStdOutCallbackHandler()])
713
- chain = prompt | client_llm
 
 
 
 
 
714
  #completion_NCS = chain.invoke({"question":question_p,"context":context_p})
715
 
716
  msg = cl.Message(author="Datapcc : 🌐🌐🌐",content="")
717
- async for chunk in chain.astream({"question":question_p,"context":context_p}):
 
718
  await msg.stream_token(chunk)
719
 
720
  cl.user_session.set("NCS" + romeListArray[0], msg.content)
@@ -926,7 +933,7 @@ async def IA():
926
  repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
927
 
928
  llm = HuggingFaceEndpoint(
929
- repo_id=repo_id, max_new_tokens=5000, temperature=0.7, streaming=True
930
  )
931
  return llm
932
 
 
693
  async def construction_NCS(romeListArray):
694
  context = await contexte(romeListArray)
695
  emploisST = cl.user_session.get("EmploiST")
696
+ memory = ConversationBufferMemory(return_messages=True)
697
  ### Mistral Completion ###
698
  client_llm = await IA()
699
  structure = str(modele('Note de composante sectorielle'))
 
711
  prompt = PromptTemplate(template=template, input_variables=["question","context"])
712
  #llm_chain = LLMChain(prompt=prompt, llm=client_llm)
713
  #completion_NCS = llm_chain.run({"question":question_p,"context":context_p}, callbacks=[StreamingStdOutCallbackHandler()])
714
+ chain = (
715
+ RunnablePassthrough.assign(
716
+ history=RunnableLambda(memory.load_memory_variables) | itemgetter("history")
717
+ )
718
+ | prompt | client_llm
719
+ )
720
  #completion_NCS = chain.invoke({"question":question_p,"context":context_p})
721
 
722
  msg = cl.Message(author="Datapcc : 🌐🌐🌐",content="")
723
+ async for chunk in chain.astream({"question":question_p,"context":context_p},
724
+ config=RunnableConfig(callbacks=[cl.AsyncLangchainCallbackHandler(stream_final_answer=True)]):
725
  await msg.stream_token(chunk)
726
 
727
  cl.user_session.set("NCS" + romeListArray[0], msg.content)
 
933
  repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
934
 
935
  llm = HuggingFaceEndpoint(
936
+ repo_id=repo_id, max_new_tokens=5000, temperature=0.7, task="text2text-generation", streaming=True
937
  )
938
  return llm
939