|
"""Run dataset on text2sql zazu experiment. |
|
|
|
See README.md for more details. |
|
""" |
|
import datetime |
|
import json |
|
import multiprocessing |
|
import random |
|
import re |
|
from pathlib import Path |
|
|
|
import click |
|
import numpy as np |
|
from constants import PROMPT_FORMATTERS |
|
from loaders import DefaultLoader |
|
from get_manifest import get_manifest |
|
from manifest import Manifest |
|
from prompt_formatters import RajkumarFormatter |
|
from rich.console import Console |
|
from schema import Table, TextToSQLModelResponse, TextToSQLParams |
|
from text_to_sql import instruction_to_sql, instruction_to_sql_list |
|
from doc_retriever import ( |
|
load_documentation, |
|
split_documents, |
|
embed_documents, |
|
query_docs, |
|
) |
|
from tqdm import tqdm |
|
from transformers import AutoTokenizer |
|
|
|
console = Console(soft_wrap=True) |
|
|
|
|
|
def generate_sql( |
|
manifest: Manifest, |
|
text_to_sql_in: list[TextToSQLParams], |
|
retrieved_docs: list[list[str]], |
|
prompt_formatter: RajkumarFormatter, |
|
stop_tokens: list[str] | None = None, |
|
overwrite_manifest: bool = False, |
|
max_tokens: int = 300, |
|
temperature: float = 0.1, |
|
num_beams: int = 2, |
|
parallel: bool = False, |
|
) -> list[tuple[str, TextToSQLModelResponse]]: |
|
"""Call our text2sql function with manifest of our choice.""" |
|
if parallel: |
|
instruction_to_sql_resps: list[ |
|
TextToSQLModelResponse |
|
] = instruction_to_sql_list( |
|
params=text_to_sql_in, |
|
extra_context=retrieved_docs, |
|
manifest=manifest, |
|
prompt_formatter=prompt_formatter, |
|
overwrite_manifest=overwrite_manifest, |
|
max_tokens=max_tokens, |
|
temperature=0.1, |
|
stop_sequences=stop_tokens, |
|
num_beams=num_beams, |
|
) |
|
else: |
|
instruction_to_sql_resps = [ |
|
instruction_to_sql( |
|
params=_text_to_sql_in, |
|
extra_context=_retrieved_docs, |
|
manifest=manifest, |
|
prompt_formatter=prompt_formatter, |
|
overwrite_manifest=overwrite_manifest, |
|
max_tokens=max_tokens, |
|
temperature=temperature, |
|
stop_sequences=stop_tokens, |
|
num_beams=num_beams, |
|
) |
|
for _retrieved_docs, _text_to_sql_in in tqdm( |
|
zip(retrieved_docs, text_to_sql_in), |
|
desc="Generating SQL", |
|
total=len(text_to_sql_in), |
|
disable=(len(text_to_sql_in) <= 1), |
|
) |
|
] |
|
assert len(instruction_to_sql_resps) == len(text_to_sql_in) |
|
|
|
sql_statements = [] |
|
for i in range(len(instruction_to_sql_resps)): |
|
sql_statement = instruction_to_sql_resps[i].output.strip() |
|
if "<>" in sql_statement: |
|
sql_statement.replace("<>", "!=") |
|
|
|
sql_statement = sql_statement.split("|")[-1].strip() |
|
sql_statements.append(sql_statement) |
|
return list(zip(sql_statements, instruction_to_sql_resps)) |
|
|
|
|
|
def get_text_to_sql_in( |
|
input_question: dict, db_to_tables: dict[str, dict[str, Table]] |
|
) -> TextToSQLParams: |
|
"""Format input question for text2sql function.""" |
|
question = input_question["question"] |
|
db_id = input_question.get("db_id", None) |
|
if db_id != "none": |
|
table_params = list(db_to_tables.get(db_id, {}).values()) |
|
else: |
|
table_params = [] |
|
if len(table_params) == 0: |
|
console.print(f"[red] WARNING: No tables found for {db_id} [/red]") |
|
text_to_sql_in = TextToSQLParams( |
|
instruction=question, |
|
database=db_id, |
|
tables=table_params, |
|
) |
|
return text_to_sql_in |
|
|
|
|
|
@click.group() |
|
def cli() -> None: |
|
"""Entrypoint.""" |
|
pass |
|
|
|
|
|
@cli.command() |
|
@click.argument("dataset-path") |
|
@click.argument("table-meta-path") |
|
@click.option("--output-dir", type=str, default="") |
|
@click.option("--run-name", type=str, default="") |
|
@click.option("--num-run", type=int, default=-1) |
|
@click.option("--num-print", type=int, default=20) |
|
|
|
@click.option("--prompt-format", type=str, default="spider") |
|
|
|
@click.option("--stop-tokens", type=str, default=[], multiple=True) |
|
@click.option("--max-tokens", type=int, default=200) |
|
@click.option("--temperature", type=float, default=0) |
|
@click.option("--num-beams", type=int, default=-1) |
|
@click.option("--max-context-length", type=int, default=-1) |
|
|
|
@click.option( |
|
"--markdown-docs-path", |
|
|
|
|
|
|
|
default="eval/docs/duckdb-web/docs/archive/0.9.2/sql", |
|
) |
|
@click.option("--num-retrieved-docs", type=int, default=0) |
|
|
|
@click.option("--manifest-client", type=str, default="openai") |
|
@click.option("--manifest-engine", type=str, default="gpt-4o") |
|
@click.option("--manifest-connection", type=str, default="http://localhost:5005") |
|
@click.option("--overwrite-manifest", is_flag=True, default=False) |
|
@click.option("--parallel", is_flag=True, default=False) |
|
def predict( |
|
dataset_path: str, |
|
table_meta_path: str, |
|
output_dir: str, |
|
run_name: str, |
|
num_run: int, |
|
num_print: int, |
|
prompt_format: str, |
|
stop_tokens: list[str], |
|
max_tokens: int, |
|
temperature: float, |
|
num_beams: int, |
|
max_context_length: int, |
|
markdown_docs_path: Path, |
|
num_retrieved_docs: int, |
|
manifest_client: str, |
|
manifest_engine: str, |
|
manifest_connection: str, |
|
overwrite_manifest: bool, |
|
parallel: bool, |
|
) -> None: |
|
"""Predict SQL. |
|
|
|
Args: |
|
dataset_path: the dataset path. |
|
table_meta_path: the json path of the table metadata. |
|
database_path: the database path for sqlite. |
|
output_dir: the prediction output directory |
|
run_name: special prefix to add to filename |
|
num_run: the number of examples to run |
|
num_print: the number of examples to print |
|
prompt_format: the format of the prompt. E.g., "rajkumar" |
|
stop_tokens: the stop tokens to try |
|
max_tokens: the max tokens |
|
temperature: the temperature |
|
num_beams: the number of beams |
|
max_context_length: max context length for demonstration truncation (-1 means None) |
|
markdown_docs_path: path to duckdb sql docs |
|
num_retrieved_docs: number of docs to retrieve |
|
manifest_client: the manifest client |
|
manifest_engine: the manifest engine |
|
manifest_connection: the manifest connection |
|
""" |
|
multiprocessing.set_start_method("spawn", force=True) |
|
random.seed(0) |
|
np.random.seed(0) |
|
locals_dict = locals() |
|
locals_dict["markdown_docs_path"] = str(markdown_docs_path) |
|
console.print(json.dumps(locals_dict, indent=2)) |
|
|
|
data_formatter = DefaultLoader() |
|
|
|
if prompt_format not in PROMPT_FORMATTERS: |
|
raise ValueError(f"Unknown prompt format {prompt_format}") |
|
prompt_formatter = PROMPT_FORMATTERS[prompt_format]() |
|
|
|
|
|
manifest = get_manifest( |
|
manifest_client=manifest_client, |
|
manifest_connection=manifest_connection, |
|
manifest_engine=manifest_engine, |
|
) |
|
manifest_params = manifest.client_pool.get_current_client().get_model_params() |
|
console.print(f"Running with {manifest_params} manifest.") |
|
model_name = manifest_params.get("engine", manifest_params["model_name"]) |
|
|
|
if manifest_client in {"openai", "openaichat", "openrouter", "azureendpoint"}: |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2", trust_remote_code=True) |
|
else: |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
if stop_tokens: |
|
stop_tokens = [st.strip("'") for st in stop_tokens] |
|
console.print(f"Stop tokens: {stop_tokens}") |
|
|
|
|
|
full_dataset_path = Path(dataset_path) |
|
|
|
date_today = datetime.datetime.now().strftime("%y-%m-%d") |
|
if run_name: |
|
run_name = f"{run_name}_" |
|
suffix = f"{run_name}{full_dataset_path.stem}_{date_today}.json" |
|
prefix = f"{prompt_format}_{num_retrieved_docs}docs" |
|
if manifest_client in {"openai", "openaiazure"}: |
|
middleix = manifest_engine |
|
elif manifest_client in {"huggingface", "ray"}: |
|
middleix = Path(manifest_params.get("model_path", "")).name.replace("/", "-") |
|
elif manifest_client in {"toma", "openrouter", "openaichat", "azureendpoint"}: |
|
middleix = manifest_engine.split("/")[-1] |
|
else: |
|
raise ValueError(f"Unknown manifest client {manifest_client}") |
|
output_filename = f"{prefix}_{middleix}_{suffix}" |
|
console.print(f"Saving to {Path(output_dir) / output_filename}") |
|
Path(output_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
console.print("Loading metadata...") |
|
db_to_tables = data_formatter.load_table_metadata(table_meta_path) |
|
|
|
console.print("Loading data...") |
|
data = data_formatter.load_data(dataset_path) |
|
if num_run > 0: |
|
console.print(f"Running on {min(len(data), num_run)} examples") |
|
data = data[:num_run] |
|
original_data = data |
|
|
|
|
|
console.print("Formatting data...") |
|
num_print = min(num_print, len(data)) |
|
token_lengths = [] |
|
text_to_sql_in = [ |
|
get_text_to_sql_in(input_question, db_to_tables) for input_question in data |
|
] |
|
|
|
if num_retrieved_docs > 0: |
|
console.print("Loading documenration and indexing...") |
|
retrieved_docs = [] |
|
doc_contents = load_documentation(markdown_docs_path) |
|
chunked_docs = split_documents(doc_contents) |
|
embedded_docs, full_embedding_mat = embed_documents(chunked_docs) |
|
for i in tqdm(range(len(text_to_sql_in)), desc="Retrieving docs"): |
|
_, retrieved_docs_strings = query_docs( |
|
text_to_sql_in[i].instruction, |
|
embedded_docs, |
|
full_embedding_mat, |
|
top_n=num_retrieved_docs, |
|
) |
|
retrieved_docs.append(retrieved_docs_strings) |
|
else: |
|
retrieved_docs = [[] for _ in range(len(text_to_sql_in))] |
|
|
|
for i in range(num_print): |
|
|
|
generated_responses = generate_sql( |
|
manifest=manifest, |
|
text_to_sql_in=[text_to_sql_in[i]], |
|
retrieved_docs=[retrieved_docs[i]], |
|
stop_tokens=stop_tokens, |
|
max_tokens=max_tokens, |
|
temperature=temperature, |
|
num_beams=num_beams, |
|
prompt_formatter=prompt_formatter, |
|
overwrite_manifest=overwrite_manifest, |
|
parallel=parallel, |
|
) |
|
for prediction, model_response in generated_responses: |
|
prediction = re.sub(r"[\s\t\n]+", " ", prediction) |
|
token_lengths.append(len(tokenizer(prediction).input_ids)) |
|
console.print(f"[blue]Prompt:[/blue] {model_response.final_prompt}") |
|
console.print(f"[red]Prediction:[/red] {prediction}") |
|
if data[i].get("query") or data[i].get("sql"): |
|
console.print( |
|
"[purple]Gold:[/purple] " |
|
f"{data[i].get('query') or data[i].get('sql')}" |
|
) |
|
console.print("\n****\n") |
|
|
|
|
|
generated_sqls = generate_sql( |
|
manifest=manifest, |
|
text_to_sql_in=text_to_sql_in, |
|
retrieved_docs=retrieved_docs, |
|
stop_tokens=stop_tokens, |
|
max_tokens=max_tokens, |
|
temperature=temperature, |
|
num_beams=num_beams, |
|
prompt_formatter=prompt_formatter, |
|
overwrite_manifest=overwrite_manifest, |
|
parallel=parallel, |
|
) |
|
|
|
with open(Path(output_dir) / output_filename, "w") as fout: |
|
for i, (prediction, model_response) in enumerate(generated_sqls): |
|
if isinstance(model_response.final_prompt, str): |
|
token_lengths.append( |
|
len(tokenizer(model_response.final_prompt).input_ids) |
|
) |
|
else: |
|
for prompt in model_response.final_prompt: |
|
token_lengths.append(len(tokenizer(prompt["content"]).input_ids)) |
|
entry = { |
|
**original_data[i], |
|
"pred": prediction, |
|
"raw_pred": model_response.output, |
|
"raw_output": model_response.raw_output, |
|
"prompt": model_response.final_prompt, |
|
"tables": [tbl.dict() for tbl in text_to_sql_in[i].tables or []], |
|
} |
|
formatted_entry = data_formatter.format_output(entry) |
|
print(json.dumps(formatted_entry), file=fout) |
|
overflow = len([tl for tl in token_lengths if tl > 2048]) / len(token_lengths) |
|
console.print(f"Overflow 2048 prompt {100*overflow:.2f}%") |
|
console.print(f"Saved to {Path(output_dir) / output_filename}") |
|
|
|
|
|
if __name__ == "__main__": |
|
cli() |
|
|