datacipen commited on
Commit
5431a98
1 Parent(s): 0898c3a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +122 -122
main.py CHANGED
@@ -1,134 +1,134 @@
1
- import os
2
- import json
3
- import bcrypt
4
- import chainlit as cl
5
- from chainlit.input_widget import TextInput, Select, Switch, Slider
6
- from chainlit import user_session
7
- from literalai import LiteralClient
8
- literal_client = LiteralClient(api_key=os.getenv("LITERAL_API_KEY"))
9
-
10
- from operator import itemgetter
11
- from pinecone import Pinecone
12
  from langchain_community.embeddings import HuggingFaceEmbeddings
13
  from langchain_community.llms import HuggingFaceEndpoint
14
- from langchain.memory import ConversationBufferMemory
 
 
15
  from langchain.schema import StrOutputParser
16
- from langchain.schema.runnable import Runnable
17
- from langchain.schema.runnable.config import RunnableConfig
18
- from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableLambda
19
- from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
20
- from langchain_core.prompts import PromptTemplate
21
-
22
- @cl.password_auth_callback
23
- def auth_callback(username: str, password: str):
24
- auth = json.loads(os.environ['CHAINLIT_AUTH_LOGIN'])
25
- ident = next(d['ident'] for d in auth if d['ident'] == username)
26
- pwd = next(d['pwd'] for d in auth if d['ident'] == username)
27
- resultLogAdmin = bcrypt.checkpw(username.encode('utf-8'), bcrypt.hashpw(ident.encode('utf-8'), bcrypt.gensalt()))
28
- resultPwdAdmin = bcrypt.checkpw(password.encode('utf-8'), bcrypt.hashpw(pwd.encode('utf-8'), bcrypt.gensalt()))
29
- resultRole = next(d['role'] for d in auth if d['ident'] == username)
30
- if resultLogAdmin and resultPwdAdmin and resultRole == "admindatapcc":
31
- return cl.User(
32
- identifier=ident + " : 🧑‍💼 Admin Datapcc", metadata={"role": "admin", "provider": "credentials"}
33
- )
34
- elif resultLogAdmin and resultPwdAdmin and resultRole == "userdatapcc":
35
- return cl.User(
36
- identifier=ident + " : 🧑‍🎓 User Datapcc", metadata={"role": "user", "provider": "credentials"}
37
- )
38
-
39
- @cl.author_rename
40
- def rename(orig_author: str):
41
- rename_dict = {"LLMMathChain": "Albert Einstein", "Doc Chain Assistant": "Assistant Reviewstream"}
42
- return rename_dict.get(orig_author, orig_author)
43
-
44
- @cl.set_chat_profiles
45
- async def chat_profile():
46
- return [
47
- cl.ChatProfile(name="Reviewstream",markdown_description="Requêter sur les publications de recherche",icon="/public/logo-ofipe.jpg",),
48
- cl.ChatProfile(name="Imagestream",markdown_description="Requêter sur un ensemble d'images",icon="./public/logo-ofipe.jpg",),
49
- ]
50
-
51
- @cl.on_chat_start
52
- async def on_chat_start():
53
- await cl.Message(f"> REVIEWSTREAM").send()
54
- await cl.Message(f"Nous avons le plaisir de vous accueillir dans l'application de recherche et d'analyse des publications.").send()
55
- listPrompts_name = f"Liste des revues de recherche"
56
- contentPrompts = """<p><img src='/public/hal-logo-header.png' width='32' align='absmiddle' /> <strong> Hal Archives Ouvertes</strong> : Une archive ouverte est un réservoir numérique contenant des documents issus de la recherche scientifique, généralement déposés par leurs auteurs, et permettant au grand public d'y accéder gratuitement et sans contraintes.
57
- </p>
58
- <p><img src='/public/logo-persee.png' width='32' align='absmiddle' /> <strong>Persée</strong> : offre un accès libre et gratuit à des collections complètes de publications scientifiques (revues, livres, actes de colloques, publications en série, sources primaires, etc.) associé à une gamme d'outils de recherche et d'exploitation.</p>
59
- """
60
- prompt_elements = []
61
- prompt_elements.append(
62
- cl.Text(content=contentPrompts, name=listPrompts_name, display="side")
63
  )
64
- await cl.Message(content="📚 " + listPrompts_name, elements=prompt_elements).send()
65
- settings = await cl.ChatSettings(
66
- [
67
- Select(
68
- id="Model",
69
- label="Publications de recherche",
70
- values=["---", "HAL", "Persée"],
71
- initial_index=0,
72
- ),
73
- ]
74
- ).send()
75
-
76
-
77
-
78
- @cl.on_message
79
- async def main(message: cl.Message):
80
- os.environ['PINECONE_API_KEY'] = os.environ['PINECONE_API_KEY']
81
- embeddings = HuggingFaceEmbeddings()
82
- index_name = "all-venus"
83
- pc = Pinecone(
84
- api_key=os.environ['PINECONE_API_KEY']
85
  )
86
- index = pc.Index(index_name)
87
- xq = embeddings.embed_query(message.content)
88
- xc = index.query(vector=xq, filter={"categorie": {"$eq": "bibliographie-OPP-DGDIN"}},top_k=150, include_metadata=True)
89
- context_p = ""
90
- for result in xc['matches']:
91
- context_p = context_p + result['metadata']['text']
92
-
93
-
94
- memory = ConversationBufferMemory(return_messages=True)
95
- template = """<s>[INST] Vous êtes un chercheur de l'enseignement supérieur et vous êtes doué pour faire des analyses d'articles de recherche sur les thématiques liées à la pédagogie, en fonction des critères définis ci-avant.
96
-
97
- En fonction des informations suivantes et du contexte suivant seulement et strictement, répondez en langue française strictement à la question ci-dessous à partir du contexte ci-dessous. Si vous ne pouvez pas répondre à la question sur la base des informations, dites que vous ne trouvez pas de réponse ou que vous ne parvenez pas à trouver de réponse. Essayez donc de comprendre en profondeur le contexte et répondez uniquement en vous basant sur les informations fournies. Ne générez pas de réponses non pertinentes.
 
 
 
 
 
 
 
 
98
  {context}
99
- {question} [/INST] </s>
 
100
  """
101
-
102
- os.environ['HUGGINGFACEHUB_API_TOKEN'] = os.environ['HUGGINGFACEHUB_API_TOKEN']
103
- repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
 
 
 
104
 
105
- model = HuggingFaceEndpoint(
106
- repo_id=repo_id, max_new_tokens=8000, temperature=1.0, task="text2text-generation", streaming=True
107
- )
108
-
109
- prompt = ChatPromptTemplate.from_messages(
110
- [
111
- (
112
- "system",
113
- f"Contexte : Vous êtes un chercheur de l'enseignement supérieur et vous êtes doué pour faire des analyses d'articles de recherche sur les thématiques liées à la pédagogie. En fonction des informations suivantes et du contexte suivant seulement et strictement.",
114
- ),
115
- MessagesPlaceholder(variable_name="history"),
116
- ("human", "Contexte : {context}, réponds à la question suivante de la manière la plus pertinente, la plus exhaustive et la plus détaillée possible. {question}."),
117
- ]
118
- )
119
  runnable = (
120
- RunnablePassthrough.assign(
121
- history=RunnableLambda(memory.load_memory_variables) | itemgetter("history")
122
- )
123
  | prompt
124
  | model
 
125
  )
126
-
127
- msg = cl.Message(author="Assistant Reviewstream",content="")
128
- async for chunk in runnable.astream({"question": message.content, "context":context_p},
129
- config=RunnableConfig(callbacks=[cl.AsyncLangchainCallbackHandler(stream_final_answer=True)])):
130
- await msg.stream_token(chunk)
131
-
132
- await msg.send()
133
- memory.chat_memory.add_user_message(message.content)
134
- memory.chat_memory.add_ai_message(msg.content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from pathlib import Path
 
 
 
 
 
 
 
 
 
3
  from langchain_community.embeddings import HuggingFaceEmbeddings
4
  from langchain_community.llms import HuggingFaceEndpoint
5
+
6
+ #from langchain_openai import ChatOpenAI, OpenAIEmbeddings
7
+ from langchain.prompts import ChatPromptTemplate
8
  from langchain.schema import StrOutputParser
9
+ from langchain_community.document_loaders import (
10
+ PyMuPDFLoader,
11
+ )
12
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
13
+ from langchain.vectorstores.chroma import Chroma
14
+ from langchain.indexes import SQLRecordManager, index
15
+ from langchain.schema import Document
16
+ from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig
17
+ from langchain.callbacks.base import BaseCallbackHandler
18
+
19
+ import chainlit as cl
20
+
21
+
22
+ chunk_size = 1024
23
+ chunk_overlap = 50
24
+
25
+ embeddings_model = HuggingFaceEmbeddings()
26
+
27
+ PDF_STORAGE_PATH = "./public/pdfs"
28
+
29
+
30
+ def process_pdfs(pdf_storage_path: str):
31
+ pdf_directory = Path(pdf_storage_path)
32
+ docs = [] # type: List[Document]
33
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
34
+
35
+ for pdf_path in pdf_directory.glob("*.pdf"):
36
+ loader = PyMuPDFLoader(str(pdf_path))
37
+ documents = loader.load()
38
+ docs += text_splitter.split_documents(documents)
39
+
40
+ doc_search = Chroma.from_documents(docs, embeddings_model)
41
+
42
+ namespace = "chromadb/my_documents"
43
+ record_manager = SQLRecordManager(
44
+ namespace, db_url="sqlite:///record_manager_cache.sql"
 
 
 
 
 
 
 
 
 
 
 
45
  )
46
+ record_manager.create_schema()
47
+
48
+ index_result = index(
49
+ docs,
50
+ record_manager,
51
+ doc_search,
52
+ cleanup="incremental",
53
+ source_id_key="source",
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  )
55
+
56
+ print(f"Indexing stats: {index_result}")
57
+
58
+ return doc_search
59
+
60
+
61
+ doc_search = process_pdfs(PDF_STORAGE_PATH)
62
+ #model = ChatOpenAI(model_name="gpt-4", streaming=True)
63
+ os.environ['HUGGINGFACEHUB_API_TOKEN'] = os.environ['HUGGINGFACEHUB_API_TOKEN']
64
+ repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
65
+
66
+ model = HuggingFaceEndpoint(
67
+ repo_id=repo_id, max_new_tokens=8000, temperature=1.0, task="text2text-generation", streaming=True
68
+ )
69
+
70
+
71
+ @cl.on_chat_start
72
+ async def on_chat_start():
73
+ template = """Answer the question based only on the following context:
74
+
75
  {context}
76
+
77
+ Question: {question}
78
  """
79
+ prompt = ChatPromptTemplate.from_template(template)
80
+
81
+ def format_docs(docs):
82
+ return "\n\n".join([d.page_content for d in docs])
83
+
84
+ retriever = doc_search.as_retriever()
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  runnable = (
87
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
 
 
88
  | prompt
89
  | model
90
+ | StrOutputParser()
91
  )
92
+
93
+ cl.user_session.set("runnable", runnable)
94
+
95
+
96
+ @cl.on_message
97
+ async def on_message(message: cl.Message):
98
+ runnable = cl.user_session.get("runnable") # type: Runnable
99
+ msg = cl.Message(content="")
100
+
101
+ class PostMessageHandler(BaseCallbackHandler):
102
+ """
103
+ Callback handler for handling the retriever and LLM processes.
104
+ Used to post the sources of the retrieved documents as a Chainlit element.
105
+ """
106
+
107
+ def __init__(self, msg: cl.Message):
108
+ BaseCallbackHandler.__init__(self)
109
+ self.msg = msg
110
+ self.sources = set() # To store unique pairs
111
+
112
+ def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
113
+ for d in documents:
114
+ source_page_pair = (d.metadata['source'], d.metadata['page'])
115
+ self.sources.add(source_page_pair) # Add unique pairs to the set
116
+
117
+ def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
118
+ if len(self.sources):
119
+ sources_text = "\n".join([f"{source}#page={page}" for source, page in self.sources])
120
+ self.msg.elements.append(
121
+ cl.Text(name="Sources", content=sources_text, display="inline")
122
+ )
123
+
124
+ async with cl.Step(type="run", name="QA Assistant"):
125
+ async for chunk in runnable.astream(
126
+ message.content,
127
+ config=RunnableConfig(callbacks=[
128
+ cl.LangchainCallbackHandler(),
129
+ PostMessageHandler(msg)
130
+ ]),
131
+ ):
132
+ await msg.stream_token(chunk)
133
+
134
+ await msg.send()