Spaces:
Sleeping
Sleeping
import json | |
import re | |
from tqdm import tqdm | |
import os | |
import asyncio | |
from openai import AsyncOpenAI | |
from utils.api_utils import generate_from_openai_chat_completion, generate_from_claude_chat_completion | |
def construct_prompt_textonly(question: str, options: list, answer: str, answer_analysis: str) -> str: | |
optionized_list = [f"{chr(65 + i)}. {option}" for i, option in enumerate(options)] | |
optionized_str = "\n".join(optionized_list) | |
prompt = f""" | |
Generate a multiple-choice question with additional distractors that increase the complexity of answer selection. Follow these instructions: | |
1. **Retain Original Structure**: Retain the original question and options. | |
2. **Add Three Distractors**: Add three new distractors that are **plausible and maintain professional validity**. These should increase the difficulty but still be incorrect, based on the original question and answer analysis. | |
3. **Use Answer Analysis**: Reference the **correct answer analysis** when creating distractors to ensure they challenge **subject-matter experts**. | |
4. **Expert-Level Difficulty**: Keep the distractors **challenging and hard to distinguish** from the correct answer, requiring **advanced knowledge** to avoid the correct answer being too obvious. | |
5. **Balanced Length**: Ensure all options have **similar lengths** to prevent any one option from standing out. | |
6. **Distractors Analysis**: Provide a **distractor analysis in Chinese**, explaining why the distractors are **incorrect** but **challenging and hard to distinguish**. | |
Please output the result in valid JSON format using the structure below. Make sure there are no extra commas, missing commas, extra quotation marks or missing quotation marks: | |
{{ | |
"question": "{question}", | |
"options": {{ | |
"A": "{options[0]}", | |
"B": "{options[1]}", | |
"C": "{options[2]}", | |
"D": "{options[3]}" | |
}}, | |
"distractors": {{ | |
"E": "New distractor 1", | |
"F": "New distractor 2", | |
"G": "New distractor 3", | |
"analysis_of_distractors": "Use Chinese to explain why the distractors are **incorrect** but **challenging and hard to distinguish**, based on the question, options, and answer analysis.", | |
}}, | |
"correct_answer": "{answer}", | |
}} | |
Input: | |
Question: {question} | |
Options: | |
{optionized_str} | |
Answer: {answer} | |
Answer Analysis: {answer_analysis} | |
""" | |
# prompt = prompt.replace("I don't know.", "Idle.") | |
return prompt | |
def prepare_q_text_input(query, prompt_func=construct_prompt_textonly): | |
question = query['question'] | |
options = [query['option_1'], query['option_2'], query['option_3'], query['option_4']] | |
gt = query['answer'] | |
answer_analysis = query['answer_analysis'] | |
q_text_prompt = prompt_func(question=question, options=options, answer=gt, answer_analysis=answer_analysis) | |
return q_text_prompt | |
def prepare_q_inputs(queries): | |
messages = [] | |
for i, query in enumerate(queries): | |
q_text_prompt = prepare_q_text_input(query) | |
prompt_message = [ | |
{ | |
"role": "user", | |
"content": q_text_prompt, | |
}, | |
] | |
messages.append(prompt_message) | |
return messages | |
# def extract_json_from_text(text): | |
# text = json.dumps(text) | |
# # 移除转义符和换行符 | |
# text = text.replace('\\n', '').replace('\\"', '"') | |
# # 定义匹配 JSON 对象的正则表达式模式 | |
# json_pattern = re.compile( | |
# r'\{\s*"question":\s*"([^"]*)",\s*"options":\s*\{\s*"A":\s*"([^"]*)",\s*"B":\s*"([^"]*)",\s*"C":\s*"([^"]*)",\s*"D":\s*"([^"]*)"\s*\},' | |
# r'\s*"distractors":\s*\{\s*"E":\s*"([^"]*)",\s*"F":\s*"([^"]*)",\s*"G":\s*"([^"]*)"\s*\},\s*"correct_answer":\s*"([^"]*)"\s*\}', | |
# re.DOTALL | |
# ) | |
# # 匹配 JSON 结构 | |
# match = json_pattern.search(text) | |
# if match: | |
# # 捕获到的匹配组 | |
# question = match.group(1) | |
# option_a = match.group(2) | |
# option_b = match.group(3) | |
# option_c = match.group(4) | |
# option_d = match.group(5) | |
# distractor_e = match.group(6) | |
# distractor_f = match.group(7) | |
# distractor_g = match.group(8) | |
# correct_answer = match.group(9) | |
# # 构建 JSON 对象 | |
# json_data = { | |
# "question": question, | |
# "options": { | |
# "A": option_a, | |
# "B": option_b, | |
# "C": option_c, | |
# "D": option_d | |
# }, | |
# "distractors": { | |
# "E": distractor_e, | |
# "F": distractor_f, | |
# "G": distractor_g | |
# }, | |
# "correct_answer": correct_answer | |
# } | |
# return json_data | |
# else: | |
# print("No JSON object found in the text.") | |
# return None | |
def generate_distractors(model_name: str, | |
queries: list, | |
n: int=1, | |
max_tokens: int=4096): | |
assert model_name in ["gpt-4o-mini", "gpt-4-turbo", "gpt-4o", "gpt-4o-2024-08-06"], "Invalid model name" | |
client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"),base_url="https://yanlp.zeabur.app/v1") | |
messages = prepare_q_inputs(queries) | |
responses = asyncio.run( | |
generate_from_openai_chat_completion( | |
client, | |
messages=messages, | |
engine_name=model_name, | |
n = n, | |
max_tokens=max_tokens, | |
requests_per_minute=30, | |
json_format=True | |
) | |
) | |
for query, response in zip(queries, responses): | |
new_options = response | |
# print(new_options) | |
if new_options and "distractors" in new_options: | |
query["option_5"] = new_options["distractors"].get("E", "") | |
else: | |
query["option_5"] = "" | |
if new_options and "distractors" in new_options: | |
query["option_6"] = new_options["distractors"].get("F", "") | |
else: | |
query["option_6"] = "" | |
if new_options and "distractors" in new_options: | |
query["option_7"] = new_options["distractors"].get("G", "") | |
else: | |
query["option_7"] = "" | |
if new_options and "distractors" in new_options: | |
query["distractor_analysis"] = new_options["distractors"].get("analysis_of_distractors", "") | |
else: | |
query["distractor_analysis"] = "" | |
return queries | |