import json import logging import os import urllib.parse from typing import Any import gradio as gr import requests from gradio_huggingfacehub_search import HuggingfaceHubSearch from huggingface_hub.repocard import CardData, RepoCard logger = logging.getLogger(__name__) example = HuggingfaceHubSearch().example_value() SYSTEM_PROMPT_TEMPLATE = ( "You are a SQL query expert assistant that returns a DuckDB SQL queries " "based on the user's natural language query and dataset features. " "You might need to use DuckDB functions for lists and aggregations, " "given the features. Only return the SQL query, no other text. The " "user may ask you to make various adjustments to the query. Every " "time your response should only include the refined SQL query and " "nothing else.\n\n" "The table being queried is named: {table_name}.\n\n" "# Features\n" "{features}" ) def get_iframe(hub_repo_id, sql_query=None): if not hub_repo_id: raise ValueError("Hub repo id is required") if sql_query: sql_query = urllib.parse.quote(sql_query) url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer?sql_console=true&sql={sql_query}" else: url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer" iframe = f""" """ return iframe def get_table_info(hub_repo_id): url: str = f"https://datasets-server.huggingface.co/info?dataset={hub_repo_id}" response = requests.get(url) try: data = response.json() data = data.get("dataset_info") return json.dumps(data) except Exception as e: gr.Error(f"Error getting column info: {e}") def get_table_name( config: str | None, split: str | None, config_choices: list[str], split_choices: list[str], ): if len(config_choices) > 0 and config is None: config = config_choices[0] if len(split_choices) > 0 and split is None: split = split_choices[0] if len(config_choices) > 1 and len(split_choices) > 1: base_name = f"{config}_{split}" elif len(config_choices) >= 1 and len(split_choices) <= 1: base_name = config else: base_name = split def replace_char(c): if c.isalnum(): return c if c in ["-", "_", "/"]: return "_" return "" table_name = "".join(replace_char(c) for c in base_name) if table_name[0].isdigit(): table_name = f"_{table_name}" return table_name.lower() def get_system_prompt( card_data: dict[str, Any], config: str | None, split: str | None, ): config_choices = get_config_choices(card_data) split_choices = get_split_choices(card_data) table_name = get_table_name(config, split, config_choices, split_choices) features = card_data[config]["features"] return SYSTEM_PROMPT_TEMPLATE.format( table_name=table_name, features=features, ) def get_config_choices(card_data: dict[str, Any]) -> list[str]: return list(card_data.keys()) def get_split_choices(card_data: dict[str, Any]) -> list[str]: splits = set() for config in card_data.values(): splits.update(config.get("splits", {}).keys()) return list(splits) def query_dataset(hub_repo_id, card_data, query, config, split, history): if card_data is None or len(card_data) == 0: return "", get_iframe(hub_repo_id), [] card_data = json.loads(card_data) system_prompt = get_system_prompt(card_data, config, split) messages = [{"role": "system", "content": system_prompt}] for turn in history: user, assistant = turn messages.append( { "role": "user", "content": user, } ) messages.append( { "role": "assistant", "content": assistant, } ) messages.append( { "role": "user", "content": query, } ) api_key = os.environ["API_KEY_TOGETHER_AI"].strip() response = requests.post( "https://api.together.xyz/v1/chat/completions", json=dict( model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", messages=messages, max_tokens=1000, ), headers={"Authorization": f"Bearer {api_key}"}, ) if response.status_code != 200: logger.warning(response.text) try: response.raise_for_status() except Exception as e: gr.Error(f"Could not query LLM for suggestion: {e}") response_dict = response.json() duck_query = response_dict["choices"][0]["message"]["content"] duck_query = _sanitize_duck_query(duck_query) history.append((query, duck_query)) return duck_query, get_iframe(hub_repo_id, duck_query), history def _sanitize_duck_query(duck_query: str) -> str: # Sometimes the LLM wraps the query like this: # ```sql # select * from x; # ``` # This removes that wrapping if present. if "```" not in duck_query: return duck_query start_idx = duck_query.index("```") + len("```") end_idx = duck_query.rindex("```") duck_query = duck_query[start_idx:end_idx] if duck_query.startswith("sql\n"): duck_query = duck_query.replace("sql\n", "", 1) return duck_query with gr.Blocks() as demo: gr.Markdown("""# 🐥 🦙 🤗 Text To SQL Hub Datasets 🤗 🦙 🐥 This is a basic text to SQL tool that allows you to query datasets on Huggingface Hub. It is built with [DuckDB](https://duckdb.org/), [Huggingface's Inference API](https://huggingface.co/docs/api-inference/index), and [LLama 3.1 70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct). Also, it uses the [dataset-server API](https://redocly.github.io/redoc/?url=https://datasets-server.huggingface.co/openapi.json#operation/isValidDataset). """) with gr.Row(): search_in = HuggingfaceHubSearch( label="Search Huggingface Hub", placeholder="Search for models on Huggingface", search_type="dataset", sumbit_on_select=True, ) with gr.Row(): show_btn = gr.Button("Show Dataset") with gr.Row(): sql_out = gr.Code( label="DuckDB SQL Query", interactive=True, language="sql", lines=1, visible=False, ) with gr.Row(): card_data = gr.Code(label="Card data", language="json", visible=False) @gr.render(inputs=[card_data]) def show_config_split_choices(data): try: data = json.loads(data.strip()) config_choices = get_config_choices(data) split_choices = get_split_choices(data) except Exception: config_choices = [] split_choices = [] initial_config = config_choices[0] if len(config_choices) > 0 else None initial_split = split_choices[0] if len(split_choices) > 0 else None with gr.Row(): with gr.Column(): config_selection = gr.Dropdown( label="Config Name", choices=config_choices, value=initial_config ) with gr.Column(): split_selection = gr.Dropdown( label="Split Name", choices=split_choices, value=initial_split ) with gr.Accordion("Query Suggestion History.", open=False) as accordion: chatbot = gr.Chatbot(height=200, layout="bubble") with gr.Row(): query = gr.Textbox( label="Query Description", placeholder="Enter a natural language query to generate SQL", ) with gr.Row(): with gr.Column(): query_btn = gr.Button("Get Suggested Query") with gr.Column(): clear = gr.ClearButton([query, chatbot], value="Reset Query History") with gr.Row(): search_out = gr.HTML(label="Search Results") gr.on( [show_btn.click, search_in.submit], fn=get_iframe, inputs=[search_in], outputs=[search_out], ).then( fn=get_table_info, inputs=[search_in], outputs=[card_data], ) gr.on( [query_btn.click, query.submit], fn=query_dataset, inputs=[ search_in, card_data, query, config_selection, split_selection, chatbot, ], outputs=[sql_out, search_out, chatbot], ) gr.on([query_btn.click], fn=lambda: gr.update(open=True), outputs=[accordion]) if __name__ == "__main__": demo.launch()