richardr1126 commited on
Commit
097b91f
β€’
1 Parent(s): 819c99b

testing new generation strategy

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +181 -59
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.37.0
8
- app_file: app-ngrok.py
9
  pinned: true
10
  license: bigcode-openrail-m
11
  tags:
 
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.37.0
8
+ app_file: app.py
9
  pinned: true
10
  license: bigcode-openrail-m
11
  tags:
app.py CHANGED
@@ -1,15 +1,109 @@
1
  import os
2
- from threading import Event, Thread
 
 
 
 
 
3
  from transformers import (
4
  AutoModelForCausalLM,
5
  AutoTokenizer,
6
  StoppingCriteria,
7
  StoppingCriteriaList,
8
- TextIteratorStreamer,
9
  )
10
- import gradio as gr
 
 
 
 
 
11
  import torch
12
- import sqlparse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  model_name = os.getenv("HF_MODEL_NAME", None)
15
  tok = AutoTokenizer.from_pretrained(model_name)
@@ -24,22 +118,20 @@ m = AutoModelForCausalLM.from_pretrained(
24
  #load_in_8bit=True,
25
  )
26
 
27
- m.config.pad_token_id = m.config.eos_token_id
28
- m.generation_config.pad_token_id = m.config.eos_token_id
29
-
30
- stop_tokens = [";", "###", "Result"]
31
- stop_token_ids = tok.convert_tokens_to_ids(stop_tokens)
32
 
33
  print(f"Successfully loaded the model {model_name} into memory")
34
 
35
- class StopOnTokens(StoppingCriteria):
36
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
37
- for stop_id in stop_token_ids:
38
- if input_ids[0][-1] == stop_id:
39
- return True
40
- return False
41
 
42
- def bot(input_message: str, db_info="", temperature=0.1, top_p=0.9, top_k=0, repetition_penalty=1.08):
 
 
 
 
 
 
 
43
  stop = StopOnTokens()
44
 
45
  # Format the user's input message
@@ -52,12 +144,14 @@ def bot(input_message: str, db_info="", temperature=0.1, top_p=0.9, top_k=0, rep
52
  input_ids=input_ids,
53
  max_new_tokens=max_new_tokens,
54
  temperature=temperature,
55
- do_sample=temperature > 0.0,
56
  top_p=top_p,
57
  top_k=top_k,
58
  repetition_penalty=repetition_penalty,
59
  streamer=streamer,
60
  stopping_criteria=StoppingCriteriaList([stop]),
 
 
 
61
  )
62
 
63
  stream_complete = Event()
@@ -73,64 +167,92 @@ def bot(input_message: str, db_info="", temperature=0.1, top_p=0.9, top_k=0, rep
73
  for new_text in streamer:
74
  partial_text += new_text
75
 
76
- # Split the text by "|", and get the last element in the list which should be the final query
77
- try:
78
- final_query = partial_text.split("|")[1].strip()
79
- except Exception:
80
- final_query = partial_text
81
 
82
- try:
83
- # Attempt to format SQL query using sqlparse
84
- formatted_query = sqlparse.format(final_query, reindent=True, keyword_case='upper')
85
- except Exception:
86
- # If formatting fails, use the original, unformatted query
87
- formatted_query = final_query
88
 
89
- # Convert SQL to markdown (not required, but just to show how to use the markdown module)
90
- final_query_markdown = f"{formatted_query}"
91
- return final_query_markdown
92
 
 
93
  with gr.Blocks(theme='gradio/soft') as demo:
 
94
  header = gr.HTML("""
95
- <h1 style="text-align: center">SQL Skeleton WizardCoder Demo</h1>
96
- <h3 style="text-align: center">πŸ§™β€β™‚οΈ Generate SQL queries from Natural Language πŸ§™β€β™‚οΈ</h3>
 
 
 
97
  """)
98
 
99
- output_box = gr.Code(label="Generated SQL", lines=2, interactive=True)
 
 
 
 
 
100
  input_text = gr.Textbox(lines=3, placeholder='Write your question here...', label='NL Input')
101
- db_info = gr.Textbox(lines=4, placeholder='Example: | table_01 : column_01 , column_02 | table_02 : column_01 , column_02 | ...', label='Database Info')
 
 
 
 
 
102
 
103
- with gr.Accordion("Hyperparameters", open=False):
104
- temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
105
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.0, maximum=1.0, value=0.9, step=0.01)
106
  top_k = gr.Slider(label="Top-k", minimum=0, maximum=200, value=0, step=1)
107
  repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.08, step=0.01)
 
 
 
 
 
108
 
109
- run_button = gr.Button("Generate SQL", variant="primary")
110
-
111
- with gr.Accordion("Examples", open=True):
 
 
 
 
 
 
 
 
 
 
112
  examples = gr.Examples([
113
- ["What is the average, minimum, and maximum age for all French singers?", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
 
 
114
  ["Show location and name for all stadiums with a capacity between 5000 and 10000.", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
115
  ["What are the number of concerts that occurred in the stadium with the largest capacity ?", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
116
- ["How many male singers performed in concerts in the year 2023?", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
117
- ["List the names of all singers who performed in a concert with the theme 'Rock'", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"]
118
- ], inputs=[input_text, db_info, temperature, top_p, top_k, repetition_penalty], fn=bot)
119
-
120
- bitsandbytes_model = "richardr1126/spider-skeleton-wizard-coder-8bit"
121
- merged_model = "richardr1126/spider-skeleton-wizard-coder-merged"
122
- initial_model = "WizardLM/WizardCoder-15B-V1.0"
123
- finetuned_model = "richardr1126/spider-skeleton-wizard-coder-qlora"
124
- dataset = "richardr1126/spider-skeleton-context-instruct"
125
-
126
- footer = gr.HTML(f"""
127
- <p>πŸ› οΈ If you want you can <strong>duplicate this Space</strong>, then change the HF_MODEL_REPO spaces env varaible to use any Transformers model.</p>
128
- <p>🌐 Leveraging the <a href='https://huggingface.co/{bitsandbytes_model}'><strong>bitsandbytes 8-bit version</strong></a> of <a href='https://huggingface.co/{merged_model}'><strong>{merged_model}</strong></a> model.</p>
129
- <p>πŸ”— How it's made: <a href='https://huggingface.co/{initial_model}'><strong>{initial_model}</strong></a> was finetuned to create <a href='https://huggingface.co/{finetuned_model}'><strong>{finetuned_model}</strong></a>, then merged together to create <a href='https://huggingface.co/{merged_model}'><strong>{merged_model}</strong></a>.</p>
130
- <p>πŸ“‰ Fine-tuning was performed using QLoRA techniques on the <a href='https://huggingface.co/datasets/{dataset}'><strong>{dataset}</strong></a> dataset. You can view training metrics on the <a href='https://huggingface.co/{finetuned_model}'><strong>QLoRa adapter HF Repo</strong></a>.</p>
131
- """)
132
 
133
 
134
- run_button.click(fn=bot, inputs=[input_text, db_info, temperature, top_p, top_k, repetition_penalty], outputs=output_box, api_name="txt2sql")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- demo.queue(concurrency_count=1, max_size=10).launch()
 
1
  import os
2
+ import gradio as gr
3
+ import sqlparse
4
+ import requests
5
+ from time import sleep
6
+ import re
7
+ import platform
8
  from transformers import (
9
  AutoModelForCausalLM,
10
  AutoTokenizer,
11
  StoppingCriteria,
12
  StoppingCriteriaList,
13
+ TextIteratorStreamer
14
  )
15
+ from threading import Event, Thread
16
+ # Additional Firebase imports
17
+ import firebase_admin
18
+ from firebase_admin import credentials, firestore
19
+ import json
20
+ import base64
21
  import torch
22
+
23
+
24
+ print(f"Running on {platform.system()}")
25
+
26
+ if platform.system() == "Windows" or platform.system() == "Darwin":
27
+ from dotenv import load_dotenv
28
+ load_dotenv()
29
+
30
+ quantized_model = "richardr1126/spider-skeleton-wizard-coder-8bit"
31
+ merged_model = "richardr1126/spider-skeleton-wizard-coder-merged"
32
+ initial_model = "WizardLM/WizardCoder-15B-V1.0"
33
+ lora_model = "richardr1126/spider-skeleton-wizard-coder-qlora"
34
+ dataset = "richardr1126/spider-skeleton-context-instruct"
35
+
36
+ # Firebase code
37
+ # Initialize Firebase
38
+ base64_string = os.getenv('FIREBASE')
39
+ base64_bytes = base64_string.encode('utf-8')
40
+ json_bytes = base64.b64decode(base64_bytes)
41
+ json_data = json_bytes.decode('utf-8')
42
+
43
+ firebase_auth = json.loads(json_data)
44
+
45
+ # Load credentials and initialize Firestore
46
+ cred = credentials.Certificate(firebase_auth)
47
+ firebase_admin.initialize_app(cred)
48
+ db = firestore.client()
49
+
50
+ def log_message_to_firestore(input_message, db_info, temperature, response_text):
51
+ doc_ref = db.collection('codellama-logs').document()
52
+ log_data = {
53
+ 'timestamp': firestore.SERVER_TIMESTAMP,
54
+ 'temperature': temperature,
55
+ 'db_info': db_info,
56
+ 'input': input_message,
57
+ 'output': response_text,
58
+ }
59
+ doc_ref.set(log_data)
60
+
61
+ rated_outputs = set() # set to store already rated outputs
62
+
63
+ def log_rating_to_firestore(input_message, db_info, temperature, response_text, rating):
64
+ global rated_outputs
65
+ output_id = f"{input_message} {db_info} {response_text} {temperature}"
66
+
67
+ if output_id in rated_outputs:
68
+ gr.Warning("You've already rated this output!")
69
+ return
70
+ if not input_message or not response_text or not rating:
71
+ gr.Info("You haven't asked a question yet!")
72
+ return
73
+
74
+ rated_outputs.add(output_id)
75
+
76
+ doc_ref = db.collection('codellama-ratings').document()
77
+ log_data = {
78
+ 'timestamp': firestore.SERVER_TIMESTAMP,
79
+ 'temperature': temperature,
80
+ 'db_info': db_info,
81
+ 'input': input_message,
82
+ 'output': response_text,
83
+ 'rating': rating,
84
+ }
85
+ doc_ref.set(log_data)
86
+ gr.Info("Thanks for your feedback!")
87
+ # End Firebase code
88
+
89
+ def format(text):
90
+ # Split the text by "|", and get the last element in the list which should be the final query
91
+ try:
92
+ final_query = text.split("|")[1].strip()
93
+ except Exception:
94
+ final_query = text
95
+
96
+ try:
97
+ # Attempt to format SQL query using sqlparse
98
+ formatted_query = sqlparse.format(final_query, reindent=True, keyword_case='upper')
99
+ except Exception:
100
+ # If formatting fails, use the original, unformatted query
101
+ formatted_query = final_query
102
+
103
+ # Convert SQL to markdown (not required, but just to show how to use the markdown module)
104
+ final_query_markdown = f"{formatted_query}"
105
+
106
+ return final_query_markdown
107
 
108
  model_name = os.getenv("HF_MODEL_NAME", None)
109
  tok = AutoTokenizer.from_pretrained(model_name)
 
118
  #load_in_8bit=True,
119
  )
120
 
121
+ # m.config.pad_token_id = m.config.eos_token_id
122
+ # m.generation_config.pad_token_id = m.config.eos_token_id
 
 
 
123
 
124
  print(f"Successfully loaded the model {model_name} into memory")
125
 
 
 
 
 
 
 
126
 
127
+ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0, repetition_penalty=1.08, format_sql=True, log=False, num_return_sequences=1, num_beams=1, do_sample=False):
128
+ stop_token_ids = tok.convert_tokens_to_ids(["###"])
129
+ class StopOnTokens(StoppingCriteria):
130
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
131
+ for stop_id in stop_token_ids:
132
+ if input_ids[0][-1] == stop_id:
133
+ return True
134
+ return False
135
  stop = StopOnTokens()
136
 
137
  # Format the user's input message
 
144
  input_ids=input_ids,
145
  max_new_tokens=max_new_tokens,
146
  temperature=temperature,
 
147
  top_p=top_p,
148
  top_k=top_k,
149
  repetition_penalty=repetition_penalty,
150
  streamer=streamer,
151
  stopping_criteria=StoppingCriteriaList([stop]),
152
+ num_return_sequences=num_return_sequences,
153
+ num_beams=num_beams,
154
+ do_sample=do_sample,
155
  )
156
 
157
  stream_complete = Event()
 
167
  for new_text in streamer:
168
  partial_text += new_text
169
 
170
+ output = format(partial_text) if format_sql else partial_text
 
 
 
 
171
 
172
+ if log:
173
+ # Log the request to Firestore
174
+ log_message_to_firestore(input_message, db_info, temperature, output)
 
 
 
175
 
176
+ return output
 
 
177
 
178
+ # Gradio UI Code
179
  with gr.Blocks(theme='gradio/soft') as demo:
180
+ # Elements stack vertically by default just define elements in order you want them to stack
181
  header = gr.HTML("""
182
+ <h1 style="text-align: center">SQL CodeLlama Demo</h1>
183
+ <h3 style="text-align: center">πŸ•·οΈβ˜ οΈπŸ¦™ Generate SQL queries from Natural Language πŸ•·οΈβ˜ οΈπŸ§™πŸ¦™</h3>
184
+ <div style="max-width: 450px; margin: auto; text-align: center">
185
+ <p style="font-size: 12px; text-align: center">⚠️ Should take 30-60s to generate. Please rate the response, it helps a lot. If you get a blank output, the model server is currently down, please try again another time.</p>
186
+ </div>
187
  """)
188
 
189
+ output_box = gr.Code(label="Generated SQL", lines=2, interactive=False)
190
+
191
+ with gr.Row():
192
+ rate_up = gr.Button("πŸ‘", variant="secondary")
193
+ rate_down = gr.Button("πŸ‘Ž", variant="secondary")
194
+
195
  input_text = gr.Textbox(lines=3, placeholder='Write your question here...', label='NL Input')
196
+ db_info = gr.Textbox(lines=4, placeholder='Make sure to place your tables information inside || for better results. Example: | table_01 : column_01 , column_02 | table_02 : column_01 , column_02 | ...', label='Database Info')
197
+ format_sql = gr.Checkbox(label="Format SQL + Remove Skeleton", value=True, interactive=True)
198
+
199
+ with gr.Row():
200
+ run_button = gr.Button("Generate SQL", variant="primary")
201
+ clear_button = gr.ClearButton(variant="secondary")
202
 
203
+ with gr.Accordion("Options", open=False):
204
+ temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.2, step=0.1)
205
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.0, maximum=1.0, value=0.9, step=0.01)
206
  top_k = gr.Slider(label="Top-k", minimum=0, maximum=200, value=0, step=1)
207
  repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.08, step=0.01)
208
+
209
+ with gr.Accordion("Generation strategies", open=False):
210
+ num_return_sequences = gr.Slider(label="Num Return Sequences", minimum=1, maximum=5, value=1, step=1)
211
+ num_beams = gr.Slider(label="Num Beams", minimum=1, maximum=5, value=1, step=1)
212
+ do_sample = gr.Checkbox(label="Do Sample", value=False, interactive=True)
213
 
214
+ info = gr.HTML(f"""
215
+ <p>🌐 Leveraging the <a href='https://huggingface.co/{quantized_model}'><strong>bitsandbytes 8-bit version</strong></a> of <a href='https://huggingface.co/{merged_model}'><strong>{merged_model}</strong></a> model.</p>
216
+ <p>πŸ”— How it's made: <a href='https://huggingface.co/{initial_model}'><strong>{initial_model}</strong></a> was finetuned to create <a href='https://huggingface.co/{lora_model}'><strong>{lora_model}</strong></a>, then merged together to create <a href='https://huggingface.co/{merged_model}'><strong>{merged_model}</strong></a>.</p>
217
+ <p>πŸ“‰ Fine-tuning was performed using QLoRA techniques on the <a href='https://huggingface.co/datasets/{dataset}'><strong>{dataset}</strong></a> dataset. You can view training metrics on the <a href='https://huggingface.co/{lora_model}'><strong>QLoRa adapter HF Repo</strong></a>.</p>
218
+ <p>πŸ“Š All inputs/outputs are logged to Firebase to see how the model is doing. You can also leave a rating for each generated SQL the model produces, which gets sent to the database as well.</a></p>
219
+ """)
220
+
221
+ examples = gr.Examples([
222
+ ["What is the average, minimum, and maximum age of all singers from France?", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
223
+ ["How many students have dogs?", "| student : stuid , lname , fname , age , sex , major , advisor , city_code | has_pet : stuid , petid | pets : petid , pettype , pet_age , weight | has_pet.stuid = student.stuid | has_pet.petid = pets.petid | pets.pettype = 'Dog' |"],
224
+ ], inputs=[input_text, db_info, temperature, top_p, top_k, repetition_penalty, format_sql], fn=generate, cache_examples=False if platform.system() == "Windows" or platform.system() == "Darwin" else True, outputs=output_box)
225
+
226
+ with gr.Accordion("More Examples", open=False):
227
  examples = gr.Examples([
228
+ ["What is the average weight of pets of all students?", "| student : stuid , lname , fname , age , sex , major , advisor , city_code | has_pet : stuid , petid | pets : petid , pettype , pet_age , weight | has_pet.stuid = student.stuid | has_pet.petid = pets.petid |"],
229
+ ["How many male singers performed in concerts in the year 2023?", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
230
+ ["For students who have pets, how many pets does each student have? List their ids instead of names.", "| student : stuid , lname , fname , age , sex , major , advisor , city_code | has_pet : stuid , petid | pets : petid , pettype , pet_age , weight | has_pet.stuid = student.stuid | has_pet.petid = pets.petid |"],
231
  ["Show location and name for all stadiums with a capacity between 5000 and 10000.", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
232
  ["What are the number of concerts that occurred in the stadium with the largest capacity ?", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
233
+ ["Which student has the oldest pet?", "| student : stuid , lname , fname , age , sex , major , advisor , city_code | has_pet : stuid , petid | pets : petid , pettype , pet_age , weight | has_pet.stuid = student.stuid | has_pet.petid = pets.petid |"],
234
+ ["List the names of all singers who performed in a concert with the theme 'Rock'", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
235
+ ["List all students who don't have pets.", "| student : stuid , lname , fname , age , sex , major , advisor , city_code | has_pet : stuid , petid | pets : petid , pettype , pet_age , weight | has_pet.stuid = student.stuid | has_pet.petid = pets.petid |"],
236
+ ], inputs=[input_text, db_info, temperature, top_p, top_k, repetition_penalty, format_sql], fn=generate, cache_examples=False, outputs=output_box)
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
 
239
+ readme_content = requests.get(f"https://huggingface.co/{merged_model}/raw/main/README.md").text
240
+ readme_content = re.sub('---.*?---', '', readme_content, flags=re.DOTALL) #Remove YAML front matter
241
+
242
+ with gr.Accordion("πŸ“– Model Readme", open=True):
243
+ readme = gr.Markdown(
244
+ readme_content,
245
+ )
246
+
247
+ with gr.Accordion("Disabled Options:", open=False):
248
+ log = gr.Checkbox(label="Log to Firebase", value=True, interactive=False)
249
+
250
+ # When the button is clicked, call the generate function, inputs are taken from the UI elements, outputs are sent to outputs elements
251
+ run_button.click(fn=generate, inputs=[input_text, db_info, temperature, top_p, top_k, repetition_penalty, format_sql, log, num_return_sequences, num_beams, do_sample], outputs=output_box, api_name="txt2sql")
252
+ clear_button.add([input_text, db_info, output_box])
253
+
254
+ # Firebase code - for rating the generated SQL (remove if you don't want to use Firebase)
255
+ rate_up.click(fn=log_rating_to_firestore, inputs=[input_text, db_info, temperature, output_box, rate_up])
256
+ rate_down.click(fn=log_rating_to_firestore, inputs=[input_text, db_info, temperature, output_box, rate_down])
257
 
258
+ demo.queue(concurrency_count=1, max_size=20).launch(debug=True)