Spaces:
Sleeping
Sleeping
added app files
Browse files- app.py +92 -0
- chain.py +49 -0
- combined.txt +0 -0
- data.py +25 -0
- docs.pkl +3 -0
app.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import datetime
|
3 |
+
import pickle
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import langchain
|
7 |
+
from langchain.llms import HuggingFaceHub
|
8 |
+
|
9 |
+
from chain import get_new_chain
|
10 |
+
|
11 |
+
api_token = os.environ["HF_TOKEN"]
|
12 |
+
|
13 |
+
|
14 |
+
def get_faiss_store():
|
15 |
+
with open("docs.pkl", "rb") as f:
|
16 |
+
faiss_store = pickle.load(f)
|
17 |
+
return faiss_store
|
18 |
+
|
19 |
+
def load_model():
|
20 |
+
|
21 |
+
print(langchain.__file__)
|
22 |
+
|
23 |
+
vectorstore = get_faiss_store()
|
24 |
+
|
25 |
+
flan_ul = HuggingFaceHub(repo_id="google/flan-ul2",
|
26 |
+
model_kwargs={"temperature":0.1, "max_new_tokens":200},
|
27 |
+
huggingfacehub_api_token=api_token)
|
28 |
+
|
29 |
+
qa_chain = get_new_chain(vectorstore, flan_ul)
|
30 |
+
return qa_chain
|
31 |
+
|
32 |
+
|
33 |
+
def chat(inp, agent):
|
34 |
+
result = []
|
35 |
+
if agent is None:
|
36 |
+
result.append((inp, "Please wait for model to load (3-5 seconds)"))
|
37 |
+
return result
|
38 |
+
print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
|
39 |
+
print("inp: " + inp)
|
40 |
+
result = []
|
41 |
+
output = agent({"question": inp})
|
42 |
+
answer = output["answer"]
|
43 |
+
result.append((inp, answer))
|
44 |
+
print(result)
|
45 |
+
return result
|
46 |
+
|
47 |
+
|
48 |
+
block = gr.Blocks(css=".gradio-container {background-color: lightgray}")
|
49 |
+
|
50 |
+
with block:
|
51 |
+
with gr.Row():
|
52 |
+
gr.Markdown("<h3><center>PotterChat</center></h3><p>Ask questions about the Harry Potter Books, Powered by Flan-UL2</p>")
|
53 |
+
|
54 |
+
chatbot = gr.Chatbot()
|
55 |
+
|
56 |
+
with gr.Row():
|
57 |
+
message = gr.Textbox(
|
58 |
+
label="What's your question?",
|
59 |
+
placeholder="Who was Harry's godfather?",
|
60 |
+
lines=1,
|
61 |
+
)
|
62 |
+
submit = gr.Button(value="Send", variant="secondary").style(full_width=False)
|
63 |
+
|
64 |
+
gr.Examples(
|
65 |
+
examples=[
|
66 |
+
"Which house in Hogwarts was Harry in?",
|
67 |
+
"Who were Harry's best friends?",
|
68 |
+
"Who taught Potions at Hogwarts?",
|
69 |
+
],
|
70 |
+
inputs=message,
|
71 |
+
)
|
72 |
+
|
73 |
+
gr.HTML(
|
74 |
+
"""
|
75 |
+
This simple application uses Langchain, an open-source LLM, and FAISS to do Q&A over the Harry Potter books."""
|
76 |
+
)
|
77 |
+
|
78 |
+
gr.HTML(
|
79 |
+
"<center>Powered by <a href='huggingface.co'>Hugging Face 🤗</a> and <a href='https://github.com/hwchase17/langchain'>LangChain 🦜️🔗</a></center>"
|
80 |
+
)
|
81 |
+
|
82 |
+
# state = gr.State()
|
83 |
+
agent_state = gr.State()
|
84 |
+
|
85 |
+
block.load(load_model, inputs=None, outputs=[agent_state])
|
86 |
+
|
87 |
+
# submit.click(chat, inputs=[message, state, agent_state], outputs=[chatbot, state])
|
88 |
+
# message.submit(chat, inputs=[message, state, agent_state], outputs=[chatbot, state])
|
89 |
+
submit.click(chat, inputs=[message, agent_state], outputs=[chatbot])
|
90 |
+
message.submit(chat, inputs=[message, agent_state], outputs=[chatbot])
|
91 |
+
|
92 |
+
block.launch(debug=True)
|
chain.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
|
3 |
+
from langchain import PromptTemplate
|
4 |
+
from langchain.chains.base import Chain
|
5 |
+
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
6 |
+
from langchain.chains.question_answering import load_qa_chain
|
7 |
+
from langchain.prompts import PromptTemplate
|
8 |
+
from langchain.vectorstores import FAISS
|
9 |
+
from pydantic import BaseModel
|
10 |
+
|
11 |
+
|
12 |
+
class CustomChain(Chain, BaseModel):
|
13 |
+
|
14 |
+
vstore: FAISS
|
15 |
+
chain: BaseCombineDocumentsChain
|
16 |
+
|
17 |
+
@property
|
18 |
+
def input_keys(self) -> List[str]:
|
19 |
+
return ["question"]
|
20 |
+
|
21 |
+
@property
|
22 |
+
def output_keys(self) -> List[str]:
|
23 |
+
return ["answer"]
|
24 |
+
|
25 |
+
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
26 |
+
# def _call(self, inputs: str) -> Dict[str, str]:
|
27 |
+
question = inputs["question"]
|
28 |
+
# question = inputs
|
29 |
+
docs = self.vstore.similarity_search(question, k=5)
|
30 |
+
answer, _ = self.chain.combine_docs(docs, **inputs)
|
31 |
+
return {"answer": answer}
|
32 |
+
|
33 |
+
|
34 |
+
def get_new_chain(vectorstore, llm):
|
35 |
+
flan_template = """Use only the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
36 |
+
|
37 |
+
{context}
|
38 |
+
|
39 |
+
Question: {question}
|
40 |
+
"""
|
41 |
+
PROMPT = PromptTemplate(template=flan_template, input_variables=["question", "context"])
|
42 |
+
|
43 |
+
doc_chain = load_qa_chain(
|
44 |
+
llm,
|
45 |
+
chain_type="stuff",
|
46 |
+
prompt=PROMPT,
|
47 |
+
verbose=True
|
48 |
+
)
|
49 |
+
return CustomChain(chain=doc_chain, vstore=vectorstore)
|
combined.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from langchain.vectorstores import FAISS
|
3 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
4 |
+
|
5 |
+
|
6 |
+
file = open("combined.txt", "r")
|
7 |
+
contents = file.read()
|
8 |
+
|
9 |
+
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
|
10 |
+
|
11 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
12 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
13 |
+
chunk_size = 500,
|
14 |
+
chunk_overlap = 20,
|
15 |
+
length_function = len,
|
16 |
+
)
|
17 |
+
texts = text_splitter.create_documents([contents])
|
18 |
+
|
19 |
+
print("Beginning construction of FAISS DB")
|
20 |
+
docs = FAISS.from_documents(texts, embeddings)
|
21 |
+
|
22 |
+
print("Beginning pickle")
|
23 |
+
with open("docs.pkl", "wb") as f:
|
24 |
+
pickle.dump(docs, f)
|
25 |
+
print("pickle over")
|
docs.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:382cb8f95c5e1aa71364e67eb6383368c231b002173aaf5db073ed793f32b5d6
|
3 |
+
size 502076344
|