import gradio as gr import spaces import torch import os import sys from pathlib import Path from datetime import datetime import json # Add the duckdb-nsql directory to the Python path sys.path.append('duckdb-nsql') # Import necessary functions and classes from predict.py and evaluate.py from eval.predict import cli as predict_cli, predict, console, get_manifest, DefaultLoader, PROMPT_FORMATTERS from eval.evaluate import cli as evaluate_cli, evaluate, compute_metrics, get_to_print from eval.evaluate import test_suite_evaluation, read_tables_json zero = torch.Tensor([0]).cuda() print(zero.device) # <-- 'cpu' 🤔 @spaces.GPU def run_evaluation(model_name): print(zero.device) # <-- 'cuda:0' 🤗 results = [] if "OPENROUTER_API_KEY" not in os.environ: return "Error: OPENROUTER_API_KEY not found in environment variables." try: # Set up the arguments similar to the CLI in predict.py dataset_path = "eval/data/dev.json" table_meta_path = "eval/data/tables.json" output_dir = "output/" prompt_format = "duckdbinstgraniteshort" stop_tokens = [';'] max_tokens = 30000 temperature = 0.1 num_beams = -1 manifest_client = "openrouter" manifest_engine = model_name manifest_connection = "http://localhost:5000" overwrite_manifest = True parallel = False # Initialize necessary components data_formatter = DefaultLoader() prompt_formatter = PROMPT_FORMATTERS[prompt_format]() # Load manifest manifest = get_manifest( manifest_client=manifest_client, manifest_connection=manifest_connection, manifest_engine=manifest_engine, ) results.append(f"Using model: {manifest_engine}") # Load data and metadata results.append("Loading metadata and data...") db_to_tables = data_formatter.load_table_metadata(table_meta_path) data = data_formatter.load_data(dataset_path) # Generate output filename date_today = datetime.now().strftime("%y-%m-%d") pred_filename = f"{prompt_format}_0docs_{manifest_engine.split('/')[-1]}_{Path(dataset_path).stem}_{date_today}.json" pred_path = Path(output_dir) / pred_filename results.append(f"Prediction will be saved to: {pred_path}") # Run prediction results.append("Starting prediction...") predict( dataset_path=dataset_path, table_meta_path=table_meta_path, output_dir=output_dir, prompt_format=prompt_format, stop_tokens=stop_tokens, max_tokens=max_tokens, temperature=temperature, num_beams=num_beams, manifest_client=manifest_client, manifest_engine=manifest_engine, manifest_connection=manifest_connection, overwrite_manifest=overwrite_manifest, parallel=parallel ) results.append("Prediction completed.") # Run evaluation results.append("Starting evaluation...") # Set up evaluation arguments gold_path = Path(dataset_path) db_dir = "eval/data/databases/" tables_path = Path(table_meta_path) kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path)) db_schemas = read_tables_json(str(tables_path)) gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8")) pred_sqls_dict = [json.loads(l) for l in pred_path.open("r").readlines()] gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict] setup_sqls = [p["setup_sql"] for p in gold_sqls_dict] validate_sqls = [p["validation_sql"] for p in gold_sqls_dict] gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict] pred_sqls = [p["pred"] for p in pred_sqls_dict] categories = [p.get("category", "") for p in gold_sqls_dict] metrics = compute_metrics( gold_sqls=gold_sqls, pred_sqls=pred_sqls, gold_dbs=gold_dbs, setup_sqls=setup_sqls, validate_sqls=validate_sqls, kmaps=kmaps, db_schemas=db_schemas, database_dir=db_dir, lowercase_schema_match=False, model_name=model_name, categories=categories, ) results.append("Evaluation completed.") # Format and add the evaluation metrics to the results if metrics: to_print = get_to_print({"all": metrics}, "all", model_name, len(gold_sqls)) formatted_metrics = "\n".join([f"{k}: {v}" for k, v in to_print.items() if k not in ["slice", "model"]]) results.append(f"Evaluation metrics:\n{formatted_metrics}") else: results.append("No evaluation metrics returned.") except Exception as e: results.append(f"An unexpected error occurred: {str(e)}") return "\n\n".join(results) with gr.Blocks() as demo: gr.Markdown("# DuckDB SQL Evaluation App") model_name = gr.Textbox(label="Model Name (e.g., qwen/qwen-2.5-72b-instruct)") start_btn = gr.Button("Start Evaluation") output = gr.Textbox(label="Output", lines=20) start_btn.click(fn=run_evaluation, inputs=[model_name], outputs=output) demo.launch()