httpdaniel commited on
Commit
ed02a3d
·
1 Parent(s): 333f45e

Adding gradio interface

Browse files
Files changed (1) hide show
  1. app.py +138 -3
app.py CHANGED
@@ -1,7 +1,142 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
 
1
  import gradio as gr
2
+ from langchain_community.document_loaders import PyPDFLoader
3
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
4
+ from langchain_chroma import Chroma
5
+ from langchain_huggingface.embeddings import HuggingFaceEmbeddings
6
+ from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
7
+ from langchain import hub
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.runnables import RunnablePassthrough
10
 
11
+ def initialise_vectorstore(pdf, progress=gr.Progress()):
12
+ progress(0, desc="Reading PDF")
13
+
14
+ loader = PyPDFLoader(pdf.name)
15
+ pages = loader.load()
16
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
17
+ splits = text_splitter.split_documents(pages)
18
+
19
+ progress(0.5, desc="Initialising Vectorstore")
20
+
21
+ vectorstore = Chroma.from_documents(
22
+ splits,
23
+ embedding=HuggingFaceEmbeddings()
24
+ )
25
+
26
+ progress(1, desc="Complete")
27
+
28
+ return vectorstore, progress
29
+
30
+ def initialise_chain(llm, vectorstore, progress=gr.Progress()):
31
+
32
+ progress(0, desc="Initialising LLM")
33
+
34
+ llm = HuggingFaceEndpoint(
35
+ repo_id=llm,
36
+ task="text-generation",
37
+ max_new_tokens=512,
38
+ do_sample=False,
39
+ repetition_penalty=1.03
40
+ )
41
+
42
+ chat = ChatHuggingFace(
43
+ llm=llm,
44
+ verbose=True
45
+ )
46
+
47
+ progress(0.5, desc="Initialising RAG Chain")
48
+
49
+ retriever = vectorstore.as_retriever()
50
+ prompt = hub.pull("rlm/rag-prompt")
51
+ parser = StrOutputParser()
52
+
53
+ rag_chain = {"context": retriever, "question": RunnablePassthrough()} | prompt | chat | parser
54
+
55
+ progress(1, desc="Complete")
56
+
57
+ return rag_chain, progress
58
+
59
+ def send(message, rag_chain, chat_history):
60
+ response = rag_chain.invoke(message)
61
+ chat_history.append((message, response))
62
+ return "", chat_history
63
+
64
+ def restart():
65
+ return f"Restarting"
66
+
67
+
68
+ with gr.Blocks() as demo:
69
+
70
+ vectorstore = gr.State()
71
+ rag_chain = gr.State()
72
+
73
+ gr.Markdown("<H1>Talk to Documents</H1>")
74
+ gr.Markdown("<H3>Upload and ask questions about your PDF files</H3>")
75
+ gr.Markdown("<H6>Note: This project uses LangChain to perform RAG (Retrieval Augmented Generation) on PDF files, allowing users to ask any questions related to their contents. When a PDF file is uploaded, it is embedded and stored in an in-memory Chroma vectorstore, which the chatbot uses as a source of knowledge when aswering user questions.</H6>")
76
+
77
+ # Vectorstore Tab
78
+ with gr.Tab("Vectorstore"):
79
+ with gr.Row():
80
+ input_pdf = gr.File()
81
+ with gr.Row():
82
+ with gr.Column(scale=1, min_width=0):
83
+ pass
84
+ with gr.Column(scale=2, min_width=0):
85
+ initialise_vectorstore_btn = gr.Button(
86
+ "Initialise Vectorstore",
87
+ variant='primary'
88
+ )
89
+ with gr.Column(scale=1, min_width=0):
90
+ pass
91
+ with gr.Row():
92
+ vectorstore_initialisation_progress = gr.Textbox(value="None", label="Initialization")
93
+
94
+ # RAG Chain
95
+ with gr.Tab("RAG Chain"):
96
+ with gr.Row():
97
+ language_model = gr.Radio(["microsoft/Phi-3-mini-4k-instruct", "mistralai/Mistral-7B-Instruct-v0.2", "nvidia/Mistral-NeMo-Minitron-8B-Base"])
98
+ with gr.Row():
99
+ with gr.Column(scale=1, min_width=0):
100
+ pass
101
+ with gr.Column(scale=2, min_width=0):
102
+ initialise_chain_btn = gr.Button(
103
+ "Initialise RAG Chain",
104
+ variant='primary'
105
+ )
106
+ with gr.Column(scale=1, min_width=0):
107
+ pass
108
+ with gr.Row():
109
+ chain_initialisation_progress = gr.Textbox(value="None", label="Initialization")
110
+
111
+ # Chatbot Tab
112
+ with gr.Tab("Chatbot"):
113
+ with gr.Row():
114
+ chatbot = gr.Chatbot()
115
+ with gr.Accordion("Advanced - Document references", open=False):
116
+ with gr.Row():
117
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
118
+ source1_page = gr.Number(label="Page", scale=1)
119
+ with gr.Row():
120
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
121
+ source2_page = gr.Number(label="Page", scale=1)
122
+ with gr.Row():
123
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
124
+ source3_page = gr.Number(label="Page", scale=1)
125
+ with gr.Row():
126
+ message = gr.Textbox()
127
+ with gr.Row():
128
+ send_btn = gr.Button(
129
+ "Send",
130
+ variant=["primary"]
131
+ )
132
+ restart_btn = gr.Button(
133
+ "Restart",
134
+ variant=["secondary"]
135
+ )
136
+
137
+ initialise_vectorstore_btn.click(fn=initialise_vectorstore, inputs=input_pdf, outputs=[vectorstore, vectorstore_initialisation_progress])
138
+ initialise_chain_btn.click(fn=initialise_chain, inputs=[language_model, vectorstore], outputs=[rag_chain, chain_initialisation_progress])
139
+ send_btn.click(fn=send, inputs=[message, rag_chain, chatbot], outputs=[message, chatbot])
140
+ restart_btn.click(fn=restart)
141
 
 
142
  demo.launch()