File size: 6,772 Bytes
0319a9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import json

def data2reference( top_k_items, output_n = 3 ):
    outputted_items = set()

    output_str = "#Reference:\n"

    for item in top_k_items:
        item_in_life = item["keyword"]
        if item_in_life in outputted_items:
            continue
        name_in_cultivation = item["name_in_cultivation"]
        description_in_cultivation = item["description_in_cultivation"]
        # output_str += f"name_in_life: {item_in_life}\n"
        # output_str += f"name_in_cultivation: {name_in_cultivation}\n"
        # output_str += f"description_in_cultivation: {description_in_cultivation}\n\n"
        # output with into json format
        output_data = {
            "name_in_life": item_in_life,
            "name_in_cultivation": name_in_cultivation,
            "description_in_cultivation": description_in_cultivation
        }
        output_str += json.dumps(output_data, ensure_ascii=False) + "\n\n"

        outputted_items.add(item_in_life)
        if len(outputted_items) >= output_n:
            break
    return output_str.strip()



def data2prompt(query_item , top_k_items):

    reference_prompt = data2reference(top_k_items, 3)

    task_prompt1 = "\n请参考Reference中的物品描述,将Input中的输入物品,联系改写成修仙世界中的对应物品\n"

    input_prompt = "# Input:\n"
    if "keyword" in query_item:
        input_prompt += f"input_name:{query_item['keyword']}\n"
        if "description" in query_item:
            input_prompt += f"description_in_life:{query_item['description']}\n"
    else:
        # directly dump query_item
        input_prompt += json.dumps(query_item, ensure_ascii=False) + "\n"

    CoT_prompt = \
"""Let's think it step by step,以json形式输出逐个字段。包含以下字段
- name_in_life: 进一步明确要生成描述的物品名称
- name_in_cultivation_1: 尝试编写物品在修仙界对应的名称
- description_in_cultivation_1: 尝试编写物品在修仙界对应的描述
- echo_1: "我将分析description_in_cultivation_1与Reference中的差异,分析description_in_cultivation_1是否已经足够生动"
- critique: 相比于Reference中的描述,分析description_in_cultivation_1在哪些方面有所欠缺
- echo_2: "根据input_name和description_in_cultivation_1,我将分析从物体的哪些属性,可以进一步加强、夸张和修改描述"
- analysis: 分析从物体的哪些属性,可以进一步加强、夸张和修改描述
- echo_3: "我将尝试3次,从不同角度加强description_in_cultivation_1的描述"
- candidate_descriptions: 从不同角度,输出3次不同的加强后的描述
- analysis_candidates: 分析各个candidates有什么优点
- echo_4: "根据analysis_candidates,我将merge出一个最终的描述"
- final_enhanced_description: 通过各个candidates的优点, merge出一个最终的描述
- echo_5: "我将分析根据final_description,是否简易将物品名称替换为新的名词"
- name_fit_analysis: 分析item_name是否还匹配final_description的描述,是否需要给input_name起一个更响亮的名字
- new_name: 如果需要,给input_name起一个更响亮的名字, 如果不需要,则仍然输出name_in_cultivation_1
"""

    return reference_prompt + task_prompt1 + input_prompt + CoT_prompt

try:
    from src.ZhipuClient import ZhipuClient
except:
    from ZhipuClient import ZhipuClient

zhipu_client = None


import json

def markdown_to_json(markdown_str):
    # 移除Markdown语法中可能存在的标记,如代码块标记等
    if markdown_str.startswith("```json"):
        markdown_str = markdown_str[7:-3].strip()
    elif markdown_str.startswith("```"):
        markdown_str = markdown_str[3:-3].strip()

    # 将字符串转换为JSON字典
    json_dict = json.loads(markdown_str)

    return json_dict

import re

def forced_extract(input_str, keywords):
    result = {key: "" for key in keywords}

    for key in keywords:
        # 使用正则表达式来查找关键词-值对
        pattern = f'"{key}":\s*"(.*?)"'
        match = re.search(pattern, input_str)
        if match:
            result[key] = match.group(1)
    
    return result

def generate_cultivation_with_rag( query_item, search_result ):
    global zhipu_client
    if zhipu_client is None:
        zhipu_client = ZhipuClient()
    prompt = data2prompt(query_item, search_result)
    response = zhipu_client.prompt2response(prompt)

    try:
        json_response = markdown_to_json(response)
    except:
        keyword_list = ["name_in_life", "name_in_cultivation_1","description_in_cultivation_1", "final_enhanced_description", "new_name"]
        json_response = forced_extract(response, keyword_list)

    if "new_name" not in json_response or json_response["new_name"] == "":
        if "name_in_cultivation_1" in json_response:
            json_response["new_name"] = json_response["name_in_cultivation_1"]
        else:
            json_response["new_name"] = ""

    if "final_enhanced_description" not in json_response or json_response["final_enhanced_description"] == "":
        if "description_in_cultivation_1" in json_response:
            json_response["final_enhanced_description"] = json_response["description_in_cultivation_1"]
        else:
            json_response["final_enhanced_description"] = json_response["new_name"]
    

    return json_response

if __name__ == '__main__':
    try:
        from src.Database import Database
    except:
        from Database import Database
    
    db = Database()

    try:
        from src.Captioner import Captioner
    except:
        from Captioner import Captioner

    import os
    os.environ['HTTP_PROXY'] = 'http://localhost:8234'
    os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
    
    
    captioner = Captioner()

    test_image = "temp_images/3or47vg0.jpg"
    caption_response = captioner.caption(test_image)

    # print(caption_response)

    search_result = db.search_with_image_name( test_image )

    # print(search_result[0].keys())
    # reference_str = data2reference(search_result, output_n = 3)
    # print(reference_str)

    seen = set()
    keywords = [res['translated_word'] for res in search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))]
    # print(keywords)

    # prompt = data2prompt(caption_response , keywords)
    # print(prompt)
    from get_major_object import get_major_object, verify_keyword_in_base

    json_response = get_major_object(caption_response , keywords)
    
    print(json_response)

    print()

    in_base_data , alt_data = verify_keyword_in_base(json_response , db)

    if alt_data is not None:
        result = generate_cultivation_with_rag(alt_data , search_result)
        print(result)