File size: 5,386 Bytes
977063a acfff07 5051da6 acfff07 f9d0ccd acfff07 977063a acfff07 5051da6 acfff07 5051da6 f9d0ccd 5051da6 acfff07 5051da6 acfff07 5051da6 acfff07 5051da6 acfff07 5051da6 49c6a0b 5051da6 acfff07 5051da6 f9d0ccd 5051da6 49c6a0b 5051da6 49c6a0b 977063a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import gradio as gr
import spaces
import torch
import os
import sys
from pathlib import Path
from datetime import datetime
import json
# Add the duckdb-nsql directory to the Python path
sys.path.append('duckdb-nsql')
# Import necessary functions and classes from predict.py and evaluate.py
from eval.predict import cli as predict_cli, predict, console, get_manifest, DefaultLoader, PROMPT_FORMATTERS
from eval.evaluate import cli as evaluate_cli, evaluate, compute_metrics, get_to_print
from eval.evaluate import test_suite_evaluation, read_tables_json
zero = torch.Tensor([0]).cuda()
print(zero.device) # <-- 'cpu' 🤔
@spaces.GPU
def run_evaluation(model_name):
print(zero.device) # <-- 'cuda:0' 🤗
results = []
if "OPENROUTER_API_KEY" not in os.environ:
return "Error: OPENROUTER_API_KEY not found in environment variables."
try:
# Set up the arguments similar to the CLI in predict.py
dataset_path = "eval/data/dev.json"
table_meta_path = "eval/data/tables.json"
output_dir = "output/"
prompt_format = "duckdbinstgraniteshort"
stop_tokens = [';']
max_tokens = 30000
temperature = 0.1
num_beams = -1
manifest_client = "openrouter"
manifest_engine = model_name
manifest_connection = "http://localhost:5000"
overwrite_manifest = True
parallel = False
# Initialize necessary components
data_formatter = DefaultLoader()
prompt_formatter = PROMPT_FORMATTERS[prompt_format]()
# Load manifest
manifest = get_manifest(
manifest_client=manifest_client,
manifest_connection=manifest_connection,
manifest_engine=manifest_engine,
)
results.append(f"Using model: {manifest_engine}")
# Load data and metadata
results.append("Loading metadata and data...")
db_to_tables = data_formatter.load_table_metadata(table_meta_path)
data = data_formatter.load_data(dataset_path)
# Generate output filename
date_today = datetime.now().strftime("%y-%m-%d")
pred_filename = f"{prompt_format}_0docs_{manifest_engine.split('/')[-1]}_{Path(dataset_path).stem}_{date_today}.json"
pred_path = Path(output_dir) / pred_filename
results.append(f"Prediction will be saved to: {pred_path}")
# Run prediction
results.append("Starting prediction...")
predict(
dataset_path=dataset_path,
table_meta_path=table_meta_path,
output_dir=output_dir,
prompt_format=prompt_format,
stop_tokens=stop_tokens,
max_tokens=max_tokens,
temperature=temperature,
num_beams=num_beams,
manifest_client=manifest_client,
manifest_engine=manifest_engine,
manifest_connection=manifest_connection,
overwrite_manifest=overwrite_manifest,
parallel=parallel
)
results.append("Prediction completed.")
# Run evaluation
results.append("Starting evaluation...")
# Set up evaluation arguments
gold_path = Path(dataset_path)
db_dir = "eval/data/databases/"
tables_path = Path(table_meta_path)
kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path))
db_schemas = read_tables_json(str(tables_path))
gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8"))
pred_sqls_dict = [json.loads(l) for l in pred_path.open("r").readlines()]
gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict]
setup_sqls = [p["setup_sql"] for p in gold_sqls_dict]
validate_sqls = [p["validation_sql"] for p in gold_sqls_dict]
gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict]
pred_sqls = [p["pred"] for p in pred_sqls_dict]
categories = [p.get("category", "") for p in gold_sqls_dict]
metrics = compute_metrics(
gold_sqls=gold_sqls,
pred_sqls=pred_sqls,
gold_dbs=gold_dbs,
setup_sqls=setup_sqls,
validate_sqls=validate_sqls,
kmaps=kmaps,
db_schemas=db_schemas,
database_dir=db_dir,
lowercase_schema_match=False,
model_name=model_name,
categories=categories,
)
results.append("Evaluation completed.")
# Format and add the evaluation metrics to the results
if metrics:
to_print = get_to_print({"all": metrics}, "all", model_name, len(gold_sqls))
formatted_metrics = "\n".join([f"{k}: {v}" for k, v in to_print.items() if k not in ["slice", "model"]])
results.append(f"Evaluation metrics:\n{formatted_metrics}")
else:
results.append("No evaluation metrics returned.")
except Exception as e:
results.append(f"An unexpected error occurred: {str(e)}")
return "\n\n".join(results)
with gr.Blocks() as demo:
gr.Markdown("# DuckDB SQL Evaluation App")
model_name = gr.Textbox(label="Model Name (e.g., qwen/qwen-2.5-72b-instruct)")
start_btn = gr.Button("Start Evaluation")
output = gr.Textbox(label="Output", lines=20)
start_btn.click(fn=run_evaluation, inputs=[model_name], outputs=output)
demo.launch() |