augray commited on
Commit
e602593
1 Parent(s): 2ebc3cc
Files changed (1) hide show
  1. app.py +23 -9
app.py CHANGED
@@ -44,7 +44,12 @@ def get_table_info(hub_repo_id):
44
  gr.Error(f"Error getting column info: {e}")
45
 
46
 
47
- def get_table_name(config: str | None, split: str | None, config_choices: list[str], split_choices: list[str]):
 
 
 
 
 
48
  if len(config_choices) > 0 and config is None:
49
  config = config_choices[0]
50
  if len(split_choices) > 0 and split is None:
@@ -63,16 +68,19 @@ def get_table_name(config: str | None, split: str | None, config_choices: list[s
63
  if c in ["-", "_", "/"]:
64
  return "_"
65
  return ""
66
-
67
- table_name = "".join(
68
- replace_char(c) for c in base_name
69
- )
70
  if table_name[0].isdigit():
71
  table_name = f"_{table_name}"
72
  return table_name.lower()
73
 
74
 
75
- def get_prompt_messages(card_data: dict[str, Any], natural_language_query: str, config: str | None, split: str | None):
 
 
 
 
 
76
  config_choices = get_config_choices(card_data)
77
  split_choices = get_split_choices(card_data)
78
 
@@ -136,6 +144,7 @@ def query_dataset(hub_repo_id, card_data, query, config, split):
136
  duck_query = _sanitize_duck_query(duck_query)
137
  return duck_query, get_iframe(hub_repo_id, duck_query)
138
 
 
139
  def _sanitize_duck_query(duck_query: str) -> str:
140
  # Sometimes the LLM wraps the query like this:
141
  # ```sql
@@ -190,14 +199,18 @@ with gr.Blocks() as demo:
190
  except Exception:
191
  config_choices = []
192
  split_choices = []
193
-
194
  initial_config = config_choices[0] if len(config_choices) > 0 else None
195
  initial_split = split_choices[0] if len(split_choices) > 0 else None
196
  with gr.Row():
197
  with gr.Column():
198
- config_selection = gr.Dropdown(label="Config Name", choices=config_choices, value=initial_config)
 
 
199
  with gr.Column():
200
- split_selection = gr.Dropdown(label="Split Name", choices=split_choices, value=initial_split)
 
 
201
 
202
  with gr.Row():
203
  with gr.Column():
@@ -223,5 +236,6 @@ with gr.Blocks() as demo:
223
  outputs=[sql_out, search_out],
224
  )
225
 
 
226
  if __name__ == "__main__":
227
  demo.launch()
 
44
  gr.Error(f"Error getting column info: {e}")
45
 
46
 
47
+ def get_table_name(
48
+ config: str | None,
49
+ split: str | None,
50
+ config_choices: list[str],
51
+ split_choices: list[str],
52
+ ):
53
  if len(config_choices) > 0 and config is None:
54
  config = config_choices[0]
55
  if len(split_choices) > 0 and split is None:
 
68
  if c in ["-", "_", "/"]:
69
  return "_"
70
  return ""
71
+
72
+ table_name = "".join(replace_char(c) for c in base_name)
 
 
73
  if table_name[0].isdigit():
74
  table_name = f"_{table_name}"
75
  return table_name.lower()
76
 
77
 
78
+ def get_prompt_messages(
79
+ card_data: dict[str, Any],
80
+ natural_language_query: str,
81
+ config: str | None,
82
+ split: str | None,
83
+ ):
84
  config_choices = get_config_choices(card_data)
85
  split_choices = get_split_choices(card_data)
86
 
 
144
  duck_query = _sanitize_duck_query(duck_query)
145
  return duck_query, get_iframe(hub_repo_id, duck_query)
146
 
147
+
148
  def _sanitize_duck_query(duck_query: str) -> str:
149
  # Sometimes the LLM wraps the query like this:
150
  # ```sql
 
199
  except Exception:
200
  config_choices = []
201
  split_choices = []
202
+
203
  initial_config = config_choices[0] if len(config_choices) > 0 else None
204
  initial_split = split_choices[0] if len(split_choices) > 0 else None
205
  with gr.Row():
206
  with gr.Column():
207
+ config_selection = gr.Dropdown(
208
+ label="Config Name", choices=config_choices, value=initial_config
209
+ )
210
  with gr.Column():
211
+ split_selection = gr.Dropdown(
212
+ label="Split Name", choices=split_choices, value=initial_split
213
+ )
214
 
215
  with gr.Row():
216
  with gr.Column():
 
236
  outputs=[sql_out, search_out],
237
  )
238
 
239
+
240
  if __name__ == "__main__":
241
  demo.launch()