PierreBrunelle commited on
Commit
57710db
·
verified ·
1 Parent(s): 693d55a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -82
app.py CHANGED
@@ -3,16 +3,17 @@ import pixeltable as pxt
3
  from pixeltable.functions.mistralai import chat_completions
4
  from datetime import datetime
5
  from textblob import TextBlob
 
6
  import nltk
7
  from nltk.tokenize import word_tokenize
8
  from nltk.corpus import stopwords
9
  import os
10
  import getpass
11
- import re
12
 
13
  # Ensure necessary NLTK data is downloaded
14
  nltk.download('punkt', quiet=True)
15
  nltk.download('stopwords', quiet=True)
 
16
 
17
  # Set up Mistral API key
18
  if 'MISTRAL_API_KEY' not in os.environ:
@@ -37,24 +38,24 @@ def calculate_readability(text: str) -> float:
37
  average_words_per_sentence = words / sentences
38
  return 206.835 - 1.015 * average_words_per_sentence
39
 
 
40
  def run_inference_and_analysis(task, system_prompt, input_text, temperature, top_p, max_tokens, min_tokens, stop, random_seed, safe_prompt):
41
  # Initialize Pixeltable
42
  pxt.drop_table('mistral_prompts', ignore_errors=True)
43
  t = pxt.create_table('mistral_prompts', {
44
- 'task': pxt.StringType(),
45
- 'system': pxt.StringType(),
46
- 'input_text': pxt.StringType(),
47
- 'timestamp': pxt.TimestampType(),
48
- 'temperature': pxt.FloatType(),
49
- 'top_p': pxt.FloatType(),
50
- 'max_tokens': pxt.IntType(),
51
- 'min_tokens': pxt.IntType(),
52
- 'stop': pxt.StringType(),
53
- 'random_seed': pxt.IntType(),
54
- 'safe_prompt': pxt.BoolType()
55
  })
56
 
57
- # Insert new row
58
  t.insert([{
59
  'task': task,
60
  'system': system_prompt,
@@ -63,7 +64,6 @@ def run_inference_and_analysis(task, system_prompt, input_text, temperature, top
63
  'temperature': temperature,
64
  'top_p': top_p,
65
  'max_tokens': max_tokens,
66
- 'min_tokens': min_tokens,
67
  'stop': stop,
68
  'random_seed': random_seed,
69
  'safe_prompt': safe_prompt
@@ -80,36 +80,56 @@ def run_inference_and_analysis(task, system_prompt, input_text, temperature, top
80
  'temperature': temperature,
81
  'top_p': top_p,
82
  'max_tokens': max_tokens if max_tokens is not None else 300,
83
- 'min_tokens': min_tokens,
84
  'stop': stop.split(',') if stop else None,
85
  'random_seed': random_seed,
86
  'safe_prompt': safe_prompt
87
  }
88
 
89
- # Run inference with both models
90
- t['open_mistral_nemo'] = chat_completions(model='open-mistral-nemo', **common_params)
91
- t['mistral_medium'] = chat_completions(model='mistral-medium', **common_params)
92
 
93
  # Extract responses
94
- t['omn_response'] = t.open_mistral_nemo.choices[0].message.content
95
- t['ml_response'] = t.mistral_medium.choices[0].message.content
96
 
97
- # Run analysis
98
- t['large_sentiment_score'] = get_sentiment_score(t.ml_response)
99
- t['large_keywords'] = extract_keywords(t.ml_response)
100
- t['large_readability_score'] = calculate_readability(t.ml_response)
101
- t['open_sentiment_score'] = get_sentiment_score(t.omn_response)
102
- t['open_keywords'] = extract_keywords(t.omn_response)
103
- t['open_readability_score'] = calculate_readability(t.omn_response)
104
 
105
- # Get results
106
  results = t.select(
107
  t.omn_response, t.ml_response,
108
  t.large_sentiment_score, t.open_sentiment_score,
109
  t.large_keywords, t.open_keywords,
110
  t.large_readability_score, t.open_readability_score
111
  ).tail(1)
112
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  return (
114
  results['omn_response'][0],
115
  results['ml_response'][0],
@@ -118,63 +138,119 @@ def run_inference_and_analysis(task, system_prompt, input_text, temperature, top
118
  results['large_keywords'][0],
119
  results['open_keywords'][0],
120
  results['large_readability_score'][0],
121
- results['open_readability_score'][0]
 
 
 
 
122
  )
123
 
 
124
  def gradio_interface():
125
- with gr.Blocks() as demo:
126
- gr.Markdown("# LLM Prompt Studio")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  with gr.Row():
129
  with gr.Column():
130
- # Input components
131
- task = gr.Textbox(label="Task")
132
- system_prompt = gr.Textbox(label="System Prompt", lines=3)
133
- input_text = gr.Textbox(label="Input Text", lines=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  with gr.Accordion("Advanced Settings", open=False):
136
  temperature = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.1, label="Temperature")
137
  top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.1, label="Top P")
138
  max_tokens = gr.Number(label="Max Tokens", value=300)
139
- min_tokens = gr.Number(label="Min Tokens", value=None)
140
  stop = gr.Textbox(label="Stop Sequences (comma-separated)")
141
  random_seed = gr.Number(label="Random Seed", value=None)
142
  safe_prompt = gr.Checkbox(label="Safe Prompt", value=False)
143
 
144
- # Example prompts
145
- examples = [
146
- ["Sentiment Analysis",
147
- "You are an AI trained to analyze the sentiment of text. Provide a detailed analysis of the emotional tone, highlighting key phrases that indicate sentiment.",
148
- "The new restaurant downtown exceeded all my expectations. The food was exquisite, the service impeccable, and the ambiance was perfect for a romantic evening. I can't wait to go back!",
149
- 0.3, 0.95, 200, None, "", None, False],
150
-
151
- ["Story Generation",
152
- "You are a creative writer. Generate a short, engaging story based on the given prompt. Include vivid descriptions and an unexpected twist.",
153
- "In a world where dreams are shared, a young girl discovers she can manipulate other people's dreams.",
154
- 0.9, 0.8, 500, 300, "The end", None, False]
155
- ]
156
-
157
- gr.Examples(
158
- examples=examples,
159
- inputs=[
160
- task, system_prompt, input_text,
161
- temperature, top_p, max_tokens,
162
- min_tokens, stop, random_seed,
163
- safe_prompt
164
- ],
165
- outputs=[
166
- omn_response, ml_response,
167
- large_sentiment, open_sentiment,
168
- large_keywords, open_keywords,
169
- large_readability, open_readability
170
- ],
171
- fn=run_inference_and_analysis
172
- )
173
-
174
- submit_btn = gr.Button("Run Analysis")
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  with gr.Column():
177
- # Output components
178
  omn_response = gr.Textbox(label="Open-Mistral-Nemo Response")
179
  ml_response = gr.Textbox(label="Mistral-Medium Response")
180
 
@@ -190,23 +266,43 @@ def gradio_interface():
190
  large_readability = gr.Number(label="Mistral-Medium Readability")
191
  open_readability = gr.Number(label="Open-Mistral-Nemo Readability")
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  submit_btn.click(
194
  run_inference_and_analysis,
195
- inputs=[
196
- task, system_prompt, input_text,
197
- temperature, top_p, max_tokens,
198
- min_tokens, stop, random_seed,
199
- safe_prompt
200
- ],
201
- outputs=[
202
- omn_response, ml_response,
203
- large_sentiment, open_sentiment,
204
- large_keywords, open_keywords,
205
- large_readability, open_readability
206
- ]
207
  )
208
-
209
  return demo
210
 
 
211
  if __name__ == "__main__":
212
  gradio_interface().launch()
 
3
  from pixeltable.functions.mistralai import chat_completions
4
  from datetime import datetime
5
  from textblob import TextBlob
6
+ import re
7
  import nltk
8
  from nltk.tokenize import word_tokenize
9
  from nltk.corpus import stopwords
10
  import os
11
  import getpass
 
12
 
13
  # Ensure necessary NLTK data is downloaded
14
  nltk.download('punkt', quiet=True)
15
  nltk.download('stopwords', quiet=True)
16
+ nltk.download('punkt_tab', quiet=True)
17
 
18
  # Set up Mistral API key
19
  if 'MISTRAL_API_KEY' not in os.environ:
 
38
  average_words_per_sentence = words / sentences
39
  return 206.835 - 1.015 * average_words_per_sentence
40
 
41
+ # Function to run inference and analysis
42
  def run_inference_and_analysis(task, system_prompt, input_text, temperature, top_p, max_tokens, min_tokens, stop, random_seed, safe_prompt):
43
  # Initialize Pixeltable
44
  pxt.drop_table('mistral_prompts', ignore_errors=True)
45
  t = pxt.create_table('mistral_prompts', {
46
+ 'task': pxt.String,
47
+ 'system': pxt.String,
48
+ 'input_text': pxt.String,
49
+ 'timestamp': pxt.Timestamp,
50
+ 'temperature': pxt.Float,
51
+ 'top_p': pxt.Float,
52
+ 'max_tokens': pxt.Int,
53
+ 'stop': pxt.String,
54
+ 'random_seed': pxt.Int,
55
+ 'safe_prompt': pxt.Bool
 
56
  })
57
 
58
+ # Insert new row into Pixeltable
59
  t.insert([{
60
  'task': task,
61
  'system': system_prompt,
 
64
  'temperature': temperature,
65
  'top_p': top_p,
66
  'max_tokens': max_tokens,
 
67
  'stop': stop,
68
  'random_seed': random_seed,
69
  'safe_prompt': safe_prompt
 
80
  'temperature': temperature,
81
  'top_p': top_p,
82
  'max_tokens': max_tokens if max_tokens is not None else 300,
 
83
  'stop': stop.split(',') if stop else None,
84
  'random_seed': random_seed,
85
  'safe_prompt': safe_prompt
86
  }
87
 
88
+ # Add computed columns for model responses and analysis
89
+ t.add_computed_column(open_mistral_nemo=chat_completions(model='open-mistral-nemo', **common_params))
90
+ t.add_computed_column(mistral_medium=chat_completions(model='mistral-medium', **common_params))
91
 
92
  # Extract responses
93
+ t.add_computed_column(omn_response=t.open_mistral_nemo.choices[0].message.content.astype(pxt.String))
94
+ t.add_computed_column(ml_response=t.mistral_medium.choices[0].message.content.astype(pxt.String))
95
 
96
+ # Add computed columns for analysis
97
+ t.add_computed_column(large_sentiment_score=get_sentiment_score(t.ml_response))
98
+ t.add_computed_column(large_keywords=extract_keywords(t.ml_response))
99
+ t.add_computed_column(large_readability_score=calculate_readability(t.ml_response))
100
+ t.add_computed_column(open_sentiment_score=get_sentiment_score(t.omn_response))
101
+ t.add_computed_column(open_keywords=extract_keywords(t.omn_response))
102
+ t.add_computed_column(open_readability_score=calculate_readability(t.omn_response))
103
 
104
+ # Retrieve results
105
  results = t.select(
106
  t.omn_response, t.ml_response,
107
  t.large_sentiment_score, t.open_sentiment_score,
108
  t.large_keywords, t.open_keywords,
109
  t.large_readability_score, t.open_readability_score
110
  ).tail(1)
111
+
112
+ history = t.select(t.timestamp, t.task, t.system, t.input_text).order_by(t.timestamp, asc=False).collect().to_pandas()
113
+ responses = t.select(t.timestamp, t.omn_response, t.ml_response).order_by(t.timestamp, asc=False).collect().to_pandas()
114
+ analysis = t.select(
115
+ t.timestamp,
116
+ t.open_sentiment_score,
117
+ t.large_sentiment_score,
118
+ t.open_keywords,
119
+ t.large_keywords,
120
+ t.open_readability_score,
121
+ t.large_readability_score
122
+ ).order_by(t.timestamp, asc=False).collect().to_pandas()
123
+ params = t.select(
124
+ t.timestamp,
125
+ t.temperature,
126
+ t.top_p,
127
+ t.max_tokens,
128
+ t.stop,
129
+ t.random_seed,
130
+ t.safe_prompt
131
+ ).order_by(t.timestamp, asc=False).collect().to_pandas()
132
+
133
  return (
134
  results['omn_response'][0],
135
  results['ml_response'][0],
 
138
  results['large_keywords'][0],
139
  results['open_keywords'][0],
140
  results['large_readability_score'][0],
141
+ results['open_readability_score'][0],
142
+ history,
143
+ responses,
144
+ analysis,
145
+ params
146
  )
147
 
148
+ # Gradio interface
149
  def gradio_interface():
150
+ with gr.Blocks(theme=gr.themes.Base(), title="Prompt Engineering and LLM Studio") as demo:
151
+ gr.HTML(
152
+ """
153
+ <div style="margin-bottom: 20px;">
154
+ <img src="https://raw.githubusercontent.com/pixeltable/pixeltable/main/docs/resources/pixeltable-logo-large.png" alt="Pixeltable" style="max-width: 150px;" />
155
+ </div>
156
+ """
157
+ )
158
+ gr.Markdown(
159
+ """
160
+ # Prompt Engineering and LLM Studio
161
+ This application demonstrates how [Pixeltable](https://github.com/pixeltable/pixeltable) can be used for rapid and incremental prompt engineering
162
+ and model comparison workflows. It showcases Pixeltable's ability to directly store, version, index,
163
+ and transform data while providing an interactive interface to experiment with different prompts and models.
164
+ Remember, effective prompt engineering often requires experimentation and iteration. Use this tool to systematically improve your prompts and understand how different inputs and parameters affect the LLM outputs.
165
+ """
166
+ )
167
 
168
  with gr.Row():
169
  with gr.Column():
170
+ with gr.Accordion("What does it do?", open=False):
171
+ gr.Markdown(
172
+ """
173
+ 1. **Data Organization**: Pixeltable uses tables and views to organize data, similar to traditional databases but with enhanced capabilities for AI workflows.
174
+ 2. **Computed Columns**: These are dynamically generated columns based on expressions applied to columns.
175
+ 3. **Data Storage**: All prompts, responses, and analysis results are stored in Pixeltable tables.
176
+ 4. **Versioning**: Every operations are automatically versioned, allowing you to track changes over time.
177
+ 5. **UDFs**: Sentiment scores, keywords, and readability scores are computed dynamically.
178
+ 6. **Querying**: The history and analysis tabs leverage Pixeltable's querying capabilities to display results.
179
+ """
180
+ )
181
+
182
+ with gr.Column():
183
+ with gr.Accordion("How does it work?", open=False):
184
+ gr.Markdown(
185
+ """
186
+ 1. **Define your task**: This helps you keep track of different experiments.
187
+ 2. **Set up your prompt**: Enter a system prompt in the "System Prompt" field. Write your specific input or question in the "Input Text" field
188
+ 3. **Adjust parameters (optional)**: Adjust temperature, top_p, token limits, etc., to control the model's output.
189
+ 4. **Run the analysis**: Click the "Run Inference and Analysis" button.
190
+ 5. **Review the results**: Compare the responses from both models and exmaine the scores.
191
+ 6. **Iterate and refine**: Based on the results, refine your prompt or adjust parameters.
192
+ """
193
+ )
194
+
195
+ with gr.Row():
196
+ with gr.Column():
197
+ task = gr.Textbox(label="Task (Arbitrary Category)")
198
+ system_prompt = gr.Textbox(label="System Prompt")
199
+ input_text = gr.Textbox(label="Input Text")
200
 
201
  with gr.Accordion("Advanced Settings", open=False):
202
  temperature = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.1, label="Temperature")
203
  top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.1, label="Top P")
204
  max_tokens = gr.Number(label="Max Tokens", value=300)
 
205
  stop = gr.Textbox(label="Stop Sequences (comma-separated)")
206
  random_seed = gr.Number(label="Random Seed", value=None)
207
  safe_prompt = gr.Checkbox(label="Safe Prompt", value=False)
208
 
209
+ submit_btn = gr.Button("Run Inference and Analysis")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
+ with gr.Tabs():
212
+ with gr.Tab("Prompt Input"):
213
+ history = gr.Dataframe(
214
+ headers=["Task", "System Prompt", "Input Text", "Timestamp"],
215
+ wrap=True
216
+ )
217
+
218
+ with gr.Tab("Model Responses"):
219
+ responses = gr.Dataframe(
220
+ headers=["Timestamp", "Open-Mistral-Nemo Response", "Mistral-Medium Response"],
221
+ wrap=True
222
+ )
223
+
224
+ with gr.Tab("Analysis Results"):
225
+ analysis = gr.Dataframe(
226
+ headers=[
227
+ "Timestamp",
228
+ "Open-Mistral-Nemo Sentiment",
229
+ "Mistral-Medium Sentiment",
230
+ "Open-Mistral-Nemo Keywords",
231
+ "Mistral-Medium Keywords",
232
+ "Open-Mistral-Nemo Readability",
233
+ "Mistral-Medium Readability"
234
+ ],
235
+ wrap=True
236
+ )
237
+
238
+ with gr.Tab("Model Parameters"):
239
+ params = gr.Dataframe(
240
+ headers=[
241
+ "Timestamp",
242
+ "Temperature",
243
+ "Top P",
244
+ "Max Tokens",
245
+ "Min Tokens",
246
+ "Stop Sequences",
247
+ "Random Seed",
248
+ "Safe Prompt"
249
+ ],
250
+ wrap=True
251
+ )
252
+
253
  with gr.Column():
 
254
  omn_response = gr.Textbox(label="Open-Mistral-Nemo Response")
255
  ml_response = gr.Textbox(label="Mistral-Medium Response")
256
 
 
266
  large_readability = gr.Number(label="Mistral-Medium Readability")
267
  open_readability = gr.Number(label="Open-Mistral-Nemo Readability")
268
 
269
+ # Define the examples
270
+ examples = [
271
+ # Example 1: Sentiment Analysis
272
+ ["Sentiment Analysis",
273
+ "You are an AI trained to analyze the sentiment of text. Provide a detailed analysis of the emotional tone, highlighting key phrases that indicate sentiment.",
274
+ "The new restaurant downtown exceeded all my expectations. The food was exquisite, the service impeccable, and the ambiance was perfect for a romantic evening. I can't wait to go back!",
275
+ 0.3, 0.95, 200, ""],
276
+
277
+ # Example 2: Creative Writing
278
+ ["Story Generation",
279
+ "You are a creative writer. Generate a short, engaging story based on the given prompt. Include vivid descriptions and an unexpected twist.",
280
+ "In a world where dreams are shared, a young girl discovers she can manipulate other people's dreams.",
281
+ 0.9, 0.8, 500, 300, "The end"]
282
+ ]
283
+
284
+ gr.Examples(
285
+ examples=examples,
286
+ inputs=[task, system_prompt, input_text, temperature, top_p, max_tokens, stop, random_seed, safe_prompt],
287
+ outputs=[omn_response, ml_response, large_sentiment, open_sentiment, large_keywords, open_keywords, large_readability, open_readability],
288
+ fn=run_inference_and_analysis,
289
+ cache_examples=True,
290
+ )
291
+
292
+ gr.Markdown(
293
+ """
294
+ For more information, visit [Pixeltable's GitHub repository](https://github.com/pixeltable/pixeltable).
295
+ """
296
+ )
297
+
298
  submit_btn.click(
299
  run_inference_and_analysis,
300
+ inputs=[task, system_prompt, input_text, temperature, top_p, max_tokens, stop, random_seed, safe_prompt],
301
+ outputs=[omn_response, ml_response, large_sentiment, open_sentiment, large_keywords, open_keywords, large_readability, open_readability, history, responses, analysis, params]
 
 
 
 
 
 
 
 
 
 
302
  )
303
+
304
  return demo
305
 
306
+ # Launch the Gradio interface
307
  if __name__ == "__main__":
308
  gradio_interface().launch()