|
import gradio as gr |
|
import spaces |
|
import torch |
|
import os |
|
import sys |
|
from pathlib import Path |
|
from datetime import datetime |
|
import json |
|
|
|
|
|
sys.path.append('duckdb-nsql') |
|
|
|
|
|
from eval.predict import cli as predict_cli, predict, console, get_manifest, DefaultLoader, PROMPT_FORMATTERS |
|
from eval.evaluate import cli as evaluate_cli, evaluate, compute_metrics, get_to_print |
|
from eval.evaluate import test_suite_evaluation, read_tables_json |
|
|
|
zero = torch.Tensor([0]).cuda() |
|
print(zero.device) |
|
|
|
@spaces.GPU |
|
def run_evaluation(model_name): |
|
print(zero.device) |
|
|
|
results = [] |
|
|
|
if "OPENROUTER_API_KEY" not in os.environ: |
|
return "Error: OPENROUTER_API_KEY not found in environment variables." |
|
|
|
try: |
|
|
|
dataset_path = "eval/data/dev.json" |
|
table_meta_path = "eval/data/tables.json" |
|
output_dir = "output/" |
|
prompt_format = "duckdbinstgraniteshort" |
|
stop_tokens = [';'] |
|
max_tokens = 30000 |
|
temperature = 0.1 |
|
num_beams = -1 |
|
manifest_client = "openrouter" |
|
manifest_engine = model_name |
|
manifest_connection = "http://localhost:5000" |
|
overwrite_manifest = True |
|
parallel = False |
|
|
|
|
|
data_formatter = DefaultLoader() |
|
prompt_formatter = PROMPT_FORMATTERS[prompt_format]() |
|
|
|
|
|
manifest = get_manifest( |
|
manifest_client=manifest_client, |
|
manifest_connection=manifest_connection, |
|
manifest_engine=manifest_engine, |
|
) |
|
|
|
results.append(f"Using model: {manifest_engine}") |
|
|
|
|
|
results.append("Loading metadata and data...") |
|
db_to_tables = data_formatter.load_table_metadata(table_meta_path) |
|
data = data_formatter.load_data(dataset_path) |
|
|
|
|
|
date_today = datetime.now().strftime("%y-%m-%d") |
|
pred_filename = f"{prompt_format}_0docs_{manifest_engine.split('/')[-1]}_{Path(dataset_path).stem}_{date_today}.json" |
|
pred_path = Path(output_dir) / pred_filename |
|
results.append(f"Prediction will be saved to: {pred_path}") |
|
|
|
|
|
results.append("Starting prediction...") |
|
predict( |
|
dataset_path=dataset_path, |
|
table_meta_path=table_meta_path, |
|
output_dir=output_dir, |
|
prompt_format=prompt_format, |
|
stop_tokens=stop_tokens, |
|
max_tokens=max_tokens, |
|
temperature=temperature, |
|
num_beams=num_beams, |
|
manifest_client=manifest_client, |
|
manifest_engine=manifest_engine, |
|
manifest_connection=manifest_connection, |
|
overwrite_manifest=overwrite_manifest, |
|
parallel=parallel |
|
) |
|
results.append("Prediction completed.") |
|
|
|
|
|
results.append("Starting evaluation...") |
|
|
|
|
|
gold_path = Path(dataset_path) |
|
db_dir = "eval/data/databases/" |
|
tables_path = Path(table_meta_path) |
|
|
|
kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path)) |
|
db_schemas = read_tables_json(str(tables_path)) |
|
|
|
gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8")) |
|
pred_sqls_dict = [json.loads(l) for l in pred_path.open("r").readlines()] |
|
|
|
gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict] |
|
setup_sqls = [p["setup_sql"] for p in gold_sqls_dict] |
|
validate_sqls = [p["validation_sql"] for p in gold_sqls_dict] |
|
gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict] |
|
pred_sqls = [p["pred"] for p in pred_sqls_dict] |
|
categories = [p.get("category", "") for p in gold_sqls_dict] |
|
|
|
metrics = compute_metrics( |
|
gold_sqls=gold_sqls, |
|
pred_sqls=pred_sqls, |
|
gold_dbs=gold_dbs, |
|
setup_sqls=setup_sqls, |
|
validate_sqls=validate_sqls, |
|
kmaps=kmaps, |
|
db_schemas=db_schemas, |
|
database_dir=db_dir, |
|
lowercase_schema_match=False, |
|
model_name=model_name, |
|
categories=categories, |
|
) |
|
|
|
results.append("Evaluation completed.") |
|
|
|
|
|
if metrics: |
|
to_print = get_to_print({"all": metrics}, "all", model_name, len(gold_sqls)) |
|
formatted_metrics = "\n".join([f"{k}: {v}" for k, v in to_print.items() if k not in ["slice", "model"]]) |
|
results.append(f"Evaluation metrics:\n{formatted_metrics}") |
|
else: |
|
results.append("No evaluation metrics returned.") |
|
|
|
except Exception as e: |
|
results.append(f"An unexpected error occurred: {str(e)}") |
|
|
|
return "\n\n".join(results) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# DuckDB SQL Evaluation App") |
|
|
|
model_name = gr.Textbox(label="Model Name (e.g., qwen/qwen-2.5-72b-instruct)") |
|
start_btn = gr.Button("Start Evaluation") |
|
output = gr.Textbox(label="Output", lines=20) |
|
|
|
start_btn.click(fn=run_evaluation, inputs=[model_name], outputs=output) |
|
|
|
demo.launch() |