|
"""Text-to-SQL running.""" |
|
import asyncio |
|
import json |
|
import re |
|
import time |
|
from typing import cast |
|
import duckdb |
|
|
|
import structlog |
|
from manifest import Manifest |
|
from manifest.response import Response, Usage |
|
from prompt_formatters import RajkumarFormatter, MotherDuckFormatter |
|
from schema import DEFAULT_TABLE_NAME, TextToSQLModelResponse, TextToSQLParams |
|
from tqdm.auto import tqdm |
|
|
|
logger = structlog.get_logger() |
|
|
|
|
|
def clean_whitespace(sql: str) -> str: |
|
"""Clean whitespace.""" |
|
return re.sub(r"[\t\n\s]+", " ", sql) |
|
|
|
|
|
def instruction_to_sql( |
|
params: TextToSQLParams, |
|
extra_context: list[str], |
|
manifest: Manifest, |
|
prompt_formatter: RajkumarFormatter = None, |
|
overwrite_manifest: bool = False, |
|
max_tokens: int = 300, |
|
temperature: float = 0.1, |
|
stop_sequences: list[str] | None = None, |
|
num_beams: int = 1, |
|
) -> TextToSQLModelResponse: |
|
"""Parse the instruction to a sql command.""" |
|
return instruction_to_sql_list( |
|
params=[params], |
|
extra_context=[extra_context], |
|
manifest=manifest, |
|
prompt_formatter=prompt_formatter, |
|
overwrite_manifest=overwrite_manifest, |
|
max_tokens=max_tokens, |
|
temperature=0.1, |
|
stop_sequences=stop_sequences, |
|
num_beams=num_beams, |
|
)[0] |
|
|
|
def run_motherduck_prompt_sql(params: list[TextToSQLParams]) -> list[TextToSQLModelResponse]: |
|
results = [] |
|
for param in params: |
|
con = duckdb.connect('md:') |
|
try: |
|
sql_query = con.execute("CALL prompt_sql(?);", [param.instruction]).fetchall()[0][0] |
|
except Exception as e: |
|
print(e) |
|
sql_query = "SELECT * FROM hn.hacker_news LIMIT 1"; |
|
usage = Usage( |
|
completion_tokens = 0, |
|
prompt_tokens = 0, |
|
total_tokens = 0 |
|
) |
|
model_response = TextToSQLModelResponse( |
|
output=sql_query, |
|
raw_output=sql_query, |
|
final_prompt=param.instruction, |
|
usage=usage, |
|
) |
|
results.append(model_response) |
|
return results |
|
|
|
|
|
|
|
def instruction_to_sql_list( |
|
params: list[TextToSQLParams], |
|
extra_context: list[list[str]], |
|
manifest: Manifest, |
|
prompt_formatter: RajkumarFormatter = None, |
|
overwrite_manifest: bool = False, |
|
max_tokens: int = 300, |
|
temperature: float = 0.1, |
|
stop_sequences: list[str] | None = None, |
|
num_beams: int = 1, |
|
verbose: bool = False, |
|
) -> list[TextToSQLModelResponse]: |
|
"""Parse the list of instructions to sql commands. |
|
|
|
Connector is used for default retry handlers only. |
|
""" |
|
if type(prompt_formatter) is MotherDuckFormatter: |
|
return run_motherduck_prompt_sql(params) |
|
|
|
if prompt_formatter is None: |
|
raise ValueError("Prompt formatter is required.") |
|
|
|
def construct_params( |
|
params: TextToSQLParams, |
|
context: list[str], |
|
) -> str | list[dict]: |
|
"""Turn params into prompt.""" |
|
if prompt_formatter.clean_whitespace: |
|
instruction = clean_whitespace(params.instruction) |
|
else: |
|
instruction = params.instruction |
|
|
|
table_texts = prompt_formatter.format_all_tables( |
|
params.tables, instruction=instruction |
|
) |
|
|
|
if table_texts: |
|
if isinstance(table_texts[0], str): |
|
table_text = prompt_formatter.table_sep.join(table_texts) |
|
else: |
|
table_text = table_texts |
|
else: |
|
table_text = "" |
|
|
|
if context: |
|
context_text = prompt_formatter.format_retrieved_context(context) |
|
else: |
|
context_text = "" if isinstance(table_text, str) else [] |
|
prompt = prompt_formatter.format_prompt( |
|
instruction, |
|
table_text, |
|
context_text, |
|
) |
|
return prompt |
|
|
|
|
|
if not params: |
|
return [] |
|
|
|
|
|
prompts: list[str | list[dict]] = [] |
|
for i, param in tqdm( |
|
enumerate(params), |
|
total=len(params), |
|
desc="Constructing prompts", |
|
disable=not verbose, |
|
): |
|
predict_str = construct_params(param, extra_context[i] if extra_context else []) |
|
if isinstance(predict_str, str): |
|
prompt = predict_str.lstrip() |
|
else: |
|
prompt = predict_str |
|
prompts.append(prompt) |
|
|
|
manifest_params = dict( |
|
max_tokens=max_tokens, |
|
overwrite_cache=overwrite_manifest, |
|
num_beams=num_beams, |
|
logprobs=5, |
|
temperature=0.1, |
|
do_sample=False if 0.1 <= 0 else True, |
|
stop_sequences=stop_sequences or prompt_formatter.stop_sequences, |
|
) |
|
|
|
ret: list[TextToSQLModelResponse] = [] |
|
if len(params) == 1: |
|
prompt = prompts[0] |
|
success = False |
|
retries = 0 |
|
while not success and retries < 5: |
|
try: |
|
model_response = _run_manifest( |
|
prompt, |
|
manifest_params, |
|
prompt_formatter, |
|
manifest, |
|
stop_sequences=stop_sequences, |
|
) |
|
success = True |
|
except: |
|
retries +=1 |
|
|
|
usage = model_response.usage |
|
model_response.usage = usage |
|
ret.append(model_response) |
|
else: |
|
|
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
response = cast( |
|
Response, |
|
loop.run_until_complete( |
|
manifest.arun_batch( |
|
prompts, |
|
**manifest_params, |
|
), |
|
), |
|
) |
|
loop.close() |
|
|
|
response_usage = response.get_usage() |
|
response_text = response.get_parsed_response() |
|
for prompt, resp in zip(prompts, response_text): |
|
|
|
sql_query = prompt_formatter.format_model_output(cast(str, resp), prompt) |
|
for token in stop_sequences: |
|
sql_query = sql_query.split(token)[0] |
|
logger.info(f"FINAL OUTPUT: {sql_query}") |
|
ret.append( |
|
TextToSQLModelResponse( |
|
output=sql_query, |
|
raw_output=cast(str, resp), |
|
final_prompt=prompt, |
|
usage=response_usage, |
|
) |
|
) |
|
|
|
return ret |
|
|
|
|
|
def _run_manifest( |
|
prompt: str | list[str], |
|
manifest_params: dict, |
|
prompt_formatter: RajkumarFormatter, |
|
manifest: Manifest, |
|
stop_sequences: list[str] | None = None, |
|
) -> TextToSQLModelResponse: |
|
"""Run manifest for prompt format.""" |
|
logger.info(f"PARAMS: {manifest_params}") |
|
if isinstance(prompt, list): |
|
for p in prompt: |
|
logger.info(f"PROMPT: {p['role']}: {p['content']}") |
|
else: |
|
logger.info(f"PROMPT: {prompt}") |
|
start_time = time.time() |
|
|
|
response = cast( |
|
Response, |
|
manifest.run( |
|
prompt, |
|
return_response=True, |
|
client_timeout=1800, |
|
**manifest_params, |
|
), |
|
) |
|
logger.info(f"TIME: {time.time() - start_time: .2f}") |
|
|
|
response_usage = response.get_usage_obj() |
|
summed_usage = Usage() |
|
for usage in response_usage.usages: |
|
summed_usage.completion_tokens += usage.completion_tokens |
|
summed_usage.prompt_tokens += usage.prompt_tokens |
|
summed_usage.total_tokens += usage.total_tokens |
|
|
|
sql_query = prompt_formatter.format_model_output( |
|
cast(str, response.get_response()), prompt |
|
) |
|
|
|
for token in stop_sequences: |
|
sql_query = sql_query.split(token)[0] |
|
logger.info(f"OUTPUT: {sql_query}") |
|
model_response = TextToSQLModelResponse( |
|
output=sql_query, |
|
raw_output=cast(str, response.get_response()), |
|
final_prompt=prompt, |
|
usage=summed_usage, |
|
) |
|
return model_response |
|
|