tdoehmen commited on
Commit
edbe15e
·
1 Parent(s): ee5875c

added prompt template and openai api key

Browse files
app.py CHANGED
@@ -1,40 +1,120 @@
1
  import gradio as gr
2
- from evaluation_logic import run_evaluation, AVAILABLE_PROMPT_FORMATS
 
 
3
 
4
- def gradio_run_evaluation(inference_api, model_name, prompt_format):
 
 
 
 
 
 
 
 
 
 
5
  output = []
6
- for result in run_evaluation(inference_api, str(model_name).strip(), prompt_format):
7
  output.append(result)
8
  yield "\n".join(output)
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  with gr.Blocks(gr.themes.Soft()) as demo:
11
  gr.Markdown("# DuckDB SQL Evaluation App")
12
 
13
- inference_api = gr.Dropdown(
14
- label="Inference API",
15
- choices=['openrouter'],
16
- value="openrouter"
17
- )
18
- model_name = gr.Textbox(label="Model Name (e.g., qwen/qwen-2.5-72b-instruct)")
19
- gr.Markdown("[View OpenRouter Models](https://openrouter.ai/models?order=top-weekly)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- prompt_format = gr.Dropdown(
22
- label="Prompt Format",
23
- choices=['duckdbinst', 'duckdbinstgraniteshort'], #AVAILABLE_PROMPT_FORMATS,
24
- value="duckdbinstgraniteshort"
25
- )
26
  gr.Examples(
27
  examples=[
28
- ["openrouter", "qwen/qwen-2.5-72b-instruct", "duckdbinst"],
29
- ["openrouter", "meta-llama/llama-3.2-3b-instruct:free", "duckdbinstgraniteshort"],
30
- ["openrouter", "mistralai/mistral-nemo", "duckdbinst"],
31
  ],
32
- inputs=[inference_api, model_name, prompt_format],
33
  )
34
 
35
  start_btn = gr.Button("Start Evaluation")
36
  output = gr.Textbox(label="Output", lines=20)
37
 
38
- start_btn.click(fn=gradio_run_evaluation, inputs=[inference_api, model_name, prompt_format], outputs=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  demo.queue().launch()
 
1
  import gradio as gr
2
+ import os
3
+ from evaluation_logic import run_evaluation
4
+ from eval.predict import PROMPT_FORMATTERS
5
 
6
+ PROMPT_TEMPLATES = {
7
+ "duckdbinstgraniteshort": PROMPT_FORMATTERS["duckdbinstgraniteshort"]().PROMPT_TEMPLATE,
8
+ "duckdbinst": PROMPT_FORMATTERS["duckdbinst"]().PROMPT_TEMPLATE,
9
+ }
10
+
11
+ def gradio_run_evaluation(inference_api, model_name, prompt_format, openrouter_token=None, custom_prompt=None):
12
+ # Set environment variable if OpenRouter token is provided
13
+ if inference_api == "openrouter":
14
+ os.environ["OPENROUTER_API_KEY"] = str(openrouter_token)
15
+
16
+ # We now pass both the format name and content to evaluation
17
  output = []
18
+ for result in run_evaluation(inference_api, str(model_name).strip(), prompt_format, custom_prompt):
19
  output.append(result)
20
  yield "\n".join(output)
21
 
22
+ def update_token_visibility(api):
23
+ """Update visibility of the OpenRouter token input"""
24
+ return gr.update(visible=api == "openrouter")
25
+
26
+ def update_prompt_template(prompt_format):
27
+ """Update the template content when a preset is selected"""
28
+ if prompt_format in PROMPT_TEMPLATES:
29
+ return PROMPT_FORMATTERS[prompt_format]()
30
+ return ""
31
+
32
+ def handle_template_edit(prompt_format, new_template):
33
+ """Handle when user edits the template"""
34
+ # If the template matches a preset exactly, keep the preset name
35
+ for format_name, template in PROMPT_TEMPLATES.items():
36
+ if template.strip() == new_template.strip():
37
+ return format_name
38
+ # Otherwise switch to custom
39
+ return "custom"
40
+
41
  with gr.Blocks(gr.themes.Soft()) as demo:
42
  gr.Markdown("# DuckDB SQL Evaluation App")
43
 
44
+ with gr.Row():
45
+ with gr.Column():
46
+ inference_api = gr.Dropdown(
47
+ label="Inference API",
48
+ choices=['openrouter'],
49
+ value="openrouter"
50
+ )
51
+
52
+ openrouter_token = gr.Textbox(
53
+ label="OpenRouter API Token",
54
+ placeholder="Enter your OpenRouter API token",
55
+ type="password",
56
+ visible=True
57
+ )
58
+
59
+ model_name = gr.Textbox(
60
+ label="Model Name (e.g., qwen/qwen-2.5-72b-instruct)"
61
+ )
62
+
63
+ gr.Markdown("[View OpenRouter Models](https://openrouter.ai/models?order=top-weekly)")
64
+
65
+ with gr.Row():
66
+ with gr.Column():
67
+ # Add 'custom' to the choices
68
+ prompt_format = gr.Dropdown(
69
+ label="Prompt Format",
70
+ choices=['duckdbinst', 'duckdbinstgraniteshort', 'custom'],
71
+ value="duckdbinstgraniteshort"
72
+ )
73
+
74
+ custom_prompt = gr.TextArea(
75
+ label="Prompt Template Content",
76
+ placeholder="Enter your custom prompt template here or select a preset format above.",
77
+ lines=10,
78
+ value=PROMPT_TEMPLATES['duckdbinstgraniteshort'] # Set initial value
79
+ )
80
 
 
 
 
 
 
81
  gr.Examples(
82
  examples=[
83
+ ["openrouter", "qwen/qwen-2.5-72b-instruct", "duckdbinst", "", PROMPT_TEMPLATES['duckdbinst']],
84
+ ["openrouter", "meta-llama/llama-3.2-3b-instruct:free", "duckdbinstgraniteshort", "", PROMPT_TEMPLATES['duckdbinstgraniteshort']],
85
+ ["openrouter", "mistralai/mistral-nemo", "duckdbinst", "", PROMPT_TEMPLATES['duckdbinst']],
86
  ],
87
+ inputs=[inference_api, model_name, prompt_format, openrouter_token, custom_prompt],
88
  )
89
 
90
  start_btn = gr.Button("Start Evaluation")
91
  output = gr.Textbox(label="Output", lines=20)
92
 
93
+ # Update token visibility
94
+ inference_api.change(
95
+ fn=update_token_visibility,
96
+ inputs=[inference_api],
97
+ outputs=[openrouter_token]
98
+ )
99
+
100
+ # Update template content when preset is selected
101
+ prompt_format.change(
102
+ fn=update_prompt_template,
103
+ inputs=[prompt_format],
104
+ outputs=[custom_prompt]
105
+ )
106
+
107
+ # Update format dropdown when template is edited
108
+ custom_prompt.change(
109
+ fn=handle_template_edit,
110
+ inputs=[prompt_format, custom_prompt],
111
+ outputs=[prompt_format]
112
+ )
113
+
114
+ start_btn.click(
115
+ fn=gradio_run_evaluation,
116
+ inputs=[inference_api, model_name, prompt_format, openrouter_token, custom_prompt],
117
+ outputs=output
118
+ )
119
 
120
  demo.queue().launch()
duckdb-nsql/eval/constants.py CHANGED
@@ -16,6 +16,7 @@ from prompt_formatters import (
16
  DuckDBInstFormatterGPTmini,
17
  DuckDBInstFormatterPhiAzure,
18
  DuckDBInstFormatterLlamaSyntax,
 
19
  )
20
 
21
  PROMPT_FORMATTERS = {
@@ -33,5 +34,6 @@ PROMPT_FORMATTERS = {
33
  "duckdbinstgptmini": DuckDBInstFormatterPhi,
34
  "duckdbinstphiazure": DuckDBInstFormatterPhiAzure,
35
  "duckdbinstllamabasic": DuckDBInstFormatterLlamaBasic,
36
- "duckdbinstllamasyntax": DuckDBInstFormatterLlamaSyntax
 
37
  }
 
16
  DuckDBInstFormatterGPTmini,
17
  DuckDBInstFormatterPhiAzure,
18
  DuckDBInstFormatterLlamaSyntax,
19
+ DuckDBInstFormatterCustom,
20
  )
21
 
22
  PROMPT_FORMATTERS = {
 
34
  "duckdbinstgptmini": DuckDBInstFormatterPhi,
35
  "duckdbinstphiazure": DuckDBInstFormatterPhiAzure,
36
  "duckdbinstllamabasic": DuckDBInstFormatterLlamaBasic,
37
+ "duckdbinstllamasyntax": DuckDBInstFormatterLlamaSyntax,
38
+ "custom": DuckDBInstFormatterCustom
39
  }
duckdb-nsql/eval/prompt_formatters.py CHANGED
@@ -958,6 +958,35 @@ Write a DuckDB SQL query for the given question!
958
  )
959
  return instruction
960
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
961
  class DuckDBInstNoShorthandFormatter(DuckDBInstFormatter):
962
  """DuckDB Inst class."""
963
 
 
958
  )
959
  return instruction
960
 
961
+
962
+ class DuckDBInstFormatterCustom(RajkumarFormatter):
963
+ """DuckDB Inst class."""
964
+
965
+ PROMPT_TEMPLATE = ""
966
+
967
+ @classmethod
968
+ def format_retrieved_context(
969
+ cls,
970
+ context: list[str],
971
+ ) -> str:
972
+ """Format retrieved context."""
973
+ context_str = "\n--------\n".join(context)
974
+ return f"\n### Documentation:\n{context_str}\n"
975
+
976
+ @classmethod
977
+ def format_prompt(
978
+ cls,
979
+ instruction: str,
980
+ table_text: str,
981
+ context_text: str,
982
+ ) -> str | list[str]:
983
+ """Get prompt format."""
984
+ instruction = cls.PROMPT_TEMPLATE.format(
985
+ schema=table_text,
986
+ question=instruction
987
+ )
988
+ return instruction
989
+
990
  class DuckDBInstNoShorthandFormatter(DuckDBInstFormatter):
991
  """DuckDB Inst class."""
992
 
evaluation_logic.py CHANGED
@@ -54,7 +54,7 @@ def save_prediction(inference_api, model_name, prompt_format, question, generate
54
  "timestamp": datetime.now().isoformat()
55
  }, f)
56
 
57
- def save_evaluation(inference_api, model_name, prompt_format, metrics):
58
  evaluation_file = evaluation_folder / f"evaluation_{file_uuid}.json"
59
  evaluation_folder.mkdir(parents=True, exist_ok=True)
60
 
@@ -64,6 +64,7 @@ def save_evaluation(inference_api, model_name, prompt_format, metrics):
64
  "inference_api": inference_api,
65
  "model_name": model_name,
66
  "prompt_format": prompt_format,
 
67
  "timestamp": datetime.now().isoformat()
68
  }
69
 
@@ -82,7 +83,7 @@ def save_evaluation(inference_api, model_name, prompt_format, metrics):
82
  json.dump(flattened_metrics, f)
83
  f.write('\n')
84
 
85
- def run_prediction(inference_api, model_name, prompt_format, output_file):
86
  dataset_path = str(eval_dir / "data/dev.json")
87
  table_meta_path = str(eval_dir / "data/tables.json")
88
  stop_tokens = [';']
@@ -100,7 +101,11 @@ def run_prediction(inference_api, model_name, prompt_format, output_file):
100
  try:
101
  # Initialize necessary components
102
  data_formatter = DefaultLoader()
103
- prompt_formatter = PROMPT_FORMATTERS[prompt_format]()
 
 
 
 
104
 
105
  # Load manifest
106
  manifest = get_manifest(
@@ -159,7 +164,7 @@ def run_prediction(inference_api, model_name, prompt_format, output_file):
159
  yield f"Prediction failed with error: {str(e)}"
160
  yield f"Error traceback: {traceback.format_exc()}"
161
 
162
- def run_evaluation(inference_api, model_name, prompt_format="duckdbinstgraniteshort"):
163
  if "OPENROUTER_API_KEY" not in os.environ:
164
  yield "Error: OPENROUTER_API_KEY not found in environment variables."
165
  return
@@ -176,6 +181,9 @@ def run_evaluation(inference_api, model_name, prompt_format="duckdbinstgranitesh
176
  yield f"Using model: {model_name}"
177
  yield f"Using prompt format: {prompt_format}"
178
 
 
 
 
179
  output_file = output_dir / f"{prompt_format}_0docs_{model_name.replace('/', '_')}_dev_{datetime.now().strftime('%y-%m-%d')}.json"
180
 
181
  # Ensure the output directory exists
@@ -186,7 +194,7 @@ def run_evaluation(inference_api, model_name, prompt_format="duckdbinstgranitesh
186
  yield "Skipping prediction step and proceeding to evaluation."
187
  else:
188
  # Run prediction
189
- for output in run_prediction(inference_api, model_name, prompt_format, output_file):
190
  yield output
191
 
192
  # Run evaluation
@@ -226,7 +234,7 @@ def run_evaluation(inference_api, model_name, prompt_format="duckdbinstgranitesh
226
  )
227
 
228
  # Save evaluation results to dataset
229
- save_evaluation(inference_api, model_name, prompt_format, metrics)
230
 
231
  yield "Evaluation completed."
232
 
 
54
  "timestamp": datetime.now().isoformat()
55
  }, f)
56
 
57
+ def save_evaluation(inference_api, model_name, prompt_format, custom_prompt, metrics):
58
  evaluation_file = evaluation_folder / f"evaluation_{file_uuid}.json"
59
  evaluation_folder.mkdir(parents=True, exist_ok=True)
60
 
 
64
  "inference_api": inference_api,
65
  "model_name": model_name,
66
  "prompt_format": prompt_format,
67
+ "custom_prompt": str(custom_prompt),
68
  "timestamp": datetime.now().isoformat()
69
  }
70
 
 
83
  json.dump(flattened_metrics, f)
84
  f.write('\n')
85
 
86
+ def run_prediction(inference_api, model_name, prompt_format, custom_prompt, output_file):
87
  dataset_path = str(eval_dir / "data/dev.json")
88
  table_meta_path = str(eval_dir / "data/tables.json")
89
  stop_tokens = [';']
 
101
  try:
102
  # Initialize necessary components
103
  data_formatter = DefaultLoader()
104
+ if prompt_format.startswith("custom"):
105
+ prompt_formatter = PROMPT_FORMATTERS["custom"]()
106
+ prompt_formatter.PROMPT_TEMPLATE = custom_prompt
107
+ else:
108
+ prompt_formatter = PROMPT_FORMATTERS[prompt_format]()
109
 
110
  # Load manifest
111
  manifest = get_manifest(
 
164
  yield f"Prediction failed with error: {str(e)}"
165
  yield f"Error traceback: {traceback.format_exc()}"
166
 
167
+ def run_evaluation(inference_api, model_name, prompt_format="duckdbinstgraniteshort", custom_prompt=None):
168
  if "OPENROUTER_API_KEY" not in os.environ:
169
  yield "Error: OPENROUTER_API_KEY not found in environment variables."
170
  return
 
181
  yield f"Using model: {model_name}"
182
  yield f"Using prompt format: {prompt_format}"
183
 
184
+ if prompt_format == "custom":
185
+ prompt_format = prompt_format+"_"+str(abs(hash(custom_prompt)) % (10 ** 8))
186
+
187
  output_file = output_dir / f"{prompt_format}_0docs_{model_name.replace('/', '_')}_dev_{datetime.now().strftime('%y-%m-%d')}.json"
188
 
189
  # Ensure the output directory exists
 
194
  yield "Skipping prediction step and proceeding to evaluation."
195
  else:
196
  # Run prediction
197
+ for output in run_prediction(inference_api, model_name, prompt_format, custom_prompt, output_file):
198
  yield output
199
 
200
  # Run evaluation
 
234
  )
235
 
236
  # Save evaluation results to dataset
237
+ save_evaluation(inference_api, model_name, prompt_format, custom_prompt, metrics)
238
 
239
  yield "Evaluation completed."
240