Spaces:
Paused
Paused
Commit
·
5ef51e2
1
Parent(s):
e29cf0b
Update app.py
Browse files
app.py
CHANGED
@@ -46,6 +46,7 @@ def load_scraped_web_info():
|
|
46 |
|
47 |
|
48 |
|
|
|
49 |
@st.cache_resource
|
50 |
def load_embedding_model():
|
51 |
embedding_model = HuggingFaceInstructEmbeddings(model_name='hkunlp/instructor-base',
|
@@ -57,6 +58,31 @@ def load_faiss_index():
|
|
57 |
vector_database = FAISS.load_local("faiss_index", embedding_model)
|
58 |
return vector_database
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
#--------------
|
62 |
|
@@ -65,19 +91,36 @@ def load_faiss_index():
|
|
65 |
load_scraped_web_info()
|
66 |
embedding_model = load_embedding_model()
|
67 |
vector_database = load_faiss_index()
|
68 |
-
|
|
|
69 |
|
70 |
|
|
|
|
|
71 |
|
72 |
|
73 |
query_input = st.text_input(label= 'your question')
|
|
|
|
|
74 |
def retrieve_document(query_input):
|
75 |
related_doc = vector_database.similarity_search(query_input)
|
76 |
return related_doc
|
77 |
|
78 |
-
|
|
|
|
|
|
|
|
|
79 |
value = retrieve_document(query_input))
|
80 |
|
81 |
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
|
48 |
|
49 |
+
|
50 |
@st.cache_resource
|
51 |
def load_embedding_model():
|
52 |
embedding_model = HuggingFaceInstructEmbeddings(model_name='hkunlp/instructor-base',
|
|
|
58 |
vector_database = FAISS.load_local("faiss_index", embedding_model)
|
59 |
return vector_database
|
60 |
|
61 |
+
@st.cache_resource
|
62 |
+
def load_llm_model():
|
63 |
+
# llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
|
64 |
+
# task= 'text2text-generation',
|
65 |
+
# model_kwargs={ "device_map": "auto",
|
66 |
+
# "load_in_8bit": True,"max_length": 256, "temperature": 0,
|
67 |
+
# "repetition_penalty": 1.5})
|
68 |
+
|
69 |
+
|
70 |
+
llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
|
71 |
+
task= 'text2text-generation',
|
72 |
+
|
73 |
+
model_kwargs={ "max_length": 256, "temperature": 0,
|
74 |
+
"torch_dtype":torch.float32,
|
75 |
+
"repetition_penalty": 1.3})
|
76 |
+
|
77 |
+
|
78 |
+
return llm
|
79 |
+
|
80 |
+
|
81 |
+
def load_retriever(llm, db):
|
82 |
+
qa_retriever = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff",
|
83 |
+
retriever=db.as_retriever())
|
84 |
+
|
85 |
+
return qa_retriever
|
86 |
|
87 |
#--------------
|
88 |
|
|
|
91 |
load_scraped_web_info()
|
92 |
embedding_model = load_embedding_model()
|
93 |
vector_database = load_faiss_index()
|
94 |
+
llm_model = load_llm_model()
|
95 |
+
qa_retriever = load_retriever(llm= llm_model, db= vector_database)
|
96 |
|
97 |
|
98 |
+
print("all load done")
|
99 |
+
|
100 |
|
101 |
|
102 |
query_input = st.text_input(label= 'your question')
|
103 |
+
|
104 |
+
|
105 |
def retrieve_document(query_input):
|
106 |
related_doc = vector_database.similarity_search(query_input)
|
107 |
return related_doc
|
108 |
|
109 |
+
def retrieve_answer(query_input):
|
110 |
+
answer = qa_retriever.run(query_input)
|
111 |
+
return answer
|
112 |
+
|
113 |
+
output_1 = st.text_area(label = "Here is the relevant documents",
|
114 |
value = retrieve_document(query_input))
|
115 |
|
116 |
|
117 |
+
output_2 = st.text_area(label = "Here is the answer",
|
118 |
+
value = retrieve_answer(query_input))
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
# faiss_retriever = vector_database.as_retriever()
|
123 |
+
# print("Succesfully had FAISS as retriever")
|
124 |
+
|
125 |
+
|
126 |
+
|