augray commited on
Commit
ca78baa
·
1 Parent(s): e30a182

Correct table name

Browse files
Files changed (2) hide show
  1. .gitignore +2 -1
  2. app.py +104 -23
.gitignore CHANGED
@@ -1 +1,2 @@
1
- .env
 
 
1
+ .env
2
+ .venv
app.py CHANGED
@@ -1,18 +1,17 @@
1
  import json
 
2
  import os
3
  import urllib.parse
 
4
 
5
  import gradio as gr
6
  import requests
7
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
8
- from huggingface_hub import InferenceClient
9
 
10
- example = HuggingfaceHubSearch().example_value()
11
 
12
- client = InferenceClient(
13
- "meta-llama/Meta-Llama-3.1-70B-Instruct",
14
- token=os.environ["HF_TOKEN"],
15
- )
16
 
17
 
18
  def get_iframe(hub_repo_id, sql_query=None):
@@ -34,20 +33,53 @@ def get_iframe(hub_repo_id, sql_query=None):
34
  return iframe
35
 
36
 
37
- def get_column_info(hub_repo_id):
38
  url: str = f"https://datasets-server.huggingface.co/info?dataset={hub_repo_id}"
39
  response = requests.get(url)
40
  try:
41
  data = response.json()
42
  data = data.get("dataset_info")
43
- key = list(data.keys())[0]
44
- features: str = json.dumps(data.get(key).get("features"))
45
  except Exception as e:
46
  gr.Error(f"Error getting column info: {e}")
47
- return features
48
 
49
 
50
- def query_dataset(hub_repo_id, features, query):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  messages = [
52
  {
53
  "role": "system",
@@ -55,22 +87,71 @@ def query_dataset(hub_repo_id, features, query):
55
  },
56
  {
57
  "role": "user",
58
- "content": f"""table train
59
  # Features
60
  {features}
61
 
62
  # Query
63
- {query}
64
  """,
65
  },
66
  ]
67
- response = client.chat_completion(
68
- messages=messages,
69
- max_tokens=1000,
70
- stream=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  )
72
- query = response.choices[0].message.content
73
- return query, get_iframe(hub_repo_id, query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
 
76
  with gr.Blocks() as demo:
@@ -107,21 +188,21 @@ with gr.Blocks() as demo:
107
  with gr.Row():
108
  search_out = gr.HTML(label="Search Results")
109
  with gr.Row():
110
- features = gr.Code(label="Features", language="json", visible=False)
111
  gr.on(
112
  [btn.click, search_in.submit],
113
  fn=get_iframe,
114
  inputs=[search_in],
115
  outputs=[search_out],
116
  ).then(
117
- fn=get_column_info,
118
  inputs=[search_in],
119
- outputs=[features],
120
  )
121
  gr.on(
122
  [btn2.click, query.submit],
123
  fn=query_dataset,
124
- inputs=[search_in, features, query],
125
  outputs=[sql_out, search_out],
126
  )
127
 
 
1
  import json
2
+ import logging
3
  import os
4
  import urllib.parse
5
+ from typing import Any
6
 
7
  import gradio as gr
8
  import requests
9
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
10
+ from huggingface_hub.repocard import CardData, RepoCard
11
 
 
12
 
13
+ logger = logging.getLogger(__name__)
14
+ example = HuggingfaceHubSearch().example_value()
 
 
15
 
16
 
17
  def get_iframe(hub_repo_id, sql_query=None):
 
33
  return iframe
34
 
35
 
36
+ def get_table_info(hub_repo_id):
37
  url: str = f"https://datasets-server.huggingface.co/info?dataset={hub_repo_id}"
38
  response = requests.get(url)
39
  try:
40
  data = response.json()
41
  data = data.get("dataset_info")
42
+ return json.dumps(data)
 
43
  except Exception as e:
44
  gr.Error(f"Error getting column info: {e}")
 
45
 
46
 
47
+ def get_table_name(config: str | None, split: str | None, config_choices: list[str], split_choices: list[str]):
48
+ if len(config_choices) > 0 and config is None:
49
+ config = config_choices[0]
50
+ if len(split_choices) > 0 and split is None:
51
+ split = split_choices[0]
52
+
53
+ if len(config_choices) > 1 and len(split_choices) > 1:
54
+ base_name = f"{config}_{split}"
55
+ elif len(config_choices) >= 1 and len(split_choices) <= 1:
56
+ base_name = config
57
+ else:
58
+ base_name = split
59
+
60
+ def replace_char(c):
61
+ if c.isalnum():
62
+ return c
63
+ if c in ["-", "_", "/"]:
64
+ return "_"
65
+ return ""
66
+
67
+ table_name = "".join(
68
+ replace_char(c) for c in base_name
69
+ )
70
+ if table_name[0].isdigit():
71
+ table_name = f"_{table_name}"
72
+ return table_name.lower()
73
+
74
+
75
+ def get_prompt_messages(card_data: dict[str, Any], natural_language_query: str):
76
+ config_choices = get_config_choices(card_data)
77
+ split_choices = get_split_choices(card_data)
78
+
79
+ chosen_config = config_choices[0] if len(config_choices) > 0 else None
80
+ chosen_split = split_choices[0] if len(split_choices) > 0 else None
81
+ table_name = get_table_name(chosen_config, chosen_split, config_choices, split_choices)
82
+ features = card_data[chosen_config]["features"]
83
  messages = [
84
  {
85
  "role": "system",
 
87
  },
88
  {
89
  "role": "user",
90
+ "content": f"""table {table_name}
91
  # Features
92
  {features}
93
 
94
  # Query
95
+ {natural_language_query}
96
  """,
97
  },
98
  ]
99
+ return messages
100
+
101
+
102
+ def get_config_choices(card_data: dict[str, Any]) -> list[str]:
103
+ return list(card_data.keys())
104
+
105
+
106
+ def get_split_choices(card_data: dict[str, Any]) -> list[str]:
107
+ splits = set()
108
+ for config in card_data.values():
109
+ splits.update(config.get("splits", {}).keys())
110
+
111
+ return list(splits)
112
+
113
+
114
+ def query_dataset(hub_repo_id, card_data, query):
115
+ card_data = json.loads(card_data)
116
+ messages = get_prompt_messages(card_data, query)
117
+ api_key = os.environ["API_KEY_TOGETHER_AI"].strip()
118
+ response = requests.post(
119
+ "https://api.together.xyz/v1/chat/completions",
120
+ json=dict(
121
+ model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
122
+ messages=messages,
123
+ max_tokens=1000,
124
+ ),
125
+ headers={"Authorization": f"Bearer {api_key}"},
126
  )
127
+
128
+ if response.status_code != 200:
129
+ logger.warning(response.text)
130
+
131
+ try:
132
+ response.raise_for_status()
133
+ except Exception as e:
134
+ gr.Error(f"Could not query LLM for suggestion: {e}")
135
+
136
+ response_dict = response.json()
137
+ duck_query = response_dict["choices"][0]["message"]["content"]
138
+ duck_query = _sanitize_duck_query(duck_query)
139
+ return duck_query, get_iframe(hub_repo_id, duck_query)
140
+
141
+ def _sanitize_duck_query(duck_query: str) -> str:
142
+ # Sometimes the LLM wraps the query like this:
143
+ # ```sql
144
+ # select * from x;
145
+ # ```
146
+ # This removes that wrapping if present.
147
+ if "```" not in duck_query:
148
+ return duck_query
149
+ start_idx = duck_query.index("```") + len("```")
150
+ end_idx = duck_query.rindex("```")
151
+ duck_query = duck_query[start_idx:end_idx]
152
+ if duck_query.startswith("sql\n"):
153
+ duck_query = duck_query.replace("sql\n", "", 1)
154
+ return duck_query
155
 
156
 
157
  with gr.Blocks() as demo:
 
188
  with gr.Row():
189
  search_out = gr.HTML(label="Search Results")
190
  with gr.Row():
191
+ card_data = gr.Code(label="Card data", language="json", visible=False)
192
  gr.on(
193
  [btn.click, search_in.submit],
194
  fn=get_iframe,
195
  inputs=[search_in],
196
  outputs=[search_out],
197
  ).then(
198
+ fn=get_table_info,
199
  inputs=[search_in],
200
+ outputs=[card_data],
201
  )
202
  gr.on(
203
  [btn2.click, query.submit],
204
  fn=query_dataset,
205
+ inputs=[search_in, card_data, query],
206
  outputs=[sql_out, search_out],
207
  )
208