|
import os |
|
import json |
|
import bcrypt |
|
import pandas as pd |
|
import numpy as np |
|
from typing import List |
|
from pathlib import Path |
|
from langchain_huggingface import HuggingFaceEndpoint |
|
from langchain.schema import StrOutputParser |
|
|
|
from langchain.agents import AgentExecutor |
|
from langchain.agents.agent_types import AgentType |
|
from langchain_experimental.agents.agent_toolkits import create_csv_agent |
|
|
|
|
|
import chainlit as cl |
|
from chainlit.input_widget import TextInput, Select, Switch, Slider |
|
|
|
from deep_translator import GoogleTranslator |
|
|
|
@cl.step(type="tool") |
|
async def LLMistral(): |
|
os.environ['HUGGINGFACEHUB_API_TOKEN'] = os.environ['HUGGINGFACEHUB_API_TOKEN'] |
|
repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1" |
|
llm = HuggingFaceEndpoint( |
|
repo_id=repo_id, max_new_tokens=5300, temperature=0.1, task="text2text-generation", streaming=True |
|
) |
|
return llm |
|
|
|
@cl.set_chat_profiles |
|
async def chat_profile(): |
|
return [ |
|
cl.ChatProfile(name="Traitement des données d'enquête : «Expé CFA : questionnaire auprès des professionnels de la branche de l'agencement»",markdown_description="Vidéo exploratoire autour de l'événement",icon="/public/logo-ofipe.png",), |
|
] |
|
|
|
@cl.set_starters |
|
async def set_starters(): |
|
return [ |
|
cl.Starter( |
|
label="Répartition du nombre de CAA dans les entreprises", |
|
message="Quel est le nombre de chargé.e d'affaires en agencement dans les entreprises?", |
|
icon="/public/request-theme.svg", |
|
) |
|
] |
|
|
|
@cl.on_message |
|
async def on_message(message: cl.Message): |
|
await cl.Message(f"> SURVEYIA").send() |
|
model = await LLMistral() |
|
|
|
agent = create_csv_agent( |
|
model, |
|
"./public/ExpeCFA_LP_CAA.csv", |
|
verbose=False, |
|
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION |
|
) |
|
|
|
msg = cl.Message(content="") |
|
|
|
class PostMessageHandler(BaseCallbackHandler): |
|
""" |
|
Callback handler for handling the retriever and LLM processes. |
|
Used to post the sources of the retrieved documents as a Chainlit element. |
|
""" |
|
|
|
def __init__(self, msg: cl.Message): |
|
BaseCallbackHandler.__init__(self) |
|
self.msg = msg |
|
self.sources = set() |
|
|
|
def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs): |
|
for d in documents: |
|
source_page_pair = (d.metadata['source'], d.metadata['page']) |
|
self.sources.add(source_page_pair) |
|
|
|
def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs): |
|
sources_text = "\n".join([f"{source}#page={page}" for source, page in self.sources]) |
|
self.msg.elements.append( |
|
cl.Text(name="Sources", content=sources_text, display="inline") |
|
) |
|
|
|
cb = cl.AsyncLangchainCallbackHandler() |
|
res = await agent.acall("Réponds en langue française à la question suivante :\n" + message.content + "\nDétaille la réponse en faisant une analyse complète en 2000 mots minimum.", callbacks=[cb]) |
|
answer = res['output'] |
|
|
|
await cl.Message(content=GoogleTranslator(source='auto', target='fr').translate(answer)).send() |