|
import os |
|
import sys |
|
from pathlib import Path |
|
from datetime import datetime |
|
import json |
|
import traceback |
|
|
|
|
|
current_dir = Path(__file__).resolve().parent |
|
duckdb_nsql_dir = current_dir / 'duckdb-nsql' |
|
eval_dir = duckdb_nsql_dir / 'eval' |
|
sys.path.extend([str(current_dir), str(duckdb_nsql_dir), str(eval_dir)]) |
|
|
|
|
|
from eval.predict import get_manifest, DefaultLoader, PROMPT_FORMATTERS, generate_sql |
|
from eval.evaluate import evaluate, compute_metrics, get_to_print |
|
from eval.evaluate import test_suite_evaluation, read_tables_json |
|
from eval.schema import TextToSQLParams, Table |
|
|
|
AVAILABLE_PROMPT_FORMATS = list(PROMPT_FORMATTERS.keys()) |
|
|
|
def run_prediction(inference_api, model_name, prompt_format, output_file): |
|
dataset_path = str(eval_dir / "data/dev.json") |
|
table_meta_path = str(eval_dir / "data/tables.json") |
|
stop_tokens = [';'] |
|
max_tokens = 30000 |
|
temperature = 0.1 |
|
num_beams = -1 |
|
manifest_client = inference_api |
|
manifest_engine = model_name |
|
manifest_connection = "http://localhost:5000" |
|
overwrite_manifest = True |
|
parallel = False |
|
|
|
yield "Starting prediction..." |
|
|
|
try: |
|
|
|
data_formatter = DefaultLoader() |
|
prompt_formatter = PROMPT_FORMATTERS[prompt_format]() |
|
|
|
|
|
manifest = get_manifest( |
|
manifest_client=manifest_client, |
|
manifest_connection=manifest_connection, |
|
manifest_engine=manifest_engine, |
|
) |
|
|
|
|
|
data = data_formatter.load_data(dataset_path) |
|
db_to_tables = data_formatter.load_table_metadata(table_meta_path) |
|
|
|
|
|
text_to_sql_inputs = [] |
|
for input_question in data: |
|
question = input_question["question"] |
|
db_id = input_question.get("db_id", "none") |
|
if db_id != "none": |
|
table_params = list(db_to_tables.get(db_id, {}).values()) |
|
else: |
|
table_params = [] |
|
|
|
|
|
|
|
|
|
text_to_sql_inputs.append(TextToSQLParams( |
|
instruction=question, |
|
database=db_id, |
|
tables=table_params, |
|
)) |
|
|
|
|
|
generated_sqls = generate_sql( |
|
manifest=manifest, |
|
text_to_sql_in=text_to_sql_inputs, |
|
retrieved_docs=[[] for _ in text_to_sql_inputs], |
|
prompt_formatter=prompt_formatter, |
|
stop_tokens=stop_tokens, |
|
overwrite_manifest=overwrite_manifest, |
|
max_tokens=max_tokens, |
|
temperature=temperature, |
|
num_beams=num_beams, |
|
parallel=parallel |
|
) |
|
|
|
|
|
with output_file.open('w') as f: |
|
for original_data, (sql, _) in zip(data, generated_sqls): |
|
output = {**original_data, "pred": sql} |
|
json.dump(output, f) |
|
f.write('\n') |
|
|
|
yield f"Prediction completed. Results saved to {output_file}" |
|
except Exception as e: |
|
yield f"Prediction failed with error: {str(e)}" |
|
yield f"Error traceback: {traceback.format_exc()}" |
|
|
|
def run_evaluation(inference_api, model_name, prompt_format="duckdbinstgraniteshort"): |
|
if "OPENROUTER_API_KEY" not in os.environ: |
|
yield "Error: OPENROUTER_API_KEY not found in environment variables." |
|
return |
|
if "HF_TOKEN" not in os.environ: |
|
yield "Error: HF_TOKEN not found in environment variables." |
|
return |
|
|
|
try: |
|
|
|
dataset_path = str(eval_dir / "data/dev.json") |
|
table_meta_path = str(eval_dir / "data/tables.json") |
|
output_dir = eval_dir / "output" |
|
|
|
yield f"Using model: {model_name}" |
|
yield f"Using prompt format: {prompt_format}" |
|
|
|
output_file = output_dir / f"{prompt_format}_0docs_{model_name.replace('/', '_')}_dev_{datetime.now().strftime('%y-%m-%d')}.json" |
|
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
if output_file.exists(): |
|
yield f"Prediction file already exists: {output_file}" |
|
yield "Skipping prediction step and proceeding to evaluation." |
|
else: |
|
|
|
for output in run_prediction(inference_api, model_name, prompt_format, output_file): |
|
yield output |
|
|
|
|
|
yield "Starting evaluation..." |
|
|
|
|
|
gold_path = Path(dataset_path) |
|
db_dir = str(eval_dir / "data/databases/") |
|
tables_path = Path(table_meta_path) |
|
|
|
kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path)) |
|
db_schemas = read_tables_json(str(tables_path)) |
|
|
|
gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8")) |
|
pred_sqls_dict = [json.loads(l) for l in output_file.open("r").readlines()] |
|
|
|
gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict] |
|
setup_sqls = [p["setup_sql"] for p in gold_sqls_dict] |
|
validate_sqls = [p["validation_sql"] for p in gold_sqls_dict] |
|
gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict] |
|
pred_sqls = [p["pred"] for p in pred_sqls_dict] |
|
categories = [p.get("category", "") for p in gold_sqls_dict] |
|
|
|
yield "Computing metrics..." |
|
metrics = compute_metrics( |
|
gold_sqls=gold_sqls, |
|
pred_sqls=pred_sqls, |
|
gold_dbs=gold_dbs, |
|
setup_sqls=setup_sqls, |
|
validate_sqls=validate_sqls, |
|
kmaps=kmaps, |
|
db_schemas=db_schemas, |
|
database_dir=db_dir, |
|
lowercase_schema_match=False, |
|
model_name=model_name, |
|
categories=categories, |
|
) |
|
|
|
yield "Evaluation completed." |
|
|
|
if metrics: |
|
yield "Overall Results:" |
|
overall_metrics = metrics['exec']['all'] |
|
yield f"Count: {overall_metrics['count']}" |
|
yield f"Execution Accuracy: {overall_metrics['exec']:.3f}" |
|
yield f"Exact Match Accuracy: {overall_metrics['exact']:.3f}" |
|
yield f"Equality: {metrics['equality']['equality']:.3f}" |
|
yield f"Edit Distance: {metrics['edit_distance']['edit_distance']:.3f}" |
|
|
|
yield "\nResults by Category:" |
|
categories = ['easy', 'medium', 'hard', 'duckdb', 'ddl', 'all'] |
|
|
|
for category in categories: |
|
if category in metrics['exec']: |
|
yield f"\n{category}:" |
|
category_metrics = metrics['exec'][category] |
|
yield f"Count: {category_metrics['count']}" |
|
yield f"Execution Accuracy: {category_metrics['exec']:.3f}" |
|
else: |
|
yield f"\n{category}: No data available" |
|
else: |
|
yield "No evaluation metrics returned." |
|
except Exception as e: |
|
yield f"An unexpected error occurred: {str(e)}" |
|
yield f"Error traceback: {traceback.format_exc()}" |
|
|
|
if __name__ == "__main__": |
|
model_name = input("Enter the model name: ") |
|
prompt_format = input("Enter the prompt format (default is duckdbinstgraniteshort): ") or "duckdbinstgraniteshort" |
|
for result in run_evaluation(model_name, prompt_format): |
|
print(result, flush=True) |