File size: 6,432 Bytes
e9ce3e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import json
import re
from tqdm import tqdm
import os
import asyncio
from openai import AsyncOpenAI

from utils.api_utils import generate_from_openai_chat_completion, generate_from_claude_chat_completion


def construct_prompt_textonly(question: str, options: list, answer: str, answer_analysis: str) -> str:    
    optionized_list = [f"{chr(65 + i)}. {option}" for i, option in enumerate(options)]
    optionized_str = "\n".join(optionized_list)
    
    prompt = f"""
Generate a multiple-choice question with additional distractors that increase the complexity of answer selection. Follow these instructions:
1. **Retain Original Structure**: Retain the original question and options.
2. **Add Three Distractors**: Add three new distractors that are **plausible and maintain professional validity**. These should increase the difficulty but still be incorrect, based on the original question and answer analysis.
3. **Use Answer Analysis**: Reference the **correct answer analysis** when creating distractors to ensure they challenge **subject-matter experts**.
4. **Expert-Level Difficulty**: Keep the distractors **challenging and hard to distinguish** from the correct answer, requiring **advanced knowledge** to avoid the correct answer being too obvious.
5. **Balanced Length**: Ensure all options have **similar lengths** to prevent any one option from standing out.
6. **Distractors Analysis**: Provide a **distractor analysis in Chinese**, explaining why the distractors are **incorrect** but **challenging and hard to distinguish**.

Please output the result in valid JSON format using the structure below. Make sure there are no extra commas, missing commas, extra quotation marks or missing quotation marks:
{{
  "question": "{question}",
  "options": {{
    "A": "{options[0]}",
    "B": "{options[1]}",
    "C": "{options[2]}",
    "D": "{options[3]}"
  }},
  "distractors": {{
    "E": "New distractor 1",
    "F": "New distractor 2",
    "G": "New distractor 3",
    "analysis_of_distractors": "Use Chinese to explain why the distractors are **incorrect** but **challenging and hard to distinguish**, based on the question, options, and answer analysis.",
  }},
  "correct_answer": "{answer}",
}}

Input:
Question: {question}
Options:
{optionized_str}
Answer: {answer}
Answer Analysis: {answer_analysis}
"""
    
    # prompt = prompt.replace("I don't know.", "Idle.")
    return prompt


def prepare_q_text_input(query, prompt_func=construct_prompt_textonly):
    question = query['question']
    options = [query['option_1'], query['option_2'], query['option_3'], query['option_4']]
    gt = query['answer']
    answer_analysis = query['answer_analysis']

    q_text_prompt = prompt_func(question=question, options=options, answer=gt, answer_analysis=answer_analysis)
    return q_text_prompt


def prepare_q_inputs(queries):
    messages = []
    for i, query in enumerate(queries):
        q_text_prompt = prepare_q_text_input(query)

        prompt_message = [
            {
                "role": "user",
                "content": q_text_prompt,
            },
        ]

        messages.append(prompt_message)
    return messages



# def extract_json_from_text(text):
#     text = json.dumps(text)
#     # 移除转义符和换行符
#     text = text.replace('\\n', '').replace('\\"', '"')

#     # 定义匹配 JSON 对象的正则表达式模式
#     json_pattern = re.compile(
#         r'\{\s*"question":\s*"([^"]*)",\s*"options":\s*\{\s*"A":\s*"([^"]*)",\s*"B":\s*"([^"]*)",\s*"C":\s*"([^"]*)",\s*"D":\s*"([^"]*)"\s*\},'
#         r'\s*"distractors":\s*\{\s*"E":\s*"([^"]*)",\s*"F":\s*"([^"]*)",\s*"G":\s*"([^"]*)"\s*\},\s*"correct_answer":\s*"([^"]*)"\s*\}', 
#         re.DOTALL
#     )
    
#     # 匹配 JSON 结构
#     match = json_pattern.search(text)
    
#     if match:
#         # 捕获到的匹配组
#         question = match.group(1)
#         option_a = match.group(2)
#         option_b = match.group(3)
#         option_c = match.group(4)
#         option_d = match.group(5)
#         distractor_e = match.group(6)
#         distractor_f = match.group(7)
#         distractor_g = match.group(8)
#         correct_answer = match.group(9)
        
#         # 构建 JSON 对象
#         json_data = {
#             "question": question,
#             "options": {
#                 "A": option_a,
#                 "B": option_b,
#                 "C": option_c,
#                 "D": option_d
#             },
#             "distractors": {
#                 "E": distractor_e,
#                 "F": distractor_f,
#                 "G": distractor_g
#             },
#             "correct_answer": correct_answer
#         }
        
#         return json_data
#     else:
#         print("No JSON object found in the text.")
#         return None
    

def generate_distractors(model_name: str, 
                      queries: list,
                      n: int=1,
                      max_tokens: int=4096):
    
    assert model_name in ["gpt-4o-mini", "gpt-4-turbo", "gpt-4o", "gpt-4o-2024-08-06"], "Invalid model name"

    client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"),base_url="https://yanlp.zeabur.app/v1")
    messages = prepare_q_inputs(queries)

    responses = asyncio.run(
        generate_from_openai_chat_completion(
            client,
            messages=messages, 
            engine_name=model_name,
            n = n,
            max_tokens=max_tokens,
            requests_per_minute=30,
            json_format=True
        )
    )

    for query, response in zip(queries, responses):
        new_options = response
        # print(new_options)
        if new_options and "distractors" in new_options:
            query["option_5"] = new_options["distractors"].get("E", "")
        else:
            query["option_5"] = ""
        if new_options and "distractors" in new_options:
            query["option_6"] = new_options["distractors"].get("F", "")
        else:
            query["option_6"] = ""
        if new_options and "distractors" in new_options:
            query["option_7"] = new_options["distractors"].get("G", "")
        else:
            query["option_7"] = ""
        if new_options and "distractors" in new_options:
            query["distractor_analysis"] = new_options["distractors"].get("analysis_of_distractors", "")
        else:
            query["distractor_analysis"] = ""
    
    return queries