Al-Alcoba-Inciarte commited on
Commit
826babd
1 Parent(s): 420e6d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +385 -55
app.py CHANGED
@@ -1,60 +1,390 @@
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
2
 
3
- from haystack.components.generators import HuggingFaceTGIGenerator
4
-
5
- generator = HuggingFaceTGIGenerator("mistralai/Mixtral-8x7B-Instruct-v0.1")
6
- generator.warm_up()
7
-
8
- from haystack.components.fetchers.link_content import LinkContentFetcher
9
- from haystack.components.converters import HTMLToDocument
10
- from haystack.components.preprocessors import DocumentSplitter
11
- from haystack.components.rankers import TransformersSimilarityRanker
12
- from haystack.components.generators import GPTGenerator
13
- from haystack.components.builders.prompt_builder import PromptBuilder
14
- from haystack import Pipeline
15
-
16
- fetcher = LinkContentFetcher()
17
- converter = HTMLToDocument()
18
- document_splitter = DocumentSplitter(split_by="word", split_length=50)
19
- similarity_ranker = TransformersSimilarityRanker(top_k=3)
20
-
21
- prompt_template = """
22
- According to these documents:
23
-
24
- {% for doc in documents %}
25
- {{ doc.content }}
26
- {% endfor %}
27
-
28
- Answer the given question: {{question}}
29
- Answer:
30
- """
31
- prompt_builder = PromptBuilder(template=prompt_template)
32
-
33
- pipeline = Pipeline()
34
- pipeline.add_component("fetcher", fetcher)
35
- pipeline.add_component("converter", converter)
36
- pipeline.add_component("splitter", document_splitter)
37
- pipeline.add_component("ranker", similarity_ranker)
38
- pipeline.add_component("prompt_builder", prompt_builder)
39
- pipeline.add_component("llm", generator)
40
-
41
- pipeline.connect("fetcher.streams", "converter.sources")
42
- pipeline.connect("converter.documents", "splitter.documents")
43
- pipeline.connect("splitter.documents", "ranker.documents")
44
- pipeline.connect("ranker.documents", "prompt_builder.documents")
45
- pipeline.connect("prompt_builder.prompt", "llm.prompt")
46
-
47
- def respond(prompt, use_rag):
48
- if use_rag:
49
- result = pipeline.run({"prompt_builder": {"question": prompt},
50
- "ranker": {"query": prompt},
51
- "fetcher": {"urls": ["https://haystack.deepset.ai/blog/introducing-haystack-2-beta-and-advent"]},
52
- "llm":{"generation_kwargs": {"max_new_tokens": 350}}})
53
- return result['llm']['replies'][0]
54
  else:
55
- result = generator.run(prompt, generation_kwargs={"max_new_tokens": 350})
56
- return result["replies"][0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- iface = gr.Interface(fn=respond, inputs=["text", "checkbox"], outputs="text")
59
- iface.launch()
 
 
 
 
 
 
60
 
 
 
 
 
1
+ #from haystack.components.generators import HuggingFaceTGIGenerator
2
+ from llama_index.llms import HuggingFaceInferenceAPI
3
+ from llama_index.llms import ChatMessage, MessageRole
4
+ from llama_index.prompts import ChatPromptTemplate
5
+ from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext #, LLMPredictor, StorageContext, load_index_from_storage
6
  import gradio as gr
7
+ #import sys
8
+ #import logging
9
+ #import torch
10
+ #from huggingface_hub import InferenceClient
11
+ #import tqdm as notebook_tqdm
12
+ import requests
13
+ import os
14
+ import json
15
 
16
+ #generator = HuggingFaceTGIGenerator("mistralai/Mixtral-8x7B-Instruct-v0.1")
17
+ #generator.warm_up()
18
+
19
+ def download_file(url, filename):
20
+ """
21
+ Download a file from the specified URL and save it locally under the given filename.
22
+ """
23
+
24
+ response = requests.get(url, stream=True)
25
+
26
+ # Check if the request was successful
27
+
28
+ if filename in os.listdir('content/'): return
29
+ if filename == '': return
30
+
31
+ if response.status_code == 200:
32
+ with open('content/' + filename, 'wb') as file:
33
+ for chunk in response.iter_content(chunk_size=1024):
34
+ if chunk: # filter out keep-alive new chunks
35
+ file.write(chunk)
36
+ print(f"Download complete: {filename}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  else:
38
+ print(f"Error: Unable to download file. HTTP status code: {response.status_code}")
39
+
40
+ #def save_answer(prompt, rag_answer, norag_answer):
41
+ # json_dict = dict()
42
+ # json_dict['prompt'] = prompt
43
+ # json_dict['rag_answer'] = rag_answer
44
+ # json_dict['norag_answer'] = norag_answer
45
+ #
46
+ # file_path = 'saved_answers.json'
47
+ #
48
+ # # Check if the file exists
49
+ # if not os.path.isfile(file_path):
50
+ # with open(file_path, 'w') as f:
51
+ # # Create an empty list in the file to store dictionaries
52
+ # json.dump([], f)
53
+ # f.write('\n') # Add a newline to separate the list and future entries
54
+ #
55
+ # # Open the file in append mode
56
+ # with open(file_path, 'a+') as f:
57
+ # # Read the existing data
58
+ # f.seek(0)
59
+ # data = json.load(f)
60
+ #
61
+ # # Append the new dictionary to the list
62
+ # data.append(json_dict)
63
+ #
64
+ # # Move the cursor to the beginning of the file
65
+ # f.seek(0)
66
+ #
67
+ # # Write the updated list of dictionaries
68
+ # json.dump(data, f)
69
+ # f.write('\n') # Add a newline to separate the list and future entries
70
+ #
71
+ #
72
+ #def check_answer(prompt):
73
+ # file_path = 'saved_answers.json'
74
+ #
75
+ # if not os.path.isfile(file_path):
76
+ # with open(file_path, 'w') as f:
77
+ # # Create an empty list in the file to store dictionaries
78
+ # json.dump([], f)
79
+ # f.write('\n') # Add a newline to separate the list and future entries
80
+ # with open('saved_answers.json', 'r') as f:
81
+ # data = json.load(f)
82
+ # for entry in data:
83
+ # if entry['prompt'] == prompt:
84
+ # return entry['rag_answer'], entry['norag_answer']
85
+ # return None, None # Return None if the prompt is not found
86
+
87
+
88
+ def save_answer(prompt, rag_answer, norag_answer):
89
+ file_path = 'saved_answers.jsonl'
90
+
91
+ # Create a dictionary for the current answer
92
+ json_dict = {
93
+ 'prompt': prompt,
94
+ 'rag_answer': rag_answer,
95
+ 'norag_answer': norag_answer
96
+ }
97
+
98
+ # Check if the file exists, and create it if not
99
+ #if not os.path.isfile(file_path):
100
+ # with open(file_path, 'w') as f:
101
+ # # Create an empty list in the file to store dictionaries
102
+ # json.dump([], f)
103
+ # f.write('\n') # Add a newline to separate the list and future entries
104
+
105
+ # Load existing data from the file
106
+ existing_data = load_jsonl(file_path)
107
+
108
+ # Append the new answer to the existing data
109
+ existing_data.append(json_dict)
110
+
111
+ # Save the updated data back to the file
112
+ write_to_jsonl(file_path, existing_data)
113
+
114
+ def check_answer(prompt):
115
+ file_path = 'saved_answers.jsonl'
116
+
117
+ ## Check if the file exists, and create it if not
118
+ #if not os.path.isfile(file_path):
119
+ # with open(file_path, 'w') as f:
120
+ # # Create an empty list in the file to store dictionaries
121
+ # json.dump([], f)
122
+ # f.write('\n') # Add a newline to separate the list and future entries
123
+
124
+ # Load existing data from the file
125
+
126
+ try:
127
+ existing_data = load_jsonl(file_path)
128
+
129
+ except:
130
+ return None, None
131
+
132
+ if len(existing_data) == 0:
133
+ return None, None
134
+
135
+ # Find the answer for the given prompt, if it exists
136
+ for entry in existing_data:
137
+ if entry['prompt'] == prompt:
138
+ return entry['rag_answer'], entry['norag_answer']
139
+
140
+ # Return None if the prompt is not found
141
+ return None, None
142
+
143
+ # Helper functions
144
+ def load_jsonl(file_path):
145
+ data = []
146
+ with open(file_path, 'r') as file:
147
+ for line in file:
148
+ # Each line is a JSON object
149
+ item = json.loads(line)
150
+ data.append(item)
151
+ return data
152
+
153
+ def write_to_jsonl(file_path, data):
154
+ with open(file_path, 'a+') as file:
155
+ for item in data:
156
+ # Convert Python object to JSON string and write it to the file
157
+ json_line = json.dumps(item)
158
+ file.write(json_line + '\n')
159
+
160
+
161
+
162
+ def generate(prompt, history, rag_only, file_link, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,):
163
+
164
+ rag_answer, norag_answer = check_answer(prompt)
165
+
166
+ if rag_answer != None:
167
+ if rag_only:
168
+ return f'* Mixtral + RAG Output:\n{rag_answer}'
169
+ else:
170
+ return f'* Mixtral Output:\n{norag_answer}\n\n* Mixtral + RAG Output:\n{rag_answer}'
171
+
172
+ mixtral = HuggingFaceInferenceAPI(
173
+ model_name="mistralai/Mixtral-8x7B-Instruct-v0.1"
174
+ #Mistral-7B-Instruct-v0.2
175
+ )
176
+
177
+ service_context = ServiceContext.from_defaults(
178
+ llm=mixtral, embed_model="local:BAAI/bge-small-en-v1.5"
179
+ )
180
+
181
+ download = download_file(file_link,file_link.split("/")[-1])
182
+
183
+ documents = SimpleDirectoryReader("content/").load_data()
184
+
185
+ index = VectorStoreIndex.from_documents(documents,service_context=service_context)
186
+
187
+ # Text QA Prompt
188
+ chat_text_qa_msgs = [
189
+ ChatMessage(
190
+ role=MessageRole.SYSTEM,
191
+ content=(
192
+ "Always answer the question, even if the context isn't helpful."
193
+ ),
194
+ ),
195
+ ChatMessage(
196
+ role=MessageRole.USER,
197
+ content=(
198
+ "Context information is below.\n"
199
+ "---------------------\n"
200
+ "{context_str}\n"
201
+ "---------------------\n"
202
+ "Given the context information and not prior knowledge, "
203
+ "answer the question: {query_str}\n"
204
+ ),
205
+ ),
206
+ ]
207
+ text_qa_template = ChatPromptTemplate(chat_text_qa_msgs)
208
+
209
+ # Refine Prompt
210
+ chat_refine_msgs = [
211
+ ChatMessage(
212
+ role=MessageRole.SYSTEM,
213
+ content=(
214
+ "Always answer the question, even if the context isn't helpful."
215
+ ),
216
+ ),
217
+ ChatMessage(
218
+ role=MessageRole.USER,
219
+ content=(
220
+ "We have the opportunity to refine the original answer "
221
+ "(only if needed) with some more context below.\n"
222
+ "------------\n"
223
+ "{context_msg}\n"
224
+ "------------\n"
225
+ "Given the new context, refine the original answer to better "
226
+ "answer the question: {query_str}. "
227
+ "If the context isn't useful, output the original answer again.\n"
228
+ "Original Answer: {existing_answer}"
229
+ ),
230
+ ),
231
+ ]
232
+ refine_template = ChatPromptTemplate(chat_refine_msgs)
233
+
234
+ temperature = float(temperature)
235
+ if temperature < 1e-2:
236
+ temperature = 1e-2
237
+ top_p = float(top_p)
238
+
239
+ stream= index.as_query_engine(
240
+ text_qa_template=text_qa_template, refine_template=refine_template, similarity_top_k=6, temperature = temperature,
241
+ max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty = repetition_penalty
242
+ ).query(prompt)
243
+ print(str(stream))
244
+
245
+ output_rag= str(stream) #""
246
+
247
+ #output_norag = mixtral.complete(prompt, details=True, similarity_top_k=6, temperature = temperature,
248
+ # max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty = repetition_penalty)
249
+
250
+ #for response in str(stream):
251
+ # output += response
252
+ # yield output
253
+
254
+ #print(output_norag)
255
+
256
+
257
+ #result = generator.run(prompt, generation_kwargs={"max_new_tokens": 350})
258
+ #output_norag = result["replies"][0]
259
+
260
+
261
+ ### NORAG
262
+
263
+ if rag_only == False:
264
+ chat_text_qa_msgs_nr = [
265
+ ChatMessage(
266
+ role=MessageRole.SYSTEM,
267
+ content=(
268
+ "Always answer the question"
269
+ ),
270
+ ),
271
+ ChatMessage(
272
+ role=MessageRole.USER,
273
+ content=(
274
+ "answer the question: {query_str}\n"
275
+ ),
276
+ ),
277
+ ]
278
+ text_qa_template_nr = ChatPromptTemplate(chat_text_qa_msgs_nr)
279
+
280
+ # Refine Prompt
281
+ chat_refine_msgs_nr = [
282
+ ChatMessage(
283
+ role=MessageRole.SYSTEM,
284
+ content=(
285
+ "Always answer the question"
286
+ ),
287
+ ),
288
+ ChatMessage(
289
+ role=MessageRole.USER,
290
+ content=(
291
+ "answer the question: {query_str}. "
292
+ "If the context isn't useful, output the original answer again.\n"
293
+ "Original Answer: {existing_answer}"
294
+ ),
295
+ ),
296
+ ]
297
+ refine_template_nr = ChatPromptTemplate(chat_refine_msgs_nr)
298
+
299
+ stream_nr= index.as_query_engine(
300
+ text_qa_template=text_qa_template_nr, refine_template=refine_template_nr, similarity_top_k=6
301
+ ).query(prompt)
302
+
303
+ ###
304
+
305
+ output_norag = str(stream_nr)
306
+ save_answer(prompt, output_rag, output_norag)
307
+
308
+ return f'* Mixtral Output:\n{output_norag}\n\n* Mixtral + RAG Output:\n{output_rag}'
309
+
310
+ return f'* Mixtral + RAG Output:\n{output_rag}'
311
+
312
+ #for response in formatted_output:
313
+ # output += response
314
+ # yield output
315
+ #return formatted_output
316
+
317
+ def upload_file(files):
318
+ file_paths = [file.name for file in files]
319
+ return file_paths
320
+
321
+ additional_inputs=[
322
+ gr.Checkbox(
323
+ label="RAG Only",
324
+ interactive=True,
325
+ value= False
326
+ ),
327
+ gr.Textbox(
328
+ label="File Link",
329
+ max_lines=1,
330
+ interactive=True,
331
+ value= "https://arxiv.org/pdf/2401.10020.pdf"
332
+ ),
333
+ gr.Slider(
334
+ label="Temperature",
335
+ value=0.9,
336
+ minimum=0.0,
337
+ maximum=1.0,
338
+ step=0.05,
339
+ interactive=True,
340
+ info="Higher values produce more diverse outputs",
341
+ ),
342
+ gr.Slider(
343
+ label="Max new tokens",
344
+ value=1024,
345
+ minimum=0,
346
+ maximum=2048,
347
+ step=64,
348
+ interactive=True,
349
+ info="The maximum numbers of new tokens",
350
+ ),
351
+ gr.Slider(
352
+ label="Top-p (nucleus sampling)",
353
+ value=0.90,
354
+ minimum=0.0,
355
+ maximum=1,
356
+ step=0.05,
357
+ interactive=True,
358
+ info="Higher values sample more low-probability tokens",
359
+ ),
360
+ gr.Slider(
361
+ label="Repetition penalty",
362
+ value=1.2,
363
+ minimum=1.0,
364
+ maximum=2.0,
365
+ step=0.05,
366
+ interactive=True,
367
+ info="Penalize repeated tokens",
368
+ )
369
+ ]
370
+
371
+ examples=[["What is a trustworthy digital repository, where can you find this information?", None, None, None, None, None, None, ],
372
+ ["What are things a repository must have?", None, None, None, None, None, None,],
373
+ ["What principles should record creators follow?", None, None, None, None, None, None,],
374
+ ["Write a very short summary of Data Sanitation Techniques by Edgar Dale, and write a citation in APA style.", None, None, None, None, None, None,],
375
+ ["Can you explain how the QuickSort algorithm works and provide a Python implementation?", None, None, None, None, None, None,],
376
+ ["What are some unique features of Rust that make it stand out compared to other systems programming languages like C++?", None, None, None, None, None, None,],
377
+ ]
378
 
379
+ gr.ChatInterface(
380
+ fn=generate,
381
+ chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
382
+ additional_inputs=additional_inputs,
383
+ title="RAG Demo",
384
+ examples=examples,
385
+ #concurrency_limit=20,
386
+ ).queue().launch(show_api=False,debug=True,share=True)
387
 
388
+ #iface = gr.Interface(fn=generate, inputs=["text"], outputs=["text", "text"],
389
+ # additional_inputs=additional_inputs, title="RAG Demo", examples=examples)
390
+ #iface.launch()