|
"""Generate answers with GPT-3.5""" |
|
|
|
import argparse |
|
import json |
|
import os |
|
import time |
|
import concurrent.futures |
|
|
|
import openai |
|
import tqdm |
|
import shortuuid |
|
|
|
MODEL = 'gpt-3.5-turbo' |
|
MODEL_ID = 'gpt-3.5-turbo:20230327' |
|
|
|
def get_answer(question_id: int, question: str, max_tokens: int): |
|
ans = { |
|
'answer_id': shortuuid.uuid(), |
|
'question_id': question_id, |
|
'model_id': MODEL_ID, |
|
} |
|
for _ in range(3): |
|
try: |
|
response = openai.ChatCompletion.create( |
|
model=MODEL, |
|
messages=[{ |
|
'role': 'system', |
|
'content': 'You are a helpful assistant.' |
|
}, { |
|
'role': 'user', |
|
'content': question, |
|
}], |
|
max_tokens=max_tokens, |
|
) |
|
ans['text'] = response['choices'][0]['message']['content'] |
|
return ans |
|
except Exception as e: |
|
print('[ERROR]', e) |
|
ans['text'] = '#ERROR#' |
|
time.sleep(1) |
|
return ans |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description='ChatGPT answer generation.') |
|
parser.add_argument('-q', '--question') |
|
parser.add_argument('-o', '--output') |
|
parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') |
|
args = parser.parse_args() |
|
|
|
questions_dict = {} |
|
with open(os.path.expanduser(args.question)) as f: |
|
for line in f: |
|
if not line: |
|
continue |
|
q = json.loads(line) |
|
questions_dict[q['question_id']] = q['text'] |
|
|
|
answers = [] |
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: |
|
futures = [] |
|
for qid, question in questions_dict.items(): |
|
future = executor.submit(get_answer, qid, question, args.max_tokens) |
|
futures.append(future) |
|
|
|
for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): |
|
answers.append(future.result()) |
|
|
|
answers.sort(key=lambda x: x['question_id']) |
|
|
|
with open(os.path.expanduser(args.output), 'w') as f: |
|
table = [json.dumps(ans) for ans in answers] |
|
f.write('\n'.join(table)) |
|
|