eventia / main.py
datacipen's picture
Update main.py
927cbbf verified
raw
history blame
9.22 kB
import os
import json
import bcrypt
from typing import List
from pathlib import Path
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_huggingface import HuggingFaceEndpoint
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema import StrOutputParser
from operator import itemgetter
from pinecone import Pinecone
from langchain_pinecone import PineconeVectorStore
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain.memory import ConversationBufferMemory
from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig, RunnableLambda
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import ConversationalRetrievalChain
import chainlit as cl
from chainlit.input_widget import TextInput, Select, Switch, Slider
@cl.password_auth_callback
def auth_callback(username: str, password: str):
auth = json.loads(os.environ['CHAINLIT_AUTH_LOGIN'])
ident = next(d['ident'] for d in auth if d['ident'] == username)
pwd = next(d['pwd'] for d in auth if d['ident'] == username)
resultLogAdmin = bcrypt.checkpw(username.encode('utf-8'), bcrypt.hashpw(ident.encode('utf-8'), bcrypt.gensalt()))
resultPwdAdmin = bcrypt.checkpw(password.encode('utf-8'), bcrypt.hashpw(pwd.encode('utf-8'), bcrypt.gensalt()))
resultRole = next(d['role'] for d in auth if d['ident'] == username)
if resultLogAdmin and resultPwdAdmin and resultRole == "admindatapcc":
return cl.User(
identifier=ident + " : 🧑‍💼 Admin Datapcc", metadata={"role": "admin", "provider": "credentials"}
)
elif resultLogAdmin and resultPwdAdmin and resultRole == "userdatapcc":
return cl.User(
identifier=ident + " : 🧑‍🎓 User Datapcc", metadata={"role": "user", "provider": "credentials"}
)
os.environ['HUGGINGFACEHUB_API_TOKEN'] = os.environ['HUGGINGFACEHUB_API_TOKEN']
repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
model = HuggingFaceEndpoint(
repo_id=repo_id, max_new_tokens=5000, temperature=1.0, task="text2text-generation", streaming=True
)
os.environ['PINECONE_API_KEY'] = os.environ['PINECONE_API_KEY']
embeddings = HuggingFaceEmbeddings()
index_name = "all-venus"
#pc = Pinecone(
# api_key=os.environ['PINECONE_API_KEY']
#)
#index = pc.Index(index_name)
#xq = embeddings.embed_query(message.content)
#xc = index.query(vector=xq, filter={"categorie": {"$eq": "bibliographie-OPP-DGDIN"}},top_k=150, include_metadata=True)
#context = ""
#for result in xc['matches']:
# context = context + result['metadata']['text']
vectorstore = PineconeVectorStore(
index_name=index_name, embedding=embeddings
)
retriever = vectorstore.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": .7, "k": 150,"filter": {'categorie': {'$eq': 'bibliographie-OPP-DGDIN'}}})
@cl.on_chat_start
async def on_chat_start():
await cl.Message(f"> REVIEWSTREAM").send()
settings = await cl.ChatSettings(
[
Select(
id="Model",
label="Publications de recherche",
values=["---", "HAL", "Persée"],
initial_index=0,
),
]
).send()
res = await cl.AskActionMessage(
content="<div style='width:100%;text-align:center'> </div>",
actions=[
cl.Action(name="Pédagogie durable", value="Pédagogie durable", label="🔥 Pédagogie durable : exemple : «quels sont les modèles d'apprentissage dans les universités?»"),
cl.Action(name="Lieux d'apprentissage", value="Lieux d'apprentissage", label="🔥 Lieux d'apprentissage : exemple : «donne des exemples de lieu d'apprentissage dans les universités?»"),
cl.Action(name="jdlp", value="Journée de La Pédagogie", label="🔥 Journée de La Pédagogie : exemple : «Quelles sont les bonnes pratiques des plateformes de e-learning?»"),
],
timeout="3600"
).send()
if res:
await cl.Message(f"Vous pouvez requêter sur la thématique : {res.get('value')}").send()
cl.user_session.set("selectRequest", res.get("value"))
########## Chain with streaming ##########
message_history = ChatMessageHistory()
memory = ConversationBufferMemory(memory_key="chat_history",output_key="answer",chat_memory=message_history,return_messages=True)
qa = ConversationalRetrievalChain.from_llm(
model,
memory=memory,
chain_type="stuff",
return_source_documents=True,
verbose=False,
retriever=retriever
)
cl.user_session.set("runnable", qa)
#template = """<s>[INST] Vous êtes un chercheur de l'enseignement supérieur et vous êtes doué pour faire des analyses d'articles de recherche sur les thématiques liées à la pédagogie, en fonction des critères définis ci-avant.
#En fonction des informations suivantes et du contexte suivant seulement et strictement, répondez en langue française strictement à la question ci-dessous à partir du contexte ci-dessous. Si vous ne pouvez pas répondre à la question sur la base des informations, dites que vous ne trouvez pas de réponse ou que vous ne parvenez pas à trouver de réponse. Essayez donc de comprendre en profondeur le contexte et répondez uniquement en vous basant sur les informations fournies. Ne générez pas de réponses non pertinentes.
#{context}
#{question} [/INST] </s>
#"""
#prompt = ChatPromptTemplate.from_messages(
# [
# (
# "system",
# f"Contexte : Vous êtes un chercheur de l'enseignement supérieur et vous êtes doué pour faire des analyses d'articles de recherche sur les thématiques liées à la pédagogie. En fonction des informations suivantes et du contexte suivant seulement et strictement. Contexte : {context}.",
# ),
# MessagesPlaceholder(variable_name="history"),
# ("human", "Réponds à la question suivante de la manière la plus pertinente, la plus exhaustive et la plus détaillée possible. {question}."),
# ]
#)
#runnable = (
# RunnablePassthrough.assign(
# history=RunnableLambda(memory.load_memory_variables) | itemgetter("history")
# )
# | prompt
# | model
# | StrOutputParser()
#)
#cl.user_session.set("memory", memory)
#cl.user_session.set("runnable", runnable)
@cl.on_message
async def on_message(message: cl.Message):
memory = cl.user_session.get("memory")
runnable = cl.user_session.get("runnable") # type: Runnable
msg = cl.Message(content="")
class PostMessageHandler(BaseCallbackHandler):
"""
Callback handler for handling the retriever and LLM processes.
Used to post the sources of the retrieved documents as a Chainlit element.
"""
def __init__(self, msg: cl.Message):
BaseCallbackHandler.__init__(self)
self.msg = msg
self.sources = set() # To store unique pairs
def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
for d in documents:
source_page_pair = (d.metadata['source'], d.metadata['page'])
self.sources.add(source_page_pair) # Add unique pairs to the set
def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
if len(self.sources):
sources_text = "\n".join([f"{source}#page={page}" for source, page in self.sources])
self.msg.elements.append(
cl.Text(name="Sources", content=sources_text, display="inline")
)
async with cl.Step(type="run", name="Réponse de Mistral"):
#async for chunk in runnable.astream(
# {"question": message.content},
# config=RunnableConfig(callbacks=[
# cl.AsyncLangchainCallbackHandler(stream_final_answer=True)
# ]),
#):
# await msg.stream_token(chunk)
cb = cl.AsyncLangchainCallbackHandler()
res = await chain.acall("Contexte : Vous êtes un chercheur de l'enseignement supérieur et vous êtes doué pour faire des analyses d'articles de recherche sur les thématiques liées à la pédagogie, en fonction des critères définis ci-avant. En fonction des informations suivantes et du contexte suivant seulement et strictement, répondez en langue française strictement à la question ci-dessous à partir du contexte ci-dessous. Si vous ne pouvez pas répondre à la question sur la base des informations, dites que vous ne trouvez pas de réponse ou que vous ne parvenez pas à trouver de réponse. Essayez donc de comprendre en profondeur le contexte et répondez uniquement en vous basant sur les informations fournies. Ne générez pas de réponses non pertinentes. Question : " + message.content, callbacks=[cb])
answer = res["answer"]
await cl.Message(content=answer).send()
#await msg.send()
memory.chat_memory.add_user_message(message.content)
memory.chat_memory.add_ai_message(msg.content)