mariagrandury commited on
Commit
2614912
·
1 Parent(s): bed03be

remove unused imports and function, rename functions and fix llmchain init progress

Browse files
Files changed (1) hide show
  1. app.py +10 -27
app.py CHANGED
@@ -2,20 +2,15 @@ import os
2
  import re
3
  from pathlib import Path
4
 
5
- import accelerate
6
  import chromadb
7
  import gradio as gr
8
- import torch
9
- import tqdm
10
- import transformers
11
- from langchain.chains import ConversationalRetrievalChain, ConversationChain
12
  from langchain.memory import ConversationBufferMemory
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from langchain_community.document_loaders import PyPDFLoader
15
  from langchain_community.embeddings import HuggingFaceEmbeddings
16
- from langchain_community.llms import HuggingFaceEndpoint, HuggingFacePipeline
17
  from langchain_community.vectorstores import Chroma
18
- from transformers import AutoTokenizer
19
  from unidecode import unidecode
20
 
21
  list_llm = [
@@ -31,8 +26,7 @@ list_llm = [
31
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
32
 
33
 
34
- # Load PDF document and create doc splits
35
- def load_doc(list_file_path, chunk_size, chunk_overlap):
36
  # Processing for one document only
37
  # loader = PyPDFLoader(file_path)
38
  # pages = loader.load()
@@ -48,8 +42,7 @@ def load_doc(list_file_path, chunk_size, chunk_overlap):
48
  return doc_splits
49
 
50
 
51
- # Create vector database
52
- def create_db(splits, collection_name):
53
  embedding = HuggingFaceEmbeddings()
54
  new_client = chromadb.EphemeralClient()
55
  vectordb = Chroma.from_documents(
@@ -61,21 +54,10 @@ def create_db(splits, collection_name):
61
  return vectordb
62
 
63
 
64
- # Load vector database
65
- def load_db():
66
- embedding = HuggingFaceEmbeddings()
67
- vectordb = Chroma(embedding_function=embedding)
68
- return vectordb
69
-
70
-
71
- # Initialize langchain LLM chain
72
  def initialize_llmchain(
73
  llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()
74
  ):
75
- progress(0.1, desc="Initializing HF tokenizer...")
76
-
77
- # HuggingFaceHub uses HF inference endpoints
78
- progress(0.5, desc="Initializing HF Hub...")
79
  if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
80
  llm = HuggingFaceEndpoint(
81
  repo_id=llm_model,
@@ -92,14 +74,14 @@ def initialize_llmchain(
92
  top_k=top_k,
93
  )
94
 
95
- progress(0.75, desc="Defining buffer memory...")
96
  memory = ConversationBufferMemory(
97
  memory_key="chat_history", output_key="answer", return_messages=True
98
  )
99
  # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
100
  retriever = vector_db.as_retriever()
101
 
102
- progress(0.8, desc="Defining retrieval chain...")
103
  qa_chain = ConversationalRetrievalChain.from_llm(
104
  llm,
105
  retriever=retriever,
@@ -108,6 +90,7 @@ def initialize_llmchain(
108
  return_source_documents=True,
109
  verbose=False,
110
  )
 
111
  progress(0.9, desc="Done!")
112
  return qa_chain
113
 
@@ -148,10 +131,10 @@ def initialize_database(
148
  collection_name = create_collection_name(list_file_path[0])
149
 
150
  progress(0.25, desc="Loading document...")
151
- doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
152
 
153
  progress(0.5, desc="Generating vector database...")
154
- vector_db = create_db(doc_splits, collection_name)
155
 
156
  progress(0.9, desc="Done!")
157
  return vector_db, collection_name, "Complete!"
 
2
  import re
3
  from pathlib import Path
4
 
 
5
  import chromadb
6
  import gradio as gr
7
+ from langchain.chains import ConversationalRetrievalChain
 
 
 
8
  from langchain.memory import ConversationBufferMemory
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
  from langchain_community.document_loaders import PyPDFLoader
11
  from langchain_community.embeddings import HuggingFaceEmbeddings
12
+ from langchain_community.llms import HuggingFaceEndpoint
13
  from langchain_community.vectorstores import Chroma
 
14
  from unidecode import unidecode
15
 
16
  list_llm = [
 
26
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
27
 
28
 
29
+ def load_doc_and_create_splits(list_file_path, chunk_size, chunk_overlap):
 
30
  # Processing for one document only
31
  # loader = PyPDFLoader(file_path)
32
  # pages = loader.load()
 
42
  return doc_splits
43
 
44
 
45
+ def create_vector_db(splits, collection_name):
 
46
  embedding = HuggingFaceEmbeddings()
47
  new_client = chromadb.EphemeralClient()
48
  vectordb = Chroma.from_documents(
 
54
  return vectordb
55
 
56
 
 
 
 
 
 
 
 
 
57
  def initialize_llmchain(
58
  llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()
59
  ):
60
+ progress(0.1, desc="Initializing HF Hub...")
 
 
 
61
  if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
62
  llm = HuggingFaceEndpoint(
63
  repo_id=llm_model,
 
74
  top_k=top_k,
75
  )
76
 
77
+ progress(0.6, desc="Defining buffer memory...")
78
  memory = ConversationBufferMemory(
79
  memory_key="chat_history", output_key="answer", return_messages=True
80
  )
81
  # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
82
  retriever = vector_db.as_retriever()
83
 
84
+ progress(0.75, desc="Defining retrieval chain...")
85
  qa_chain = ConversationalRetrievalChain.from_llm(
86
  llm,
87
  retriever=retriever,
 
90
  return_source_documents=True,
91
  verbose=False,
92
  )
93
+
94
  progress(0.9, desc="Done!")
95
  return qa_chain
96
 
 
131
  collection_name = create_collection_name(list_file_path[0])
132
 
133
  progress(0.25, desc="Loading document...")
134
+ doc_splits = load_doc_and_create_splits(list_file_path, chunk_size, chunk_overlap)
135
 
136
  progress(0.5, desc="Generating vector database...")
137
+ vector_db = create_vector_db(doc_splits, collection_name)
138
 
139
  progress(0.9, desc="Done!")
140
  return vector_db, collection_name, "Complete!"