cfahlgren1 HF staff commited on
Commit
470a9a5
·
1 Parent(s): 6fdb323

update with examples and save to dataset

Browse files
Files changed (2) hide show
  1. app.py +13 -2
  2. evaluation_logic.py +58 -6
app.py CHANGED
@@ -7,20 +7,31 @@ def gradio_run_evaluation(inference_api, model_name, prompt_format):
7
  output.append(result)
8
  yield "\n".join(output)
9
 
10
- with gr.Blocks() as demo:
11
  gr.Markdown("# DuckDB SQL Evaluation App")
12
 
13
  inference_api = gr.Dropdown(
14
  label="Inference API",
15
- choices=['openrouter', 'inference_api'],
16
  value="openrouter"
17
  )
18
  model_name = gr.Textbox(label="Model Name (e.g., qwen/qwen-2.5-72b-instruct)")
 
 
19
  prompt_format = gr.Dropdown(
20
  label="Prompt Format",
21
  choices=['duckdbinst', 'duckdbinstgraniteshort'], #AVAILABLE_PROMPT_FORMATS,
22
  value="duckdbinstgraniteshort"
23
  )
 
 
 
 
 
 
 
 
 
24
  start_btn = gr.Button("Start Evaluation")
25
  output = gr.Textbox(label="Output", lines=20)
26
 
 
7
  output.append(result)
8
  yield "\n".join(output)
9
 
10
+ with gr.Blocks(gr.themes.Soft()) as demo:
11
  gr.Markdown("# DuckDB SQL Evaluation App")
12
 
13
  inference_api = gr.Dropdown(
14
  label="Inference API",
15
+ choices=['openrouter'],
16
  value="openrouter"
17
  )
18
  model_name = gr.Textbox(label="Model Name (e.g., qwen/qwen-2.5-72b-instruct)")
19
+ gr.Markdown("[View OpenRouter Models](https://openrouter.ai/models?order=top-weekly)")
20
+
21
  prompt_format = gr.Dropdown(
22
  label="Prompt Format",
23
  choices=['duckdbinst', 'duckdbinstgraniteshort'], #AVAILABLE_PROMPT_FORMATS,
24
  value="duckdbinstgraniteshort"
25
  )
26
+ gr.Examples(
27
+ examples=[
28
+ ["openrouter", "qwen/qwen-2.5-72b-instruct", "duckdbinst"],
29
+ ["openrouter", "meta-llama/llama-3.2-3b-instruct:free", "duckdbinstgraniteshort"],
30
+ ["openrouter", "mistralai/mistral-nemo", "duckdbinst"],
31
+ ],
32
+ inputs=[inference_api, model_name, prompt_format],
33
+ )
34
+
35
  start_btn = gr.Button("Start Evaluation")
36
  output = gr.Textbox(label="Output", lines=20)
37
 
evaluation_logic.py CHANGED
@@ -4,14 +4,14 @@ from pathlib import Path
4
  from datetime import datetime
5
  import json
6
  import traceback
 
 
7
 
8
- # Add the necessary directories to the Python path
9
  current_dir = Path(__file__).resolve().parent
10
  duckdb_nsql_dir = current_dir / 'duckdb-nsql'
11
  eval_dir = duckdb_nsql_dir / 'eval'
12
  sys.path.extend([str(current_dir), str(duckdb_nsql_dir), str(eval_dir)])
13
 
14
- # Import necessary functions and classes
15
  from eval.predict import get_manifest, DefaultLoader, PROMPT_FORMATTERS, generate_sql
16
  from eval.evaluate import evaluate, compute_metrics, get_to_print
17
  from eval.evaluate import test_suite_evaluation, read_tables_json
@@ -19,6 +19,54 @@ from eval.schema import TextToSQLParams, Table
19
 
20
  AVAILABLE_PROMPT_FORMATS = list(PROMPT_FORMATTERS.keys())
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def run_prediction(inference_api, model_name, prompt_format, output_file):
23
  dataset_path = str(eval_dir / "data/dev.json")
24
  table_meta_path = str(eval_dir / "data/tables.json")
@@ -60,9 +108,6 @@ def run_prediction(inference_api, model_name, prompt_format, output_file):
60
  else:
61
  table_params = []
62
 
63
- #if len(table_params) == 0:
64
- #yield f"[red] WARNING: No tables found for {db_id} [/red]"
65
-
66
  text_to_sql_inputs.append(TextToSQLParams(
67
  instruction=question,
68
  database=db_id,
@@ -73,7 +118,7 @@ def run_prediction(inference_api, model_name, prompt_format, output_file):
73
  generated_sqls = generate_sql(
74
  manifest=manifest,
75
  text_to_sql_in=text_to_sql_inputs,
76
- retrieved_docs=[[] for _ in text_to_sql_inputs], # Assuming no retrieved docs
77
  prompt_formatter=prompt_formatter,
78
  stop_tokens=stop_tokens,
79
  overwrite_manifest=overwrite_manifest,
@@ -84,12 +129,16 @@ def run_prediction(inference_api, model_name, prompt_format, output_file):
84
  )
85
 
86
  # Save results
 
87
  with output_file.open('w') as f:
88
  for original_data, (sql, _) in zip(data, generated_sqls):
89
  output = {**original_data, "pred": sql}
90
  json.dump(output, f)
91
  f.write('\n')
92
 
 
 
 
93
  yield f"Prediction completed. Results saved to {output_file}"
94
  except Exception as e:
95
  yield f"Prediction failed with error: {str(e)}"
@@ -161,6 +210,9 @@ def run_evaluation(inference_api, model_name, prompt_format="duckdbinstgranitesh
161
  categories=categories,
162
  )
163
 
 
 
 
164
  yield "Evaluation completed."
165
 
166
  if metrics:
 
4
  from datetime import datetime
5
  import json
6
  import traceback
7
+ import uuid
8
+ from huggingface_hub import CommitScheduler
9
 
 
10
  current_dir = Path(__file__).resolve().parent
11
  duckdb_nsql_dir = current_dir / 'duckdb-nsql'
12
  eval_dir = duckdb_nsql_dir / 'eval'
13
  sys.path.extend([str(current_dir), str(duckdb_nsql_dir), str(eval_dir)])
14
 
 
15
  from eval.predict import get_manifest, DefaultLoader, PROMPT_FORMATTERS, generate_sql
16
  from eval.evaluate import evaluate, compute_metrics, get_to_print
17
  from eval.evaluate import test_suite_evaluation, read_tables_json
 
19
 
20
  AVAILABLE_PROMPT_FORMATS = list(PROMPT_FORMATTERS.keys())
21
 
22
+ prediction_folder = Path("prediction_results/")
23
+ evaluation_folder = Path("evaluation_results/")
24
+
25
+ file_uuid = uuid.uuid4()
26
+
27
+ prediction_scheduler = CommitScheduler(
28
+ repo_id="sql-console/duckdb-nsql-predictions",
29
+ repo_type="dataset",
30
+ folder_path=prediction_folder,
31
+ path_in_repo="data",
32
+ every=10,
33
+ )
34
+
35
+ evaluation_scheduler = CommitScheduler(
36
+ repo_id="sql-console/duckdb-nsql-scores",
37
+ repo_type="dataset",
38
+ folder_path=evaluation_folder,
39
+ path_in_repo="data",
40
+ every=10,
41
+ )
42
+
43
+ def save_prediction(inference_api, model_name, prompt_format, question, generated_sql):
44
+ prediction_file = prediction_folder / f"prediction_{file_uuid}.json"
45
+ prediction_folder.mkdir(parents=True, exist_ok=True)
46
+ with prediction_scheduler.lock:
47
+ with prediction_file.open("a") as f:
48
+ json.dump({
49
+ "inference_api": inference_api,
50
+ "model_name": model_name,
51
+ "prompt_format": prompt_format,
52
+ "question": question,
53
+ "generated_sql": generated_sql,
54
+ "timestamp": datetime.now().isoformat()
55
+ }, f)
56
+
57
+ def save_evaluation(inference_api, model_name, prompt_format, metrics):
58
+ evaluation_file = evaluation_folder / f"evaluation_{file_uuid}.json"
59
+ evaluation_folder.mkdir(parents=True, exist_ok=True)
60
+ with evaluation_scheduler.lock:
61
+ with evaluation_file.open("a") as f:
62
+ json.dump({
63
+ "inference_api": inference_api,
64
+ "model_name": model_name,
65
+ "prompt_format": prompt_format,
66
+ "metrics": metrics,
67
+ "timestamp": datetime.now().isoformat()
68
+ }, f)
69
+
70
  def run_prediction(inference_api, model_name, prompt_format, output_file):
71
  dataset_path = str(eval_dir / "data/dev.json")
72
  table_meta_path = str(eval_dir / "data/tables.json")
 
108
  else:
109
  table_params = []
110
 
 
 
 
111
  text_to_sql_inputs.append(TextToSQLParams(
112
  instruction=question,
113
  database=db_id,
 
118
  generated_sqls = generate_sql(
119
  manifest=manifest,
120
  text_to_sql_in=text_to_sql_inputs,
121
+ retrieved_docs=[[] for _ in text_to_sql_inputs],
122
  prompt_formatter=prompt_formatter,
123
  stop_tokens=stop_tokens,
124
  overwrite_manifest=overwrite_manifest,
 
129
  )
130
 
131
  # Save results
132
+ output_file.parent.mkdir(parents=True, exist_ok=True)
133
  with output_file.open('w') as f:
134
  for original_data, (sql, _) in zip(data, generated_sqls):
135
  output = {**original_data, "pred": sql}
136
  json.dump(output, f)
137
  f.write('\n')
138
 
139
+ # Save prediction to dataset
140
+ save_prediction(inference_api, model_name, prompt_format, original_data["question"], sql)
141
+
142
  yield f"Prediction completed. Results saved to {output_file}"
143
  except Exception as e:
144
  yield f"Prediction failed with error: {str(e)}"
 
210
  categories=categories,
211
  )
212
 
213
+ # Save evaluation results to dataset
214
+ save_evaluation(inference_api, model_name, prompt_format, metrics)
215
+
216
  yield "Evaluation completed."
217
 
218
  if metrics: