richardr1126
commited on
Commit
•
f26b7dd
1
Parent(s):
af675f5
Change gpt3.5 prompt
Browse files
app.py
CHANGED
@@ -129,7 +129,7 @@ def extract_db_code(text):
|
|
129 |
return [match.strip() for match in matches]
|
130 |
|
131 |
def generate_dummy_db(db_info, question):
|
132 |
-
pre_prompt = "Generate a SQLite database with dummy data for this database
|
133 |
prompt = pre_prompt + db_info + "\n\nQuestion: " + question
|
134 |
|
135 |
while True:
|
@@ -137,6 +137,7 @@ def generate_dummy_db(db_info, question):
|
|
137 |
response = openai.ChatCompletion.create(
|
138 |
model="gpt-3.5-turbo",
|
139 |
messages=[
|
|
|
140 |
{"role": "user", "content": prompt}
|
141 |
],
|
142 |
#temperature=0.7,
|
@@ -184,7 +185,7 @@ def test_query_on_dummy_db(db_code, query):
|
|
184 |
|
185 |
|
186 |
def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0, repetition_penalty=1.08, format_sql=True, log=False, num_return_sequences=1, num_beams=1, do_sample=False):
|
187 |
-
if num_return_sequences
|
188 |
gr.Warning("Num return sequences must be less than or equal to num beams.")
|
189 |
|
190 |
stop_token_ids = tok.convert_tokens_to_ids(["###"])
|
|
|
129 |
return [match.strip() for match in matches]
|
130 |
|
131 |
def generate_dummy_db(db_info, question):
|
132 |
+
pre_prompt = "Generate a SQLite database with dummy data for this database. Make sure you add dummy data relevant to the question and don't write any SELECT statements or actual queries.\n\n"
|
133 |
prompt = pre_prompt + db_info + "\n\nQuestion: " + question
|
134 |
|
135 |
while True:
|
|
|
137 |
response = openai.ChatCompletion.create(
|
138 |
model="gpt-3.5-turbo",
|
139 |
messages=[
|
140 |
+
{"role": "system", "content": "You are a SQLite dummy database generator. 1. You will create the specified dummy database. 2. Insert the dummy data. and 3. Output the only the code in a SQL code block."}
|
141 |
{"role": "user", "content": prompt}
|
142 |
],
|
143 |
#temperature=0.7,
|
|
|
185 |
|
186 |
|
187 |
def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0, repetition_penalty=1.08, format_sql=True, log=False, num_return_sequences=1, num_beams=1, do_sample=False):
|
188 |
+
if num_return_sequences > num_beams:
|
189 |
gr.Warning("Num return sequences must be less than or equal to num beams.")
|
190 |
|
191 |
stop_token_ids = tok.convert_tokens_to_ids(["###"])
|