richardr1126 commited on
Commit
f26b7dd
1 Parent(s): af675f5

Change gpt3.5 prompt

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -129,7 +129,7 @@ def extract_db_code(text):
129
  return [match.strip() for match in matches]
130
 
131
  def generate_dummy_db(db_info, question):
132
- pre_prompt = "Generate a SQLite database with dummy data for this database, output the SQL code in a SQL code block. Make sure you add dummy data relevant to the question.\n\n"
133
  prompt = pre_prompt + db_info + "\n\nQuestion: " + question
134
 
135
  while True:
@@ -137,6 +137,7 @@ def generate_dummy_db(db_info, question):
137
  response = openai.ChatCompletion.create(
138
  model="gpt-3.5-turbo",
139
  messages=[
 
140
  {"role": "user", "content": prompt}
141
  ],
142
  #temperature=0.7,
@@ -184,7 +185,7 @@ def test_query_on_dummy_db(db_code, query):
184
 
185
 
186
  def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0, repetition_penalty=1.08, format_sql=True, log=False, num_return_sequences=1, num_beams=1, do_sample=False):
187
- if num_return_sequences >= num_beams:
188
  gr.Warning("Num return sequences must be less than or equal to num beams.")
189
 
190
  stop_token_ids = tok.convert_tokens_to_ids(["###"])
 
129
  return [match.strip() for match in matches]
130
 
131
  def generate_dummy_db(db_info, question):
132
+ pre_prompt = "Generate a SQLite database with dummy data for this database. Make sure you add dummy data relevant to the question and don't write any SELECT statements or actual queries.\n\n"
133
  prompt = pre_prompt + db_info + "\n\nQuestion: " + question
134
 
135
  while True:
 
137
  response = openai.ChatCompletion.create(
138
  model="gpt-3.5-turbo",
139
  messages=[
140
+ {"role": "system", "content": "You are a SQLite dummy database generator. 1. You will create the specified dummy database. 2. Insert the dummy data. and 3. Output the only the code in a SQL code block."}
141
  {"role": "user", "content": prompt}
142
  ],
143
  #temperature=0.7,
 
185
 
186
 
187
  def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0, repetition_penalty=1.08, format_sql=True, log=False, num_return_sequences=1, num_beams=1, do_sample=False):
188
+ if num_return_sequences > num_beams:
189
  gr.Warning("Num return sequences must be less than or equal to num beams.")
190
 
191
  stop_token_ids = tok.convert_tokens_to_ids(["###"])