richardr1126
commited on
Commit
•
c3f3890
1
Parent(s):
1b1fce9
Chooses best query from chatgpt
Browse files
app.py
CHANGED
@@ -131,6 +131,11 @@ def extract_db_code(text):
|
|
131 |
matches = re.findall(pattern, text, re.DOTALL)
|
132 |
return [match.strip() for match in matches]
|
133 |
|
|
|
|
|
|
|
|
|
|
|
134 |
def generate_dummy_db(db_info, question):
|
135 |
pre_prompt = """
|
136 |
Generate a SQLite database with dummy data for this database from the DB Layout. Your task is to generate just a database, no queries. For each input do the following:
|
@@ -188,6 +193,36 @@ def test_query_on_dummy_db(db_code, query):
|
|
188 |
print(f"Query: {query}\tError encountered: {e}")
|
189 |
return False
|
190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0, repetition_penalty=1.08, format_sql=True, log=False, num_return_sequences=1, num_beams=1, do_sample=False):
|
193 |
if num_return_sequences > num_beams:
|
@@ -246,15 +281,18 @@ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0
|
|
246 |
query = query.replace("\n", " ").replace("\t", " ").strip()
|
247 |
# Test against dummy database
|
248 |
success = test_query_on_dummy_db(db_code, query)
|
249 |
-
|
250 |
-
query = format(query) if format_sql else query
|
251 |
if success:
|
252 |
responses.append(query)
|
253 |
else:
|
254 |
responses.append(query)
|
255 |
|
256 |
-
# Choose
|
257 |
-
|
|
|
|
|
|
|
|
|
258 |
|
259 |
if log:
|
260 |
# Log the request to Firestore
|
|
|
131 |
matches = re.findall(pattern, text, re.DOTALL)
|
132 |
return [match.strip() for match in matches]
|
133 |
|
134 |
+
def extract_from_code_block(text):
|
135 |
+
pattern = r'```(?:\w+)?\s?(.*?)```'
|
136 |
+
match = re.search(pattern, text, re.DOTALL)
|
137 |
+
return match.group(1).strip() if match else ''
|
138 |
+
|
139 |
def generate_dummy_db(db_info, question):
|
140 |
pre_prompt = """
|
141 |
Generate a SQLite database with dummy data for this database from the DB Layout. Your task is to generate just a database, no queries. For each input do the following:
|
|
|
193 |
print(f"Query: {query}\tError encountered: {e}")
|
194 |
return False
|
195 |
|
196 |
+
def choose_best_query(queries, question):
|
197 |
+
pre_prompt = """
|
198 |
+
Given a list of queries. Your task is to choose just a single query which satisfies the question the most with the least amount of filters, groupings, and conditions. For each input do the following:
|
199 |
+
1. Breakdown the list of queries into small pieces and explain what each query is doing.
|
200 |
+
2. Explain why each query is relevant to the question.
|
201 |
+
3. Choose the most relevant query from your explanation that aligns to the question best with the least amount of unnecessary filters or conditions. Output the best query in a single code block ``````.
|
202 |
+
"""
|
203 |
+
prompt = pre_prompt + "\n\nQuestion: " + question + "\n\nQueries:" + "\n\n".join(queries)
|
204 |
+
|
205 |
+
while True:
|
206 |
+
try:
|
207 |
+
response = openai.ChatCompletion.create(
|
208 |
+
model="gpt-3.5-turbo",
|
209 |
+
messages=[
|
210 |
+
{"role": "user", "content": prompt}
|
211 |
+
],
|
212 |
+
#temperature=0.7,
|
213 |
+
)
|
214 |
+
response_text = response['choices'][0]['message']['content']
|
215 |
+
print(response_text)
|
216 |
+
|
217 |
+
query = extract_from_code_block(response_text)
|
218 |
+
|
219 |
+
return query
|
220 |
+
|
221 |
+
except Exception as e:
|
222 |
+
print(f'Error occurred: {str(e)}')
|
223 |
+
print('Waiting for 10 seconds before retrying...')
|
224 |
+
time.sleep(10)
|
225 |
+
|
226 |
|
227 |
def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0, repetition_penalty=1.08, format_sql=True, log=False, num_return_sequences=1, num_beams=1, do_sample=False):
|
228 |
if num_return_sequences > num_beams:
|
|
|
281 |
query = query.replace("\n", " ").replace("\t", " ").strip()
|
282 |
# Test against dummy database
|
283 |
success = test_query_on_dummy_db(db_code, query)
|
284 |
+
|
|
|
285 |
if success:
|
286 |
responses.append(query)
|
287 |
else:
|
288 |
responses.append(query)
|
289 |
|
290 |
+
# Choose the best query if num_return_sequences > 1
|
291 |
+
if num_return_sequences > 1:
|
292 |
+
query = choose_best_query(responses, input_message)
|
293 |
+
# Format again
|
294 |
+
query = format(query) if format_sql else query
|
295 |
+
responses = [query]
|
296 |
|
297 |
if log:
|
298 |
# Log the request to Firestore
|