PierreBrunelle commited on
Commit
98cccd6
1 Parent(s): 2aa40c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -30
app.py CHANGED
@@ -1,22 +1,32 @@
1
  import gradio as gr
2
  import pandas as pd
 
 
 
3
  import pixeltable as pxt
4
  from pixeltable.iterators import DocumentSplitter
5
  import numpy as np
6
  from pixeltable.functions.huggingface import sentence_transformer
7
  from pixeltable.functions import openai
 
 
 
 
8
  import os
 
9
 
10
  """## Store OpenAI API Key"""
11
 
12
  if 'OPENAI_API_KEY' not in os.environ:
13
- os.environ['OPENAI_API_KEY'] = getpass.getpass('Enter your OpenAI API key:')
 
 
 
14
 
15
- """Pixeltable Set up"""
 
16
 
17
- # Ensure a clean slate for the demo
18
- pxt.drop_dir('rag_demo', force=True)
19
- pxt.create_dir('rag_demo')
20
 
21
  # Set up embedding function
22
  @pxt.expr_udf
@@ -38,9 +48,12 @@ def create_prompt(top_k_list: list[dict], question: str) -> str:
38
 
39
  {question}'''
40
 
41
- # Gradio Application
42
- def process_files(ground_truth_file, pdf_files):
 
43
  # Ensure a clean slate for the demo by removing and recreating the 'rag_demo' directory
 
 
44
  pxt.drop_dir('rag_demo', force=True)
45
  pxt.create_dir('rag_demo')
46
 
@@ -51,12 +64,14 @@ def process_files(ground_truth_file, pdf_files):
51
  else:
52
  queries_t = pxt.io.import_excel('rag_demo.queries', ground_truth_file.name)
53
 
 
 
54
  # Create a table to store the uploaded PDF documents
55
  documents_t = pxt.create_table(
56
  'rag_demo.documents',
57
  {'document': pxt.DocumentType()}
58
  )
59
-
60
  # Insert the PDF files into the documents table
61
  documents_t.insert({'document': file.name} for file in pdf_files if file.name.endswith('.pdf'))
62
 
@@ -66,11 +81,13 @@ def process_files(ground_truth_file, pdf_files):
66
  documents_t,
67
  iterator=DocumentSplitter.create(
68
  document=documents_t.document,
69
- separators='token_limit',
70
- limit=300
71
  )
72
  )
73
 
 
 
74
  # Add an embedding index to the chunks for similarity search
75
  chunks_t.add_embedding_index('text', string_embed=e5_embed)
76
 
@@ -85,16 +102,16 @@ def process_files(ground_truth_file, pdf_files):
85
  )
86
 
87
  # Add computed columns to the queries table for context retrieval and prompt creation
88
- queries_t['question_context'] = chunks_t.top_k(queries_t.Question)
89
  queries_t['prompt'] = create_prompt(
90
- queries_t.question_context, queries_t.Question
91
  )
92
 
93
  # Prepare messages for the OpenAI API, including system instructions and user prompt
94
- messages = [
95
  {
96
  'role': 'system',
97
- 'content': 'Please read the following passages and answer the question based on their contents.'
98
  },
99
  {
100
  'role': 'user',
@@ -102,16 +119,55 @@ def process_files(ground_truth_file, pdf_files):
102
  }
103
  ]
104
 
105
- # Add OpenAI response column
 
 
106
  queries_t['response'] = openai.chat_completions(
107
- model='gpt-4o-mini-2024-07-18', messages=messages
 
 
 
 
108
  )
109
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  # Extract the answer text from the API response
111
- queries_t['answer'] = queries_t.response.choices[0].message.content.astype(pxt.StringType())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- # Prepare the output dataframe with questions, correct answers, and model-generated answers
114
- df_output = queries_t.select(queries_t.Question, queries_t.correct_answer, queries_t.answer).collect().to_pandas()
115
 
116
  try:
117
  # Return the output dataframe for display
@@ -119,26 +175,122 @@ def process_files(ground_truth_file, pdf_files):
119
  except Exception as e:
120
  return f"An error occurred: {str(e)}", None
121
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  # Gradio interface
123
- with gr.Blocks() as demo:
124
- gr.Markdown("# RAG Demo App")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  # File upload components for ground truth and PDF documents
127
  with gr.Row():
128
- ground_truth_file = gr.File(label="Upload Ground Truth (CSV or XLSX)", file_count="single")
129
  pdf_files = gr.File(label="Upload PDF Documents", file_count="multiple")
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  # Button to trigger file processing
132
  process_button = gr.Button("Process Files and Generate Outputs")
133
 
134
  # Output component to display the results
135
- df_output = gr.DataFrame(label="Pixeltable Table")
 
 
 
 
 
 
 
 
136
 
137
- process_button.click(process_files, inputs=[ground_truth_file, pdf_files], outputs=df_output)
138
- #question_input = gr.Textbox(label="Enter your question")
139
- #query_button = gr.Button("Query LLM")
140
-
141
- #query_button.click(query_llm, inputs=question_input, outputs=output_dataframe)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  if __name__ == "__main__":
144
- demo.launch()
 
1
  import gradio as gr
2
  import pandas as pd
3
+ import io
4
+ import base64
5
+ import uuid
6
  import pixeltable as pxt
7
  from pixeltable.iterators import DocumentSplitter
8
  import numpy as np
9
  from pixeltable.functions.huggingface import sentence_transformer
10
  from pixeltable.functions import openai
11
+ from pixeltable.functions.fireworks import chat_completions as f_chat_completions
12
+ from pixeltable.functions.mistralai import chat_completions
13
+ from gradio.themes import Monochrome
14
+
15
  import os
16
+ import getpass
17
 
18
  """## Store OpenAI API Key"""
19
 
20
  if 'OPENAI_API_KEY' not in os.environ:
21
+ os.environ['OPENAI_API_KEY'] = getpass.getpass('OpenAI API key:')
22
+
23
+ if 'FIREWORKS_API_KEY' not in os.environ:
24
+ os.environ['FIREWORKS_API_KEY'] = getpass.getpass('Fireworks API Key:')
25
 
26
+ if 'MISTRAL_API_KEY' not in os.environ:
27
+ os.environ['MISTRAL_API_KEY'] = getpass.getpass('Mistral AI API Key:')
28
 
29
+ """## Creating UDFs: Embedding and Prompt Functions"""
 
 
30
 
31
  # Set up embedding function
32
  @pxt.expr_udf
 
48
 
49
  {question}'''
50
 
51
+ """Gradio Application"""
52
+
53
+ def process_files(ground_truth_file, pdf_files, chunk_limit, chunk_separator, show_question, show_correct_answer, show_gpt4omini, show_llamav3p23b, show_mistralsmall, progress=gr.Progress()):
54
  # Ensure a clean slate for the demo by removing and recreating the 'rag_demo' directory
55
+ progress(0, desc="Initializing...")
56
+
57
  pxt.drop_dir('rag_demo', force=True)
58
  pxt.create_dir('rag_demo')
59
 
 
64
  else:
65
  queries_t = pxt.io.import_excel('rag_demo.queries', ground_truth_file.name)
66
 
67
+ progress(0.2, desc="Processing documents...")
68
+
69
  # Create a table to store the uploaded PDF documents
70
  documents_t = pxt.create_table(
71
  'rag_demo.documents',
72
  {'document': pxt.DocumentType()}
73
  )
74
+
75
  # Insert the PDF files into the documents table
76
  documents_t.insert({'document': file.name} for file in pdf_files if file.name.endswith('.pdf'))
77
 
 
81
  documents_t,
82
  iterator=DocumentSplitter.create(
83
  document=documents_t.document,
84
+ separators=chunk_separator,
85
+ limit=chunk_limit if chunk_separator in ["token_limit", "char_limit"] else None
86
  )
87
  )
88
 
89
+ progress(0.4, desc="Generating embeddings...")
90
+
91
  # Add an embedding index to the chunks for similarity search
92
  chunks_t.add_embedding_index('text', string_embed=e5_embed)
93
 
 
102
  )
103
 
104
  # Add computed columns to the queries table for context retrieval and prompt creation
105
+ queries_t['question_context'] = chunks_t.top_k(queries_t.question)
106
  queries_t['prompt'] = create_prompt(
107
+ queries_t.question_context, queries_t.question
108
  )
109
 
110
  # Prepare messages for the OpenAI API, including system instructions and user prompt
111
+ msgs = [
112
  {
113
  'role': 'system',
114
+ 'content': 'Read the following passages and answer the question based on their contents.'
115
  },
116
  {
117
  'role': 'user',
 
119
  }
120
  ]
121
 
122
+ progress(0.6, desc="Querying models...")
123
+
124
+ # Add OpenAI response column
125
  queries_t['response'] = openai.chat_completions(
126
+ model='gpt-4o-mini-2024-07-18',
127
+ messages=msgs,
128
+ max_tokens=300,
129
+ top_p=0.9,
130
+ temperature=0.7
131
  )
132
+
133
+ # Create a table in Pixeltable and pick a model hosted on Anthropic with some parameters
134
+ queries_t['response_2'] = f_chat_completions(
135
+ messages=msgs,
136
+ model='accounts/fireworks/models/llama-v3p2-3b-instruct',
137
+ # These parameters are optional and can be used to tune model behavior:
138
+ max_tokens=300,
139
+ top_p=0.9,
140
+ temperature=0.7
141
+ )
142
+
143
+ queries_t['response_3'] = chat_completions(
144
+ messages=msgs,
145
+ model='mistral-small-latest',
146
+ # These parameters are optional and can be used to tune model behavior:
147
+ max_tokens=300,
148
+ top_p=0.9,
149
+ temperature=0.7
150
+ )
151
+
152
  # Extract the answer text from the API response
153
+ queries_t['gpt4omini'] = queries_t.response.choices[0].message.content
154
+ queries_t['llamav3p23b'] = queries_t.response_2.choices[0].message.content
155
+ queries_t['mistralsmall'] = queries_t.response_3.choices[0].message.content
156
+
157
+ # Prepare the output dataframe with selected columns
158
+ columns_to_show = []
159
+ if show_question:
160
+ columns_to_show.append(queries_t.question)
161
+ if show_correct_answer:
162
+ columns_to_show.append(queries_t.correct_answer)
163
+ if show_gpt4omini:
164
+ columns_to_show.append(queries_t.gpt4omini)
165
+ if show_llamav3p23b:
166
+ columns_to_show.append(queries_t.llamav3p23b)
167
+ if show_mistralsmall:
168
+ columns_to_show.append(queries_t.mistralsmall)
169
 
170
+ df_output = queries_t.select(*columns_to_show).collect().to_pandas()
 
171
 
172
  try:
173
  # Return the output dataframe for display
 
175
  except Exception as e:
176
  return f"An error occurred: {str(e)}", None
177
 
178
+ def save_dataframe_as_csv(data):
179
+ print(f"Type of data: {type(data)}")
180
+ if isinstance(data, pd.DataFrame):
181
+ print(f"Shape of DataFrame: {data.shape}")
182
+ if isinstance(data, pd.DataFrame) and not data.empty:
183
+ filename = f"results_{uuid.uuid4().hex[:8]}.csv"
184
+ filepath = os.path.join('tmp', filename)
185
+ os.makedirs('tmp', exist_ok=True)
186
+ data.to_csv(filepath, index=False)
187
+ return filepath
188
+ return None
189
+
190
  # Gradio interface
191
+ with gr.Blocks(theme=Monochrome) as demo:
192
+ gr.Markdown(
193
+ """
194
+ <div max-width: 800px; margin: 0 auto;">
195
+ <img src="https://raw.githubusercontent.com/pixeltable/pixeltable/main/docs/source/data/pixeltable-logo-large.png" alt="Pixeltable" style="max-width: 200px; margin-bottom: 20px;" />
196
+ <h1 style="margin-bottom: 0.5em;">Multi-LLM RAG Benchmark: Document Q&A with Groundtruth Comparison</h1>
197
+ </div>
198
+ """
199
+ )
200
+ gr.HTML(
201
+ """
202
+ <p>
203
+ <a href="https://github.com/pixeltable/pixeltable" target="_blank" style="color: #F25022; text-decoration: none; font-weight: bold;">Pixeltable</a> is a declarative interface for working with text, images, embeddings, and even video, enabling you to store, transform, index, and iterate on data.
204
+ </p>
205
+ """
206
+ )
207
+
208
+ # Add the disclaimer
209
+ gr.Markdown(
210
+ """
211
+ <div style="background-color: #E5DDD4; border: 1px solid #e9ecef; border-radius: 8px; padding: 15px; margin-bottom: 20px;">
212
+ <strong>Disclaimer:</strong> This Gradio app is running on OpenAI, Mistral, and Fireworks accounts with the developer's personal API keys.
213
+ If you wish to use it with your own hardware or API keys, you can
214
+ <a href="https://huggingface.co/spaces/Pixeltable/Multi-LLM-RAG-with-Groundtruth-Comparison?duplicate=true" target="_blank" style="color: #F25022; text-decoration: none; font-weight: bold;">duplicate this Hugging Face Space</a>
215
+ or run it locally or in Google Colab.
216
+ </div>
217
+ """
218
+ )
219
+
220
+ with gr.Row():
221
+ with gr.Column():
222
+ with gr.Accordion("What This Demo Does", open = True):
223
+ gr.Markdown("""
224
+ 1. **Ingests Documents**: Uploads your PDF documents and a ground truth file (CSV or XLSX).
225
+ 2. **Process and Retrieve Data**: Store, chunk, index, orchestrate, and retrieve all data.
226
+ 4. **Generates Answers**: Leverages OpenAI to produce accurate answers based on the retrieved context.
227
+ 5. **Compares Results**: Displays the generated answers alongside the ground truth for easy evaluation.
228
+ """)
229
+ with gr.Column():
230
+ with gr.Accordion("How to Use", open = True):
231
+ gr.Markdown("""
232
+ 1. Upload your ground truth file (CSV or XLSX) with the following two columns: **question** and **correct_answer**.
233
+ 2. Upload one or more PDF documents that contain the information to answer these questions.
234
+ 3. Click "Process Files and Generate Output" to start the RAG process.
235
+ 4. View the results in the table below, comparing AI-generated answers to the ground truth.
236
+ """)
237
 
238
  # File upload components for ground truth and PDF documents
239
  with gr.Row():
240
+ ground_truth_file = gr.File(label="Upload Ground Truth (CSV or XLSX) - Format to respect:question | correct_answer", file_count="single")
241
  pdf_files = gr.File(label="Upload PDF Documents", file_count="multiple")
242
 
243
+ # Add controls for chunking parameters
244
+ with gr.Row():
245
+ chunk_limit = gr.Slider(minimum=100, maximum=500, value=300, step=5, label="Chunk Size Limit (only used when the separator is token_/char_limit)")
246
+ chunk_separator = gr.Dropdown(
247
+ choices=["token_limit", "char_limit", "sentence", "paragraph", "heading"],
248
+ value="token_limit",
249
+ label="Chunk Separator"
250
+ )
251
+
252
+ with gr.Row():
253
+ show_question = gr.Checkbox(label="Show Question", value=True)
254
+ show_correct_answer = gr.Checkbox(label="Show Correct Answer", value=True)
255
+ show_gpt4omini = gr.Checkbox(label="Show GPT-4o-mini Answer", value=True)
256
+ show_llamav3p23b = gr.Checkbox(label="Show LLaMA-v3-2-3B Answer", value=True)
257
+ show_mistralsmall = gr.Checkbox(label="Show Mistral-Small Answer", value=True)
258
+
259
  # Button to trigger file processing
260
  process_button = gr.Button("Process Files and Generate Outputs")
261
 
262
  # Output component to display the results
263
+ df_output = gr.DataFrame(label="Pixeltable Table",
264
+ wrap=True
265
+ )
266
+
267
+ with gr.Row():
268
+ with gr.Column(scale=1):
269
+ download_button = gr.Button("Download Results as CSV")
270
+ with gr.Column(scale=2):
271
+ csv_output = gr.File(label="CSV Download")
272
 
273
+ def trigger_download(data):
274
+ csv_path = save_dataframe_as_csv(data)
275
+ return csv_path if csv_path else None
276
+
277
+ process_button.click(process_files,
278
+ inputs=[ground_truth_file,
279
+ pdf_files,
280
+ chunk_limit,
281
+ chunk_separator,
282
+ show_question,
283
+ show_correct_answer,
284
+ show_gpt4omini,
285
+ show_llamav3p23b,
286
+ show_mistralsmall],
287
+ outputs=df_output)
288
+
289
+ download_button.click(
290
+ trigger_download,
291
+ inputs=[df_output],
292
+ outputs=[csv_output]
293
+ )
294
 
295
  if __name__ == "__main__":
296
+ demo.launch(debug=True)