Spaces:
Sleeping
Sleeping
import gradio as gr | |
import json | |
import os | |
import datetime | |
from itertools import zip_longest | |
import tiktoken | |
from models import select_random_model | |
from rag import select_random_formatter | |
def error_helper(msg: str, duration: int = 10): | |
raise gr.Error(msg, duration=duration) | |
def code_upload(code_file_select): | |
if code_file_select is None: | |
return gr.Button(interactive=False) | |
else: | |
return gr.Button(interactive=True) | |
def token_limit_getter(model: str) -> int: | |
with open("token_limits.json", "r") as f: | |
token_limits = json.load(f) | |
if model in token_limits: | |
return token_limits[model] | |
return int(5e6) | |
def check_length(text, model): | |
if model.name.startswith("gpt"): | |
encoder = lambda s: len(tiktoken.encoding_for_model(model.name).encode(text)) | |
else: | |
encoder = lambda s: len(s)/4 # 4 char per token heuristic | |
token_length = encoder(text) | |
token_limit = token_limit_getter(model.name) | |
if token_length >= token_limit: | |
error_helper("Prompt is too long. Please try reducing the size of the prompt or code uploaded.") | |
def chat_with_llms(prompt, code_files, profile_file, profile_type): | |
model1 = select_random_model() | |
model2 = select_random_model() | |
formatter1 = select_random_formatter() | |
formatter2 = select_random_formatter() | |
print(f"Selected models: {model1.name} and {model2.name}") | |
formatted1 = formatter1.format_prompt(prompt, code_files, profile_file, profile_type, error_fn=error_helper) | |
formatted2 = formatter2.format_prompt(prompt, code_files, profile_file, profile_type, error_fn=error_helper) | |
if formatted1 is None or formatted2 is None: | |
error_helper("Failed to format prompt. Please try again.") | |
check_length(formatted1, model1) | |
check_length(formatted2, model2) | |
response1 = model1.get_response(formatted1) | |
response2 = model2.get_response(formatted2) | |
if response1 is None: | |
error_helper(f"Failed to get response from {model1.name}. Please try again.") | |
if response2 is None: | |
error_helper(f"Failed to get response from {model2.name}. Please try again.") | |
source1 = gr.Markdown(f"{model1.name} + {formatter1.name}", visible=False, elem_classes=[]) | |
source2 = gr.Markdown(f"{model2.name} + {formatter2.name}", visible=False, elem_classes=[]) | |
# set vote buttons to deactive | |
download_btn = gr.Button(interactive=False) | |
vote_buttons = gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False) | |
for c1, c2 in zip_longest(response1, response2): | |
yield c1 or gr.Textbox(), source1, formatted1, c2 or gr.Textbox(), source2, formatted2, download_btn, *vote_buttons | |
vote_buttons = gr.Button(interactive=True), gr.Button(interactive=True), gr.Button(interactive=True), gr.Button(interactive=True) | |
yield c1 or gr.Textbox(), source1, formatted1, c2 or gr.Textbox(), source2, formatted2, download_btn, *vote_buttons | |
def get_interaction_log(prompt, vote, response1, model1, formatter1, full_prompt1, response2, model2, formatter2, full_prompt2): | |
log = { | |
"prompt": prompt, | |
"full_prompt1": full_prompt1, | |
"full_prompt2": full_prompt2, | |
"response1": response1, | |
"response2": response2, | |
"vote": vote, | |
"model1": model1, | |
"formatter1": formatter1, | |
"model2": model2, | |
"formatter2": formatter2, | |
"timestamp": datetime.datetime.now().isoformat() | |
} | |
fpath = f"interaction_log_{datetime.datetime.now().isoformat()}.json" | |
with open(fpath, "w") as f: | |
json.dump(log, f, indent=2) | |
return fpath | |
def handle_vote(prompt, vote, response1, source1, full_prompt1, response2, source2, full_prompt2): | |
model1, formatter1 = source1.split(" + ") | |
model2, formatter2 = source2.split(" + ") | |
label1_class = ["voted"] if vote == "Vote for Response 1" else [] | |
label2_class = ["voted"] if vote == "Vote for Response 2" else [] | |
did_vote = vote != "Skip" | |
log_fpath = get_interaction_log(prompt, vote, response1, model1, formatter1, full_prompt1, response2, model2, formatter2, full_prompt2) | |
return gr.Markdown(visible=True, elem_classes=label1_class), gr.Markdown(visible=True, elem_classes=label2_class), \ | |
gr.Button(interactive=did_vote, value=log_fpath), \ | |
gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False) | |
# Define the Gradio interface | |
with gr.Blocks(css=".not-voted p { color: black; } .voted p { color: green; } .response { padding: 25px; } .response-md { padding: 20px; }") as interface: | |
gr.Markdown("""# Code Performance Chatbot | |
Welcome to the performance analysis chatbot! | |
This is a tool for assisting developers in identifying performance bottlenecks in their code and optimizing them using LLMs. | |
Upload your code files and a performance profile (if available) to get started. Then ask away! | |
This interface is primarily for data collecting and evaluation purposes. You will be presented outputs from two different LLMs and asked to vote on which response you find more helpful. | |
---""") | |
gr.Markdown("""## Upload Code Files and Performance Profile | |
You must upload at least one source code file to proceed. You can also upload a performance profile if you have one. | |
Currently supported formats are HPCToolkit, CProfile, and Caliper. CProfile and Caliper files can be uploaded directly. | |
HPCToolkit database directories should be zipped before uploading (i.e. `hpctoolkit-database.zip`).""") | |
with gr.Row(): | |
code_files = gr.File(label="Upload Code File", file_count='multiple') | |
with gr.Column(): | |
profile_type = gr.Dropdown(['No Profile', 'HPCToolkit', 'CProfile', "Caliper"], value='No Profile', multiselect=False, label="Select Profile Type") | |
profile_file = gr.File(label="Upload Performance Profile") | |
gr.Markdown("---") | |
gr.Markdown("""## Ask a Question | |
Now you can ask a question about your code performance! | |
Once you receive two responses, vote on which one you found more helpful.""") | |
default_question = "Can you help me identify and fix performance bugs in this code?" | |
prompt = gr.Textbox(label="Ask a question about your code performance", value=default_question) | |
chat_button = gr.Button("Chat About Performance", interactive=False) | |
with gr.Row(equal_height=True): | |
with gr.Column(): | |
with gr.Accordion("Response 1", elem_classes=["response"]): | |
response1 = gr.Markdown(label="Response 1", visible=True, elem_classes=["response-md"]) | |
source1 = gr.Markdown("", visible=False) | |
full_prompt1 = gr.Textbox("", visible=False) | |
with gr.Column(): | |
with gr.Accordion("Response 2", elem_classes=["response"]): | |
response2 = gr.Markdown(label="Response 2", visible=True, elem_classes=["response-md"]) | |
source2 = gr.Markdown("", visible=False) | |
full_prompt2 = gr.Textbox("", visible=False) | |
# use code_upload to toggle the status of the 'chat_button' based on whether a code file is uploaded or not | |
code_files.change(code_upload, inputs=[code_files], outputs=[chat_button]) | |
with gr.Row(): | |
vote1_button = gr.Button("Vote for Response 1", interactive=False) | |
vote2_button = gr.Button("Vote for Response 2", interactive=False) | |
tie_button = gr.Button("Vote for Tie", interactive=False) | |
skip_button = gr.Button("Skip", interactive=False) | |
download_btn = gr.DownloadButton("Download Log", interactive=False) | |
vote_btns = [vote1_button, vote2_button, tie_button, skip_button] | |
for btn in vote_btns: | |
btn.click(handle_vote, inputs=[prompt, btn, response1, source1, full_prompt1, response2, source2, full_prompt2], outputs=[source1, source2, download_btn, *vote_btns]) | |
# final chat button | |
chat_button.click( | |
chat_with_llms, | |
inputs=[prompt, code_files, profile_file, profile_type], | |
outputs=[response1, source1, full_prompt1, response2, source2, full_prompt2, download_btn, *vote_btns] | |
) | |
# Launch the Gradio interface | |
if __name__ == '__main__': | |
interface.launch(share=False) | |