File size: 5,386 Bytes
977063a
acfff07
 
5051da6
acfff07
 
f9d0ccd
acfff07
977063a
acfff07
 
 
 
 
 
 
 
 
 
 
 
5051da6
acfff07
 
5051da6
 
f9d0ccd
 
5051da6
 
acfff07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5051da6
 
acfff07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5051da6
acfff07
 
 
5051da6
acfff07
5051da6
 
 
 
49c6a0b
5051da6
 
acfff07
5051da6
f9d0ccd
5051da6
49c6a0b
5051da6
49c6a0b
977063a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
import spaces
import torch
import os
import sys
from pathlib import Path
from datetime import datetime
import json

# Add the duckdb-nsql directory to the Python path
sys.path.append('duckdb-nsql')

# Import necessary functions and classes from predict.py and evaluate.py
from eval.predict import cli as predict_cli, predict, console, get_manifest, DefaultLoader, PROMPT_FORMATTERS
from eval.evaluate import cli as evaluate_cli, evaluate, compute_metrics, get_to_print
from eval.evaluate import test_suite_evaluation, read_tables_json

zero = torch.Tensor([0]).cuda()
print(zero.device)  # <-- 'cpu' 🤔

@spaces.GPU
def run_evaluation(model_name):
    print(zero.device)  # <-- 'cuda:0' 🤗

    results = []

    if "OPENROUTER_API_KEY" not in os.environ:
        return "Error: OPENROUTER_API_KEY not found in environment variables."

    try:
        # Set up the arguments similar to the CLI in predict.py
        dataset_path = "eval/data/dev.json"
        table_meta_path = "eval/data/tables.json"
        output_dir = "output/"
        prompt_format = "duckdbinstgraniteshort"
        stop_tokens = [';']
        max_tokens = 30000
        temperature = 0.1
        num_beams = -1
        manifest_client = "openrouter"
        manifest_engine = model_name
        manifest_connection = "http://localhost:5000"
        overwrite_manifest = True
        parallel = False

        # Initialize necessary components
        data_formatter = DefaultLoader()
        prompt_formatter = PROMPT_FORMATTERS[prompt_format]()

        # Load manifest
        manifest = get_manifest(
            manifest_client=manifest_client,
            manifest_connection=manifest_connection,
            manifest_engine=manifest_engine,
        )

        results.append(f"Using model: {manifest_engine}")

        # Load data and metadata
        results.append("Loading metadata and data...")
        db_to_tables = data_formatter.load_table_metadata(table_meta_path)
        data = data_formatter.load_data(dataset_path)

        # Generate output filename
        date_today = datetime.now().strftime("%y-%m-%d")
        pred_filename = f"{prompt_format}_0docs_{manifest_engine.split('/')[-1]}_{Path(dataset_path).stem}_{date_today}.json"
        pred_path = Path(output_dir) / pred_filename
        results.append(f"Prediction will be saved to: {pred_path}")

        # Run prediction
        results.append("Starting prediction...")
        predict(
            dataset_path=dataset_path,
            table_meta_path=table_meta_path,
            output_dir=output_dir,
            prompt_format=prompt_format,
            stop_tokens=stop_tokens,
            max_tokens=max_tokens,
            temperature=temperature,
            num_beams=num_beams,
            manifest_client=manifest_client,
            manifest_engine=manifest_engine,
            manifest_connection=manifest_connection,
            overwrite_manifest=overwrite_manifest,
            parallel=parallel
        )
        results.append("Prediction completed.")

        # Run evaluation
        results.append("Starting evaluation...")

        # Set up evaluation arguments
        gold_path = Path(dataset_path)
        db_dir = "eval/data/databases/"
        tables_path = Path(table_meta_path)

        kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path))
        db_schemas = read_tables_json(str(tables_path))

        gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8"))
        pred_sqls_dict = [json.loads(l) for l in pred_path.open("r").readlines()]

        gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict]
        setup_sqls = [p["setup_sql"] for p in gold_sqls_dict]
        validate_sqls = [p["validation_sql"] for p in gold_sqls_dict]
        gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict]
        pred_sqls = [p["pred"] for p in pred_sqls_dict]
        categories = [p.get("category", "") for p in gold_sqls_dict]

        metrics = compute_metrics(
            gold_sqls=gold_sqls,
            pred_sqls=pred_sqls,
            gold_dbs=gold_dbs,
            setup_sqls=setup_sqls,
            validate_sqls=validate_sqls,
            kmaps=kmaps,
            db_schemas=db_schemas,
            database_dir=db_dir,
            lowercase_schema_match=False,
            model_name=model_name,
            categories=categories,
        )

        results.append("Evaluation completed.")

        # Format and add the evaluation metrics to the results
        if metrics:
            to_print = get_to_print({"all": metrics}, "all", model_name, len(gold_sqls))
            formatted_metrics = "\n".join([f"{k}: {v}" for k, v in to_print.items() if k not in ["slice", "model"]])
            results.append(f"Evaluation metrics:\n{formatted_metrics}")
        else:
            results.append("No evaluation metrics returned.")

    except Exception as e:
        results.append(f"An unexpected error occurred: {str(e)}")

    return "\n\n".join(results)

with gr.Blocks() as demo:
    gr.Markdown("# DuckDB SQL Evaluation App")

    model_name = gr.Textbox(label="Model Name (e.g., qwen/qwen-2.5-72b-instruct)")
    start_btn = gr.Button("Start Evaluation")
    output = gr.Textbox(label="Output", lines=20)

    start_btn.click(fn=run_evaluation, inputs=[model_name], outputs=output)

demo.launch()