|
import gradio as gr |
|
import pixeltable as pxt |
|
from pixeltable.functions.mistralai import chat_completions |
|
from datetime import datetime |
|
from textblob import TextBlob |
|
import re |
|
import nltk |
|
from nltk.tokenize import word_tokenize |
|
from nltk.corpus import stopwords |
|
import os |
|
import getpass |
|
|
|
|
|
nltk.download('punkt', quiet=True) |
|
nltk.download('stopwords', quiet=True) |
|
nltk.download('punkt_tab', quiet=True) |
|
|
|
|
|
if 'MISTRAL_API_KEY' not in os.environ: |
|
os.environ['MISTRAL_API_KEY'] = getpass.getpass('Mistral AI API Key:') |
|
|
|
|
|
@pxt.udf |
|
def get_sentiment_score(text: str) -> float: |
|
return TextBlob(text).sentiment.polarity |
|
|
|
@pxt.udf |
|
def extract_keywords(text: str, num_keywords: int = 5) -> list: |
|
stop_words = set(stopwords.words('english')) |
|
words = word_tokenize(text.lower()) |
|
keywords = [word for word in words if word.isalnum() and word not in stop_words] |
|
return sorted(set(keywords), key=keywords.count, reverse=True)[:num_keywords] |
|
|
|
@pxt.udf |
|
def calculate_readability(text: str) -> float: |
|
words = len(re.findall(r'\w+', text)) |
|
sentences = len(re.findall(r'\w+[.!?]', text)) or 1 |
|
average_words_per_sentence = words / sentences |
|
return 206.835 - 1.015 * average_words_per_sentence |
|
|
|
|
|
def run_inference_and_analysis(task, system_prompt, input_text, temperature, top_p, max_tokens, stop, random_seed, safe_prompt): |
|
|
|
pxt.drop_table('mistral_prompts', ignore_errors=True) |
|
t = pxt.create_table('mistral_prompts', { |
|
'task': pxt.String, |
|
'system': pxt.String, |
|
'input_text': pxt.String, |
|
'timestamp': pxt.Timestamp, |
|
'temperature': pxt.Float, |
|
'top_p': pxt.Float, |
|
'max_tokens': pxt.Int, |
|
'stop': pxt.String, |
|
'random_seed': pxt.Int, |
|
'safe_prompt': pxt.Bool |
|
}) |
|
|
|
|
|
t.insert([{ |
|
'task': task, |
|
'system': system_prompt, |
|
'input_text': input_text, |
|
'timestamp': datetime.now(), |
|
'temperature': temperature, |
|
'top_p': top_p, |
|
'max_tokens': max_tokens, |
|
'stop': stop, |
|
'random_seed': random_seed, |
|
'safe_prompt': safe_prompt |
|
}]) |
|
|
|
|
|
msgs = [ |
|
{'role': 'system', 'content': t.system}, |
|
{'role': 'user', 'content': t.input_text} |
|
] |
|
|
|
common_params = { |
|
'messages': msgs, |
|
'temperature': temperature, |
|
'top_p': top_p, |
|
'max_tokens': max_tokens if max_tokens is not None else 300, |
|
'stop': stop.split(',') if stop else None, |
|
'random_seed': random_seed, |
|
'safe_prompt': safe_prompt |
|
} |
|
|
|
|
|
t.add_computed_column(open_mistral_nemo=chat_completions(model='open-mistral-nemo', **common_params)) |
|
t.add_computed_column(mistral_medium=chat_completions(model='mistral-medium', **common_params)) |
|
|
|
|
|
t.add_computed_column(omn_response=t.open_mistral_nemo.choices[0].message.content.astype(pxt.String)) |
|
t.add_computed_column(ml_response=t.mistral_medium.choices[0].message.content.astype(pxt.String)) |
|
|
|
|
|
t.add_computed_column(large_sentiment_score=get_sentiment_score(t.ml_response)) |
|
t.add_computed_column(large_keywords=extract_keywords(t.ml_response)) |
|
t.add_computed_column(large_readability_score=calculate_readability(t.ml_response)) |
|
t.add_computed_column(open_sentiment_score=get_sentiment_score(t.omn_response)) |
|
t.add_computed_column(open_keywords=extract_keywords(t.omn_response)) |
|
t.add_computed_column(open_readability_score=calculate_readability(t.omn_response)) |
|
|
|
|
|
results = t.select( |
|
t.omn_response, t.ml_response, |
|
t.large_sentiment_score, t.open_sentiment_score, |
|
t.large_keywords, t.open_keywords, |
|
t.large_readability_score, t.open_readability_score |
|
).tail(1) |
|
|
|
history = t.select(t.timestamp, t.task, t.system, t.input_text).order_by(t.timestamp, asc=False).collect().to_pandas() |
|
responses = t.select(t.timestamp, t.omn_response, t.ml_response).order_by(t.timestamp, asc=False).collect().to_pandas() |
|
analysis = t.select( |
|
t.timestamp, |
|
t.open_sentiment_score, |
|
t.large_sentiment_score, |
|
t.open_keywords, |
|
t.large_keywords, |
|
t.open_readability_score, |
|
t.large_readability_score |
|
).order_by(t.timestamp, asc=False).collect().to_pandas() |
|
params = t.select( |
|
t.timestamp, |
|
t.temperature, |
|
t.top_p, |
|
t.max_tokens, |
|
t.stop, |
|
t.random_seed, |
|
t.safe_prompt |
|
).order_by(t.timestamp, asc=False).collect().to_pandas() |
|
|
|
return ( |
|
results['omn_response'][0], |
|
results['ml_response'][0], |
|
results['large_sentiment_score'][0], |
|
results['open_sentiment_score'][0], |
|
results['large_keywords'][0], |
|
results['open_keywords'][0], |
|
results['large_readability_score'][0], |
|
results['open_readability_score'][0], |
|
history, |
|
responses, |
|
analysis, |
|
params |
|
) |
|
|
|
|
|
def gradio_interface(): |
|
with gr.Blocks(theme=gr.themes.Base(), title="Prompt Engineering and LLM Studio") as demo: |
|
gr.HTML( |
|
""" |
|
<div style="margin-bottom: 20px;"> |
|
<img src="https://raw.githubusercontent.com/pixeltable/pixeltable/main/docs/resources/pixeltable-logo-large.png" alt="Pixeltable" style="max-width: 150px;" /> |
|
</div> |
|
""" |
|
) |
|
gr.Markdown( |
|
""" |
|
# Prompt Engineering and LLM Studio |
|
This application demonstrates how [Pixeltable](https://github.com/pixeltable/pixeltable) can be used for rapid and incremental prompt engineering |
|
and model comparison workflows. It showcases Pixeltable's ability to directly store, version, index, |
|
and transform data while providing an interactive interface to experiment with different prompts and models. |
|
Remember, effective prompt engineering often requires experimentation and iteration. Use this tool to systematically improve your prompts and understand how different inputs and parameters affect the LLM outputs. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Accordion("What does it do?", open=False): |
|
gr.Markdown( |
|
""" |
|
1. **Data Organization**: Pixeltable uses tables and views to organize data, similar to traditional databases but with enhanced capabilities for AI workflows. |
|
2. **Computed Columns**: These are dynamically generated columns based on expressions applied to columns. |
|
3. **Data Storage**: All prompts, responses, and analysis results are stored in Pixeltable tables. |
|
4. **Versioning**: Every operations are automatically versioned, allowing you to track changes over time. |
|
5. **UDFs**: Sentiment scores, keywords, and readability scores are computed dynamically. |
|
6. **Querying**: The history and analysis tabs leverage Pixeltable's querying capabilities to display results. |
|
""" |
|
) |
|
|
|
with gr.Column(): |
|
with gr.Accordion("How does it work?", open=False): |
|
gr.Markdown( |
|
""" |
|
1. **Define your task**: This helps you keep track of different experiments. |
|
2. **Set up your prompt**: Enter a system prompt in the "System Prompt" field. Write your specific input or question in the "Input Text" field |
|
3. **Adjust parameters (optional)**: Adjust temperature, top_p, token limits, etc., to control the model's output. |
|
4. **Run the analysis**: Click the "Run Inference and Analysis" button. |
|
5. **Review the results**: Compare the responses from both models and exmaine the scores. |
|
6. **Iterate and refine**: Based on the results, refine your prompt or adjust parameters. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
task = gr.Textbox(label="Task (Arbitrary Category)") |
|
system_prompt = gr.Textbox(label="System Prompt") |
|
input_text = gr.Textbox(label="Input Text") |
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
temperature = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.1, label="Temperature") |
|
top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.1, label="Top P") |
|
max_tokens = gr.Number(label="Max Tokens", value=300) |
|
stop = gr.Textbox(label="Stop Sequences (comma-separated)") |
|
random_seed = gr.Number(label="Random Seed", value=None) |
|
safe_prompt = gr.Checkbox(label="Safe Prompt", value=False) |
|
|
|
submit_btn = gr.Button("Run Inference and Analysis") |
|
|
|
with gr.Tabs(): |
|
with gr.Tab("Prompt Input"): |
|
history = gr.Dataframe( |
|
headers=["Task", "System Prompt", "Input Text", "Timestamp"], |
|
wrap=True |
|
) |
|
|
|
with gr.Tab("Model Responses"): |
|
responses = gr.Dataframe( |
|
headers=["Timestamp", "Open-Mistral-Nemo Response", "Mistral-Medium Response"], |
|
wrap=True |
|
) |
|
|
|
with gr.Tab("Analysis Results"): |
|
analysis = gr.Dataframe( |
|
headers=[ |
|
"Timestamp", |
|
"Open-Mistral-Nemo Sentiment", |
|
"Mistral-Medium Sentiment", |
|
"Open-Mistral-Nemo Keywords", |
|
"Mistral-Medium Keywords", |
|
"Open-Mistral-Nemo Readability", |
|
"Mistral-Medium Readability" |
|
], |
|
wrap=True |
|
) |
|
|
|
with gr.Tab("Model Parameters"): |
|
params = gr.Dataframe( |
|
headers=[ |
|
"Timestamp", |
|
"Temperature", |
|
"Top P", |
|
"Max Tokens", |
|
"Min Tokens", |
|
"Stop Sequences", |
|
"Random Seed", |
|
"Safe Prompt" |
|
], |
|
wrap=True |
|
) |
|
|
|
with gr.Column(): |
|
omn_response = gr.Textbox(label="Open-Mistral-Nemo Response") |
|
ml_response = gr.Textbox(label="Mistral-Medium Response") |
|
|
|
with gr.Row(): |
|
large_sentiment = gr.Number(label="Mistral-Medium Sentiment") |
|
open_sentiment = gr.Number(label="Open-Mistral-Nemo Sentiment") |
|
|
|
with gr.Row(): |
|
large_keywords = gr.Textbox(label="Mistral-Medium Keywords") |
|
open_keywords = gr.Textbox(label="Open-Mistral-Nemo Keywords") |
|
|
|
with gr.Row(): |
|
large_readability = gr.Number(label="Mistral-Medium Readability") |
|
open_readability = gr.Number(label="Open-Mistral-Nemo Readability") |
|
|
|
|
|
examples = [ |
|
|
|
["Sentiment Analysis", |
|
"You are an AI trained to analyze the sentiment of text. Provide a detailed analysis of the emotional tone, highlighting key phrases that indicate sentiment.", |
|
"The new restaurant downtown exceeded all my expectations. The food was exquisite, the service impeccable, and the ambiance was perfect for a romantic evening. I can't wait to go back!", |
|
0.3, 0.95, 200, 3, None, False], |
|
|
|
|
|
["Story Generation", |
|
"You are a creative writer. Generate a short, engaging story based on the given prompt. Include vivid descriptions and an unexpected twist.", |
|
"In a world where dreams are shared, a young girl discovers she can manipulate other people's dreams.", |
|
0.9, 0.8, 500, 300, 1, None, False] |
|
] |
|
|
|
gr.Examples( |
|
examples=examples, |
|
inputs=[task, system_prompt, input_text, temperature, top_p, max_tokens, stop, random_seed, safe_prompt], |
|
outputs=[omn_response, ml_response, large_sentiment, open_sentiment, large_keywords, open_keywords, large_readability, open_readability], |
|
fn=run_inference_and_analysis, |
|
cache_examples=True, |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
For more information, visit [Pixeltable's GitHub repository](https://github.com/pixeltable/pixeltable). |
|
""" |
|
) |
|
|
|
submit_btn.click( |
|
run_inference_and_analysis, |
|
inputs=[task, system_prompt, input_text, temperature, top_p, max_tokens, stop, random_seed, safe_prompt], |
|
outputs=[omn_response, ml_response, large_sentiment, open_sentiment, large_keywords, open_keywords, large_readability, open_readability, history, responses, analysis, params] |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
gradio_interface().launch() |