test_Idiot-Cultivation-System / src /generate_cultivation.py
hhhwmws's picture
Upload 19 files
0319a9a verified
raw
history blame contribute delete
No virus
6.77 kB
import json
def data2reference( top_k_items, output_n = 3 ):
outputted_items = set()
output_str = "#Reference:\n"
for item in top_k_items:
item_in_life = item["keyword"]
if item_in_life in outputted_items:
continue
name_in_cultivation = item["name_in_cultivation"]
description_in_cultivation = item["description_in_cultivation"]
# output_str += f"name_in_life: {item_in_life}\n"
# output_str += f"name_in_cultivation: {name_in_cultivation}\n"
# output_str += f"description_in_cultivation: {description_in_cultivation}\n\n"
# output with into json format
output_data = {
"name_in_life": item_in_life,
"name_in_cultivation": name_in_cultivation,
"description_in_cultivation": description_in_cultivation
}
output_str += json.dumps(output_data, ensure_ascii=False) + "\n\n"
outputted_items.add(item_in_life)
if len(outputted_items) >= output_n:
break
return output_str.strip()
def data2prompt(query_item , top_k_items):
reference_prompt = data2reference(top_k_items, 3)
task_prompt1 = "\n请参考Reference中的物品描述,将Input中的输入物品,联系改写成修仙世界中的对应物品\n"
input_prompt = "# Input:\n"
if "keyword" in query_item:
input_prompt += f"input_name:{query_item['keyword']}\n"
if "description" in query_item:
input_prompt += f"description_in_life:{query_item['description']}\n"
else:
# directly dump query_item
input_prompt += json.dumps(query_item, ensure_ascii=False) + "\n"
CoT_prompt = \
"""Let's think it step by step,以json形式输出逐个字段。包含以下字段
- name_in_life: 进一步明确要生成描述的物品名称
- name_in_cultivation_1: 尝试编写物品在修仙界对应的名称
- description_in_cultivation_1: 尝试编写物品在修仙界对应的描述
- echo_1: "我将分析description_in_cultivation_1与Reference中的差异,分析description_in_cultivation_1是否已经足够生动"
- critique: 相比于Reference中的描述,分析description_in_cultivation_1在哪些方面有所欠缺
- echo_2: "根据input_name和description_in_cultivation_1,我将分析从物体的哪些属性,可以进一步加强、夸张和修改描述"
- analysis: 分析从物体的哪些属性,可以进一步加强、夸张和修改描述
- echo_3: "我将尝试3次,从不同角度加强description_in_cultivation_1的描述"
- candidate_descriptions: 从不同角度,输出3次不同的加强后的描述
- analysis_candidates: 分析各个candidates有什么优点
- echo_4: "根据analysis_candidates,我将merge出一个最终的描述"
- final_enhanced_description: 通过各个candidates的优点, merge出一个最终的描述
- echo_5: "我将分析根据final_description,是否简易将物品名称替换为新的名词"
- name_fit_analysis: 分析item_name是否还匹配final_description的描述,是否需要给input_name起一个更响亮的名字
- new_name: 如果需要,给input_name起一个更响亮的名字, 如果不需要,则仍然输出name_in_cultivation_1
"""
return reference_prompt + task_prompt1 + input_prompt + CoT_prompt
try:
from src.ZhipuClient import ZhipuClient
except:
from ZhipuClient import ZhipuClient
zhipu_client = None
import json
def markdown_to_json(markdown_str):
# 移除Markdown语法中可能存在的标记,如代码块标记等
if markdown_str.startswith("```json"):
markdown_str = markdown_str[7:-3].strip()
elif markdown_str.startswith("```"):
markdown_str = markdown_str[3:-3].strip()
# 将字符串转换为JSON字典
json_dict = json.loads(markdown_str)
return json_dict
import re
def forced_extract(input_str, keywords):
result = {key: "" for key in keywords}
for key in keywords:
# 使用正则表达式来查找关键词-值对
pattern = f'"{key}":\s*"(.*?)"'
match = re.search(pattern, input_str)
if match:
result[key] = match.group(1)
return result
def generate_cultivation_with_rag( query_item, search_result ):
global zhipu_client
if zhipu_client is None:
zhipu_client = ZhipuClient()
prompt = data2prompt(query_item, search_result)
response = zhipu_client.prompt2response(prompt)
try:
json_response = markdown_to_json(response)
except:
keyword_list = ["name_in_life", "name_in_cultivation_1","description_in_cultivation_1", "final_enhanced_description", "new_name"]
json_response = forced_extract(response, keyword_list)
if "new_name" not in json_response or json_response["new_name"] == "":
if "name_in_cultivation_1" in json_response:
json_response["new_name"] = json_response["name_in_cultivation_1"]
else:
json_response["new_name"] = ""
if "final_enhanced_description" not in json_response or json_response["final_enhanced_description"] == "":
if "description_in_cultivation_1" in json_response:
json_response["final_enhanced_description"] = json_response["description_in_cultivation_1"]
else:
json_response["final_enhanced_description"] = json_response["new_name"]
return json_response
if __name__ == '__main__':
try:
from src.Database import Database
except:
from Database import Database
db = Database()
try:
from src.Captioner import Captioner
except:
from Captioner import Captioner
import os
os.environ['HTTP_PROXY'] = 'http://localhost:8234'
os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
captioner = Captioner()
test_image = "temp_images/3or47vg0.jpg"
caption_response = captioner.caption(test_image)
# print(caption_response)
search_result = db.search_with_image_name( test_image )
# print(search_result[0].keys())
# reference_str = data2reference(search_result, output_n = 3)
# print(reference_str)
seen = set()
keywords = [res['translated_word'] for res in search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))]
# print(keywords)
# prompt = data2prompt(caption_response , keywords)
# print(prompt)
from get_major_object import get_major_object, verify_keyword_in_base
json_response = get_major_object(caption_response , keywords)
print(json_response)
print()
in_base_data , alt_data = verify_keyword_in_base(json_response , db)
if alt_data is not None:
result = generate_cultivation_with_rag(alt_data , search_result)
print(result)