mouliraj56 commited on
Commit
fff46d0
·
1 Parent(s): ccadf0a

Create run.py

Browse files
Files changed (1) hide show
  1. run.py +308 -0
run.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from threading import Thread
4
+ from datetime import datetime
5
+ from uuid import uuid4
6
+ import gradio as gr
7
+ from time import sleep
8
+ import pprint
9
+ import torch
10
+ from torch import cuda, bfloat16
11
+ import transformers
12
+ from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
13
+ from langchain.document_loaders.pdf import UnstructuredPDFLoader
14
+ from langchain.text_splitter import CharacterTextSplitter
15
+ from langchain.embeddings import HuggingFaceEmbeddings
16
+ from langchain.vectorstores import Chroma
17
+ from langchain.chains import RetrievalQA, ConversationalRetrievalChain
18
+ from langchain.llms import HuggingFacePipeline
19
+
20
+ # model_names = ["tiiuae/falcon-7b-instruct", "tiiuae/falcon-40b-instruct", "tiiuae/falcon-rw-1b"]
21
+ model_names = ["tiiuae/falcon-7b-instruct"]
22
+ models = {}
23
+ embedding_function_name = "all-mpnet-base-v2"
24
+ device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
25
+ max_new_tokens = 1024
26
+ repetition_penalty = 10.0
27
+ temperature = 0
28
+ chunk_size = 512
29
+ chunk_overlap = 32
30
+
31
+
32
+ def get_uuid():
33
+ return str(uuid4())
34
+
35
+
36
+ def create_embedding_function(embedding_function_name):
37
+ return HuggingFaceEmbeddings(model_name=embedding_function_name,
38
+ model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"})
39
+
40
+
41
+ def create_models():
42
+ for model_name in model_names:
43
+
44
+ if model_name == "tiiuae/falcon-40b-instruct":
45
+ bnb_config = transformers.BitsAndBytesConfig(
46
+ load_in_4bit=True,
47
+ bnb_4bit_quant_type='nf4',
48
+ bnb_4bit_use_double_quant=True,
49
+ bnb_4bit_compute_dtype=bfloat16
50
+ )
51
+ model = transformers.AutoModelForCausalLM.from_pretrained(
52
+ model_name,
53
+ trust_remote_code=True,
54
+ quantization_config=bnb_config,
55
+ device_map='auto'
56
+ )
57
+ else:
58
+ model = transformers.AutoModelForCausalLM.from_pretrained(
59
+ model_name,
60
+ trust_remote_code=True,
61
+ torch_dtype=torch.bfloat16,
62
+ device_map='auto'
63
+ )
64
+
65
+ model.eval()
66
+ print(f"Model loaded on {device}")
67
+ models[model_name] = model
68
+
69
+
70
+ create_models()
71
+ embedding_function = create_embedding_function(embedding_function_name)
72
+
73
+
74
+ def user(message, history):
75
+ # Append the user's message to the conversation history
76
+ if history is None:
77
+ history = []
78
+ return "", history + [[message, None]]
79
+
80
+
81
+ def bot(model_name, db_path, chat_mode, history):
82
+ if not history or history[-1][0] == "":
83
+ gr.Info("Please start the conversation by saying something.")
84
+ return None
85
+
86
+ chat_hist = history[:-1]
87
+ if chat_hist:
88
+ chat_hist = [tuple([y.replace("\n", ' ').strip(" ") for y in x]) for x in chat_hist]
89
+
90
+ print("@" * 20)
91
+ print(f"chat_hist:\n {chat_hist}")
92
+ print("@" * 20)
93
+
94
+ print('------------------------------------')
95
+ print(model_name)
96
+ print(db_path)
97
+ print(chat_mode)
98
+ print('------------------------------------')
99
+
100
+ # Need to create langchain model from db for each session
101
+ db = Chroma(persist_directory=db_path, embedding_function=embedding_function)
102
+
103
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
104
+ stop_token_ids = [
105
+ tokenizer.convert_tokens_to_ids(x) for x in [
106
+ ['Question', ':'],
107
+ ['Answer', ':'],
108
+ ['User', ':'],
109
+ ]
110
+ ]
111
+
112
+ class StopOnTokens(StoppingCriteria):
113
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
114
+ for stop_ids in stop_token_ids:
115
+ if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
116
+ return True
117
+ return False
118
+
119
+ stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids]
120
+ stopping_criteria = StoppingCriteriaList([StopOnTokens()])
121
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
122
+ generate_text = transformers.pipeline(
123
+ model=models[model_name], tokenizer=tokenizer,
124
+ return_full_text=True,
125
+ task='text-generation',
126
+ stopping_criteria=stopping_criteria,
127
+ temperature=temperature,
128
+ max_new_tokens=max_new_tokens,
129
+ repetition_penalty=repetition_penalty,
130
+ streamer=streamer
131
+ )
132
+ pipeline = HuggingFacePipeline(pipeline=generate_text)
133
+
134
+ if chat_mode.lower() == 'basic':
135
+ print("chat mode: basic")
136
+ qa = RetrievalQA.from_llm(
137
+ llm=pipeline,
138
+ retriever=db.as_retriever(),
139
+ return_source_documents=True
140
+ )
141
+
142
+ def run_basic(history):
143
+ a = qa({"query": history[-1][0]})
144
+ pprint.pprint(a['source_documents'])
145
+
146
+ t = Thread(target=run_basic, args=(history,))
147
+ t.start()
148
+
149
+ else:
150
+ print("chat mode: conversational")
151
+ qa = ConversationalRetrievalChain.from_llm(
152
+ llm=pipeline,
153
+ retriever=db.as_retriever(),
154
+ return_source_documents=True
155
+ )
156
+
157
+ def run_conv(history, chat_hist):
158
+ a = qa({"question": history[-1][0], "chat_history": chat_hist})
159
+ pprint.pprint(a['source_documents'])
160
+
161
+ t = Thread(target=run_conv, args=(history, chat_hist))
162
+ t.start()
163
+
164
+ history[-1][1] = ""
165
+ for new_text in streamer:
166
+ history[-1][1] += new_text
167
+ time.sleep(0.01)
168
+ yield history
169
+
170
+
171
+ def pdf_changes(pdf_doc):
172
+ print("pdf changes, loading documents")
173
+
174
+ # Persistently store the db next to the uploaded pdf
175
+ db_path, file_ext = os.path.splitext(pdf_doc.name)
176
+
177
+ timestamp = datetime.now()
178
+ db_path += "_" + timestamp.strftime("%Y-%m-%d-%H-%S")
179
+
180
+ loader = UnstructuredPDFLoader(pdf_doc.name)
181
+ documents = loader.load()
182
+ text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
183
+ texts = text_splitter.split_documents(documents)
184
+
185
+ db = Chroma.from_documents(texts, embedding_function, persist_directory=db_path)
186
+ db.persist()
187
+ return db_path
188
+
189
+
190
+ def init():
191
+ with gr.Blocks(
192
+ theme=gr.themes.Soft(),
193
+ css=".disclaimer {font-variant-caps: all-small-caps;}",
194
+ ) as demo:
195
+ gr.HTML(
196
+ """
197
+ <div style="text-align: center; max-width: 650px; margin: 0 auto;">
198
+ <div>
199
+ <img class="logo" src="https://lambdalabs.com/hubfs/logos/lambda-logo.svg" alt="Lambda Logo"
200
+ style="margin: auto; max-width: 7rem;">
201
+ <h1 style="font-weight: 900; font-size: 3rem;">
202
+ Chat With FalconPDF
203
+ </h1>
204
+ </div>
205
+ </div>
206
+ """
207
+ )
208
+
209
+ pdf_doc = gr.File(label="Load a pdf", file_types=['.pdf'], type="file")
210
+ model_id = gr.Radio(label="LLM", choices=model_names, value=model_names[0], interactive=True)
211
+ db_path = gr.Textbox(label="DB_PATH", visible=False)
212
+ chat_mode = gr.Radio(label="Chat mode", choices=['Basic', 'Conversational'], value='Basic',
213
+ info="Basic: no coversational context. Conversational: uses conversational context.")
214
+ chatbot = gr.Chatbot(height=500)
215
+
216
+ with gr.Row():
217
+ with gr.Column():
218
+ msg = gr.Textbox(
219
+ label="Chat Message Box",
220
+ placeholder="Chat Message Box",
221
+ show_label=False,
222
+ container=False
223
+ )
224
+ with gr.Column():
225
+ with gr.Row():
226
+ submit = gr.Button("Submit")
227
+ stop = gr.Button("Stop")
228
+ clear = gr.Button("Clear")
229
+
230
+ gr.Examples(['What is the summary of the paper?',
231
+ 'What is the motivation of the paper?'],
232
+ inputs=msg)
233
+
234
+ def clear_input():
235
+ sleep(1)
236
+ return ""
237
+
238
+ with gr.Row():
239
+ gr.HTML(
240
+ """
241
+ <div class="footer">
242
+ <p> A chatbot tries to give helpful, detailed, and polite answers to the user's questions. Gradio Demo created by <a href="https://lambdalabs.com/">Lambda</a>.</p>
243
+ </div>
244
+ <div class="acknowledgments">
245
+ <p> It is based on Falcon 7B/40B. More information can be found <a href="https://falconllm.tii.ae/">here</a>.</p>
246
+ </div>
247
+ """
248
+ )
249
+
250
+ model_id.change(clear_input, inputs=[], outputs=[msg])
251
+
252
+ pdf_doc.upload(pdf_changes, inputs=[pdf_doc], outputs=[db_path]). \
253
+ then(clear_input, inputs=[], outputs=[msg]). \
254
+ then(lambda: None, None, chatbot)
255
+
256
+ # enter key event
257
+ submit_event = msg.submit(
258
+ fn=user,
259
+ inputs=[msg, chatbot],
260
+ outputs=[msg, chatbot],
261
+ queue=False,
262
+ ).then(
263
+ fn=bot,
264
+ inputs=[
265
+ model_id,
266
+ db_path,
267
+ chat_mode,
268
+ chatbot,
269
+ ],
270
+ outputs=chatbot,
271
+ queue=True,
272
+ )
273
+
274
+ # click submit button event
275
+ submit_click_event = submit.click(
276
+ fn=user,
277
+ inputs=[msg, chatbot],
278
+ outputs=[msg, chatbot],
279
+ queue=False,
280
+ ).then(
281
+ fn=bot,
282
+ inputs=[
283
+ model_id,
284
+ db_path,
285
+ chat_mode,
286
+ chatbot,
287
+ ],
288
+ outputs=chatbot,
289
+ queue=True,
290
+ )
291
+
292
+ stop.click(
293
+ fn=None,
294
+ inputs=None,
295
+ outputs=None,
296
+ cancels=[submit_event, submit_click_event],
297
+ queue=False,
298
+ )
299
+
300
+ clear.click(lambda: None, None, chatbot, queue=False)
301
+
302
+ demo.queue(max_size=32, concurrency_count=2)
303
+
304
+ demo.launch(server_port=8266, inline=False, share=True)
305
+
306
+
307
+ if __name__ == "__main__":
308
+ init()