richardr1126 commited on
Commit
c3f3890
1 Parent(s): 1b1fce9

Chooses best query from chatgpt

Browse files
Files changed (1) hide show
  1. app.py +42 -4
app.py CHANGED
@@ -131,6 +131,11 @@ def extract_db_code(text):
131
  matches = re.findall(pattern, text, re.DOTALL)
132
  return [match.strip() for match in matches]
133
 
 
 
 
 
 
134
  def generate_dummy_db(db_info, question):
135
  pre_prompt = """
136
  Generate a SQLite database with dummy data for this database from the DB Layout. Your task is to generate just a database, no queries. For each input do the following:
@@ -188,6 +193,36 @@ def test_query_on_dummy_db(db_code, query):
188
  print(f"Query: {query}\tError encountered: {e}")
189
  return False
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0, repetition_penalty=1.08, format_sql=True, log=False, num_return_sequences=1, num_beams=1, do_sample=False):
193
  if num_return_sequences > num_beams:
@@ -246,15 +281,18 @@ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0
246
  query = query.replace("\n", " ").replace("\t", " ").strip()
247
  # Test against dummy database
248
  success = test_query_on_dummy_db(db_code, query)
249
- # Format again
250
- query = format(query) if format_sql else query
251
  if success:
252
  responses.append(query)
253
  else:
254
  responses.append(query)
255
 
256
- # Choose a random response from responses
257
- output = responses[0] if len(responses) > 0 else "###"
 
 
 
 
258
 
259
  if log:
260
  # Log the request to Firestore
 
131
  matches = re.findall(pattern, text, re.DOTALL)
132
  return [match.strip() for match in matches]
133
 
134
+ def extract_from_code_block(text):
135
+ pattern = r'```(?:\w+)?\s?(.*?)```'
136
+ match = re.search(pattern, text, re.DOTALL)
137
+ return match.group(1).strip() if match else ''
138
+
139
  def generate_dummy_db(db_info, question):
140
  pre_prompt = """
141
  Generate a SQLite database with dummy data for this database from the DB Layout. Your task is to generate just a database, no queries. For each input do the following:
 
193
  print(f"Query: {query}\tError encountered: {e}")
194
  return False
195
 
196
+ def choose_best_query(queries, question):
197
+ pre_prompt = """
198
+ Given a list of queries. Your task is to choose just a single query which satisfies the question the most with the least amount of filters, groupings, and conditions. For each input do the following:
199
+ 1. Breakdown the list of queries into small pieces and explain what each query is doing.
200
+ 2. Explain why each query is relevant to the question.
201
+ 3. Choose the most relevant query from your explanation that aligns to the question best with the least amount of unnecessary filters or conditions. Output the best query in a single code block ``````.
202
+ """
203
+ prompt = pre_prompt + "\n\nQuestion: " + question + "\n\nQueries:" + "\n\n".join(queries)
204
+
205
+ while True:
206
+ try:
207
+ response = openai.ChatCompletion.create(
208
+ model="gpt-3.5-turbo",
209
+ messages=[
210
+ {"role": "user", "content": prompt}
211
+ ],
212
+ #temperature=0.7,
213
+ )
214
+ response_text = response['choices'][0]['message']['content']
215
+ print(response_text)
216
+
217
+ query = extract_from_code_block(response_text)
218
+
219
+ return query
220
+
221
+ except Exception as e:
222
+ print(f'Error occurred: {str(e)}')
223
+ print('Waiting for 10 seconds before retrying...')
224
+ time.sleep(10)
225
+
226
 
227
  def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0, repetition_penalty=1.08, format_sql=True, log=False, num_return_sequences=1, num_beams=1, do_sample=False):
228
  if num_return_sequences > num_beams:
 
281
  query = query.replace("\n", " ").replace("\t", " ").strip()
282
  # Test against dummy database
283
  success = test_query_on_dummy_db(db_code, query)
284
+
 
285
  if success:
286
  responses.append(query)
287
  else:
288
  responses.append(query)
289
 
290
+ # Choose the best query if num_return_sequences > 1
291
+ if num_return_sequences > 1:
292
+ query = choose_best_query(responses, input_message)
293
+ # Format again
294
+ query = format(query) if format_sql else query
295
+ responses = [query]
296
 
297
  if log:
298
  # Log the request to Firestore