hhhwmws commited on
Commit
0319a9a
1 Parent(s): f495e87

Upload 19 files

Browse files
src/CLIPExtractor.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import CLIPProcessor, CLIPModel
4
+ import cv2
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+
9
+
10
+ class CLIPExtractor:
11
+ def __init__(self, model_name="openai/clip-vit-large-patch14", cache_dir=None):
12
+
13
+ # 设置代理环境变量
14
+ os.environ['HTTP_PROXY'] = 'http://localhost:8234'
15
+ os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
16
+
17
+ # 设置环境变量
18
+ os.environ["HF_ENDPOINT"] = "https://hf-api.gitee.com"
19
+ os.environ["HF_HOME"] = os.path.expanduser("models/")
20
+
21
+ if not cache_dir:
22
+ # 指定缓存目录
23
+ cache_dir = "models"
24
+ if not os.path.exists(cache_dir) and os.path.exists("../models"):
25
+ cache_dir = "../models"
26
+
27
+ # Initialize the model and processor with specified values
28
+ self.model = CLIPModel.from_pretrained(model_name, cache_dir=cache_dir)
29
+ self.processor = CLIPProcessor.from_pretrained(model_name, cache_dir=cache_dir)
30
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ self.model.to(self.device)
32
+
33
+ def extract_image(self, frame):
34
+ # Convert frame (from OpenCV) to PIL Image
35
+ image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
36
+ images = [image]
37
+
38
+ # Process the image and extract features
39
+ inputs = self.processor(images=images, return_tensors="pt").to(self.device)
40
+ with torch.no_grad():
41
+ outputs = self.model.get_image_features(**inputs)
42
+
43
+ ans = outputs.cpu().numpy()
44
+ return ans[0]
45
+
46
+ def extract_image_from_file(self, file_name):
47
+ if not os.path.exists(file_name):
48
+ raise FileNotFoundError(f"File {file_name} not found.")
49
+
50
+ images = [Image.open(file_name).convert("RGB")]
51
+
52
+ # Process the image and extract features
53
+ inputs = self.processor(images=images, return_tensors="pt").to(self.device)
54
+ with torch.no_grad():
55
+ outputs = self.model.get_image_features(**inputs)
56
+
57
+ ans = outputs.cpu().numpy()
58
+ return ans[0]
59
+
60
+ def extract_text(self, text):
61
+ if not isinstance(text, str) or not text:
62
+ raise ValueError("Input text should be a non-empty string.")
63
+
64
+ # Tokenize the text
65
+ inputs = self.processor.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=77).to(self.device)
66
+
67
+
68
+ # Process the text and extract features
69
+ # inputs = self.processor(text=[text], return_tensors="pt", padding=True).to(self.device)
70
+
71
+ with torch.no_grad():
72
+ outputs = self.model.get_text_features(**inputs)
73
+
74
+ ans = outputs.cpu().numpy()
75
+ return ans[0]
76
+
77
+
78
+ if __name__ == "__main__":
79
+
80
+ clip_extractor = CLIPExtractor()
81
+
82
+ sample_image = "images/狐狸.jpg"
83
+ # 提取图像特征
84
+ image_feature = clip_extractor.extract_image_from_file(sample_image)
85
+
86
+
87
+ # 提取文本特征
88
+ sample_text = "A photo of fox"
89
+ text_feature = clip_extractor.extract_text(sample_text)
90
+
91
+ # consine similarity
92
+ cosine_similarity = np.dot(image_feature, text_feature) / (np.linalg.norm(image_feature) * np.linalg.norm(text_feature))
93
+ print(cosine_similarity)
src/Captioner.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import base64
3
+ from io import BytesIO
4
+ import os
5
+ from openai import OpenAI
6
+ import json
7
+
8
+ class Captioner:
9
+ def __init__(self, api_key_path = None, proxy=None, api_base="https://api.lingyiwanwu.com/v1"):
10
+
11
+ if api_key_path is None:
12
+ # try find datas/01_key.txt and ../datas/01_key.txt
13
+ cand_paths = ['datas/01_key.txt', '../datas/01_key.txt']
14
+ flag = False
15
+ for path in cand_paths:
16
+ if os.path.exists(path):
17
+ api_key_path = path
18
+ flag = True
19
+ break
20
+
21
+ if not flag:
22
+ raise ValueError("Please provide the path to the API key file.")
23
+
24
+
25
+ self.api_key = self.load_access_token(api_key_path)
26
+ self.api_base = api_base
27
+ if proxy:
28
+ os.environ['HTTP_PROXY'] = proxy
29
+ os.environ['HTTPS_PROXY'] = proxy
30
+ self.client = OpenAI(
31
+ api_key=self.api_key,
32
+ base_url=self.api_base
33
+ )
34
+
35
+ self.history = {}
36
+ self.history_file = None
37
+
38
+ self.load_history()
39
+
40
+ def load_access_token(self, file_path):
41
+ with open(file_path, 'r') as file:
42
+ return file.read().strip()
43
+
44
+ def image2base64(self, image_path):
45
+ # 打开图像
46
+ with Image.open(image_path) as img:
47
+ # 检查图像高度是否超过480
48
+ if img.height > 480:
49
+ # 计算调整后的宽度,以保持宽高比不变
50
+ aspect_ratio = img.width / img.height
51
+ new_height = 480
52
+ new_width = int(new_height * aspect_ratio)
53
+ img = img.resize((new_width, new_height), Image.ANTIALIAS)
54
+
55
+ # 使用BytesIO在内存中保存调整大小后的图像
56
+ buffered = BytesIO()
57
+ img.save(buffered, format="JPEG")
58
+ buffered.seek(0)
59
+
60
+ # 将图像转换为Base64编码字符串
61
+ img_base64 = "data:image/jpeg;base64," + base64.b64encode(buffered.read()).decode('utf-8')
62
+
63
+ return img_base64
64
+
65
+ def load_history(self, jsonl_file_name=None):
66
+ if jsonl_file_name is None:
67
+ jsonl_file_name = "datas/caption_history.jsonl"
68
+
69
+ self.history_file = jsonl_file_name
70
+
71
+ if os.path.exists(jsonl_file_name):
72
+ with open(jsonl_file_name, 'r', encoding='utf-8') as f:
73
+ for line in f:
74
+ data = json.loads(line)
75
+ self.history[data['file_name']] = data['response']
76
+
77
+ def search_from_history(self, file_name):
78
+ return self.history.get(file_name, None)
79
+
80
+ def save_history(self, jsonl_file_name=None):
81
+ if jsonl_file_name is None:
82
+ jsonl_file_name = self.history_file
83
+
84
+ if jsonl_file_name:
85
+ with open(jsonl_file_name, 'w', encoding='utf-8') as f:
86
+ for file_name, response in self.history.items():
87
+ json.dump({'file_name': file_name, 'response': response}, f, ensure_ascii=False)
88
+ f.write('\n')
89
+
90
+ # print(f"History saved to {jsonl_file_name}")
91
+
92
+ def add_to_history(self, file_name, response):
93
+ self.history[file_name] = response
94
+
95
+ def caption(self, image_name):
96
+
97
+ # Check if the caption is already in the history
98
+ cached_response = self.search_from_history(image_name)
99
+ if cached_response:
100
+ # print("return the cache")
101
+ return cached_response
102
+
103
+ prompt = """Analyze the image and output in JSON format, including the following fields:
104
+ - "detailed_description": A detailed description of the image content.
105
+ - "major_object": Determine the main object/scene in the image based on the description, output with a simple word
106
+ - "Chinese_name": 判断图片中主要物体的中文名
107
+ - "real_or_composite": Determine whether this image was taken with a camera or created/modifed by a computer, output with real or composite."""
108
+
109
+ img_base64 = self.image2base64(image_name)
110
+
111
+ completion = self.client.chat.completions.create(
112
+ model="yi-vision",
113
+ messages=[
114
+ {
115
+ "role": "user",
116
+ "content": [
117
+ {
118
+ "type": "text",
119
+ "text": prompt
120
+ },
121
+ {
122
+ "type": "image_url",
123
+ "image_url": {
124
+ "url": img_base64
125
+ }
126
+ }
127
+ ]
128
+ }
129
+ ],
130
+ stream=False
131
+ )
132
+
133
+ response = completion.choices[0].message.content
134
+
135
+ # Add the new response to history
136
+ self.add_to_history(image_name, response)
137
+ # Save history after adding the new entry
138
+ self.save_history()
139
+
140
+ return response
141
+
142
+ if __name__ == "__main__":
143
+ import os
144
+ os.environ['HTTP_PROXY'] = 'http://localhost:8234'
145
+ os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
146
+ captioner = Captioner()
147
+ test_image = "temp_images/3zjz9b3l.jpg"
148
+ print(captioner.caption(test_image))
src/Database.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os
3
+ from tqdm import tqdm
4
+
5
+ import numpy as np
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+
8
+
9
+ class Database:
10
+ def __init__(self, parquet_path=None, customized_parquets = None):
11
+ self.default_parquet_path = 'datas/database_4000.parquet'
12
+ self.parquet_path = parquet_path or self.default_parquet_path
13
+
14
+ self.default_customized_parquets = ["datas/customized_database_0.parquet"]
15
+ self.customized_parquets = customized_parquets or self.default_customized_parquets
16
+
17
+ self.datas = None
18
+ self.last_save_table = None
19
+
20
+ if os.path.exists(self.parquet_path):
21
+ self.load_from_parquet(self.parquet_path)
22
+
23
+ self.load_from_customized(self.customized_parquets)
24
+
25
+ self.clip_extractor = None
26
+ self.bge_extractor = None
27
+
28
+ self.en_keyword2data = {}
29
+
30
+ def build_en_keyword2index(self):
31
+ # build in lower case
32
+ self.en_keyword2data = {row['translated_word'].lower(): row for i, row in self.datas.iterrows()}
33
+
34
+ def search_by_en_keyword(self, keyword):
35
+ if len(self.en_keyword2data) == 0:
36
+ self.build_en_keyword2index()
37
+
38
+ keyword = keyword.lower()
39
+ if keyword in self.en_keyword2data:
40
+ ans = self.en_keyword2data[keyword].to_dict()
41
+ del ans["clip_feature"]
42
+ del ans["bge_feature"]
43
+ return ans
44
+ else:
45
+ return None
46
+
47
+ def load_from_parquet(self, parquet_path):
48
+ self.datas = pd.read_parquet(parquet_path)
49
+
50
+ def load_from_customized(self, customized_parquets=None):
51
+ customized_parquets = customized_parquets or self.customized_parquets
52
+
53
+ # Load each parquet file and concatenate them into the self.datas DataFrame
54
+ for index, parquet_file in enumerate(customized_parquets):
55
+ if os.path.exists(parquet_file):
56
+ temp_df = pd.read_parquet(parquet_file)
57
+ if self.datas is None:
58
+ self.datas = temp_df
59
+ else:
60
+ self.datas = pd.concat([self.datas, temp_df], ignore_index=True)
61
+
62
+ # if last parquet file
63
+ if index == len(customized_parquets) - 1:
64
+ self.last_save_table = temp_df
65
+
66
+ # if customized_parquets:
67
+ # Record the last parquet file's contents as self.last_save_table
68
+
69
+
70
+ def add_data(self, data, if_save=True):
71
+ required_columns = ['keyword', 'name_in_cultivation', 'description_in_cultivation', 'translated_word', 'description']
72
+ for column in required_columns:
73
+ if column not in data:
74
+ raise ValueError(f"Missing required field: {column}")
75
+
76
+ # Optional field
77
+ if 'founder' not in data:
78
+ data['founder'] = ""
79
+
80
+ # Extract features
81
+ if self.clip_extractor is None:
82
+ self.init_clip_extractor()
83
+ if self.bge_extractor is None:
84
+ self.init_bge_extractor()
85
+
86
+ data['clip_feature'] = self.clip_extractor.extract_text(data['translated_word'] + '.' + data['description'])
87
+ data['bge_feature'] = self.bge_extractor.extract([data['keyword']])[0].tolist()
88
+
89
+ # Convert to DataFrame and add to self.datas
90
+ data_df = pd.DataFrame([data])
91
+ if self.datas is None:
92
+ self.datas = data_df
93
+ else:
94
+ self.datas = pd.concat([self.datas, data_df], ignore_index=True)
95
+
96
+ # set self.en_keyword2data to last row of self.datas
97
+ self.en_keyword2data[data['translated_word'].lower()] = self.datas.iloc[-1]
98
+
99
+ # Add to last_save_table
100
+ if self.last_save_table is None:
101
+ # self.last_save_table = data_df
102
+ # create a new DataFrame with the same columns as self.datas
103
+ self.last_save_table = pd.DataFrame(columns=self.datas.columns)
104
+
105
+ self.last_save_table = pd.concat([self.last_save_table, data_df], ignore_index=True)
106
+
107
+ if if_save:
108
+ self.save_to_parquet(self.customized_parquets[-1], self.last_save_table )
109
+
110
+ def add_datas(self, datas, if_save=True):
111
+ for data in datas:
112
+ self.add_data(data, if_save=False)
113
+ if if_save:
114
+ self.save_to_parquet(self.customized_parquets[-1], self.last_save_table)
115
+
116
+ def init_from_excel(self, excel_path):
117
+ df = pd.read_excel(excel_path)
118
+
119
+ # Drop rows with any empty cell in the required columns
120
+ df.dropna(subset=['keyword', 'name_in_cultivation', 'description_in_cultivation', 'translated_word', 'description'], inplace=True)
121
+
122
+ # Add the new columns
123
+ df['clip_feature'] = None
124
+ df['bge_feature'] = None
125
+
126
+ self.datas = df
127
+
128
+ self.extract_clip()
129
+ self.extract_bge()
130
+
131
+ def save_to_parquet(self, parquet_path=None, df = None):
132
+
133
+ parquet_path = parquet_path or self.default_parquet_path
134
+ if df is None:
135
+ if self.datas is not None:
136
+ self.datas.to_parquet(parquet_path)
137
+ else:
138
+ df.to_parquet(parquet_path)
139
+
140
+ def init_clip_extractor(self):
141
+ if self.clip_extractor is None:
142
+ try:
143
+ from CLIPExtractor import CLIPExtractor
144
+ except:
145
+ from src.CLIPExtractor import CLIPExtractor
146
+
147
+ cache_dir = "models"
148
+
149
+ self.clip_extractor = CLIPExtractor(model_name = "openai/clip-vit-large-patch14",cache_dir = cache_dir)
150
+
151
+
152
+ def extract_clip(self):
153
+ if self.clip_extractor is None:
154
+ self.init_clip_extractor()
155
+
156
+ clip_features = []
157
+ # for text in tqdm(self.datas['keyword'], desc='Extracting CLIP features'):
158
+ for index, row in tqdm(self.datas.iterrows(), desc='Extracting CLIP features', total=len(self.datas)):
159
+ text = row['translated_word'] + '.' + row['description']
160
+ if text:
161
+ feature = self.clip_extractor.extract_text(text)
162
+ else:
163
+ feature = None
164
+ clip_features.append(feature)
165
+
166
+ self.datas['clip_feature'] = clip_features
167
+
168
+ def init_bge_extractor(self):
169
+ if self.bge_extractor is None:
170
+ try:
171
+ from text_embedding import TextExtractor
172
+ except:
173
+ from src.text_embedding import TextExtractor
174
+
175
+ self.bge_extractor = TextExtractor('BAAI/bge-small-zh-v1.5')
176
+
177
+ def top_k_search(self, query_feature, attribute, top_k=15):
178
+ # Ensure the attribute exists in the dataframe
179
+ if attribute not in self.datas.columns:
180
+ raise ValueError(f"Attribute {attribute} not found in the data.")
181
+
182
+ # Convert query feature and attribute features to numpy arrays
183
+ query_feature = np.array(query_feature).reshape(1, -1)
184
+ attribute_features = np.stack(self.datas[attribute].dropna().values)
185
+
186
+ # Compute cosine similarity between query and all attributes
187
+ similarities = cosine_similarity(query_feature, attribute_features)[0]
188
+
189
+ # Get the top_k indices based on similarity
190
+ top_k_indices = np.argsort(similarities)[-top_k:][::-1]
191
+
192
+ # Retrieve the top_k most similar items
193
+ top_k_results = self.datas.iloc[top_k_indices].copy()
194
+
195
+ top_k_results = top_k_results.drop(columns=['clip_feature', 'bge_feature'])
196
+
197
+ top_k_results['similarity'] = similarities[top_k_indices]
198
+
199
+ return top_k_results.to_dict(orient='records')
200
+
201
+ def search_with_image_name(self, image_name):
202
+ self.init_clip_extractor()
203
+
204
+ img_feature = self.clip_extractor.extract_image_from_file(image_name)
205
+
206
+ return self.top_k_search(img_feature, 'clip_feature')
207
+
208
+ def search_with_image(self, image, if_opencv = False ):
209
+ if self.clip_extractor is None:
210
+ self.init_clip_extractor()
211
+
212
+ img_feature = self.clip_extractor.extract_image(image, if_opencv = if_opencv)
213
+
214
+ return self.top_k_search(img_feature, 'clip_feature')
215
+
216
+ def search_with_chinese(self, text):
217
+ if self.bge_extractor is None:
218
+ self.init_bge_extractor()
219
+
220
+ text_feature = self.bge_extractor.extract([text])[0].tolist()
221
+
222
+ return self.top_k_search(text_feature, 'bge_feature')
223
+
224
+
225
+
226
+ def extract_bge(self):
227
+ if self.bge_extractor is None:
228
+ self.init_bge_extractor()
229
+
230
+ # Extract features for each row and store them in the bge_feature column
231
+ bge_features = []
232
+ for text in tqdm(self.datas['keyword'], desc='Extracting BGE features'):
233
+ if text:
234
+ feature = self.bge_extractor.extract([text])[0].tolist()
235
+ else:
236
+ feature = None
237
+ bge_features.append(feature)
238
+
239
+ self.datas['bge_feature'] = bge_features
240
+
241
+ if __name__ == '__main__':
242
+ # Usage example
243
+ db = Database()
244
+ re_extract = False
245
+ if db.datas is None or re_extract:
246
+ print("Rebuilding database from excel file")
247
+ db.init_from_excel('datas/database_4000.xlsx')
248
+ db.save_to_parquet()
249
+
250
+ # print(db.datas[0].keys())
251
+
252
+ query_text = "钢琴"
253
+
254
+ results = db.search_with_chinese(query_text)
255
+
256
+ print(results[0].keys())
257
+
258
+ for result in results[:3]:
259
+ print(result)
260
+
261
+ image_path = "datas/老虎.jpg"
262
+
263
+ results = db.search_with_image_name(image_path)
264
+
265
+ for result in results[:3]:
266
+ print(result)
267
+ # 'keyword': '老虎狗', 'name_in_cultivation': '灵虎犬神', 'description_in_cultivation': '在九天灵脉汇聚的仙山之巅,灵虎犬神身披星图
268
+ # 斑纹,汲取日月精华,以雷霆之力守护仙脉,其双眼中映照着轮回之道,是修仙者追寻天地真理的指引,也是象征极致灵性的神秘灵兽。', 'translated_word': 'Tiger Dog', 'description': 'A Tiger Dog is a term that might refer to a mythical creature or a breed of dog with a distinctive and unusual appearance, resembling the features of a tiger. It could be characterized by its striking coat with patterns similar to those of a tiger, or by having a demeanor that is fierce and majestic like a tiger. This term is not commonly used in
269
+ # conventional contexts and might be found in stories, folktales, or in the names of unique dog breeds that have been bred to exhibit such features.', 'founder': ''
270
+ # test_new_data = {
271
+ # "keyword": "老虎狗2",
272
+ # "name_in_cultivation": "灵虎犬神",
273
+ # "description_in_cultivation": "在九天灵脉汇聚的仙山之巅,灵虎犬神身披星图斑纹,汲取日月精华,以雷霆之力守护仙脉,其双眼中映照着轮回之道,是修仙者追寻天地真理的指引,也是象征极致灵性的神秘灵兽。",
274
+ # "translated_word": "Tiger Dog",
275
+ # "description":"A Tiger Dog is a term that might refer to a mythical creature or a breed of dog with a distinctive and unusual appearance, resembling the features of a tiger. It could be characterized by its striking coat with patterns similar to those of a tiger, or by having a demeanor that is fierce and majestic like a tiger. This term is not commonly used in conventional contexts and might be found in stories, folktales, or in the names of unique dog breeds that have been bred to exhibit such features."
276
+ # }
277
+
278
+ # db.add_data(test_new_data)
src/Founder.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from collections import defaultdict
3
+
4
+ class Founder:
5
+ def __init__(self, filepath='datas/founder.jsonl'):
6
+ self.filepath = filepath
7
+ self.datas = {}
8
+ self.founder2items = defaultdict(list)
9
+
10
+ try:
11
+ self.load_founder()
12
+ except FileNotFoundError:
13
+ self.datas = {}
14
+
15
+ # Initialize the reverse mapping
16
+ for word, founder in self.datas.items():
17
+ self.founder2items[founder].append(word)
18
+
19
+ def load_founder(self):
20
+ """Load founder data from a jsonl file."""
21
+ with open(self.filepath, 'r', encoding='utf-8') as file:
22
+ for line in file:
23
+ data = json.loads(line.strip())
24
+ self.datas.update(data)
25
+
26
+ def save_founder(self):
27
+ """Save founder data to a jsonl file."""
28
+ with open(self.filepath, 'w', encoding='utf-8') as file:
29
+ for word, founder in self.datas.items():
30
+ file.write(json.dumps({word: founder}, ensure_ascii=False) + '\n')
31
+
32
+ def get_founder(self, word):
33
+ """Get the founder of a given word."""
34
+ return self.datas.get(word, None)
35
+
36
+ def set_founder(self, word, founder, enforce=False):
37
+ """Set the founder of a word if it's not already set or if enforce is True."""
38
+ if word in self.datas and not enforce:
39
+ print(f"Warning: {word} already has a founder: {self.datas[word]}. Use enforce=True to override.")
40
+ else:
41
+ self.datas[word] = founder
42
+ self.founder2items[founder].append(word)
43
+ self.save_founder()
44
+
45
+ def get_all_items_from_founder(self, founder):
46
+ """Get all words discovered by a specific founder."""
47
+ return self.founder2items.get(founder, [])
48
+
49
+ def get_top_rank(self, top_k=20):
50
+ """Get the top_k founders with the most discovered words."""
51
+ sorted_founders = sorted(self.founder2items.items(), key=lambda x: len(x[1]), reverse=True)
52
+ return sorted_founders[:top_k]
53
+
54
+ # Example usage:
55
+ # founder = Founder()
56
+ # founder.set_founder('apple', 'Alice')
57
+ # founder.set_founder('banana', 'Bob')
58
+ # print(founder.get_founder('apple'))
59
+ # print(founder.get_all_items_from_founder('Alice'))
60
+ # print(founder.get_top_rank())
61
+
62
+ if __name__ == '__main__':
63
+ founder = Founder()
64
+ founder.set_founder('test_apple', '鲁鲁道祖')
65
+ founder.set_founder('test_banana', '鲁鲁道祖')
66
+ founder.set_founder('test_orange', "文钊道祖")
67
+ print(founder.get_founder('test_apple'))
68
+ print(founder.get_all_items_from_founder('Alice'))
69
+ print(founder.get_top_rank())
src/GameMaster.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+
4
+ try:
5
+ from src.Database import Database
6
+ from src.Captioner import Captioner
7
+ from src.ImageBase import Imagebase
8
+ from src.get_major_object import get_major_object, verify_keyword_in_base
9
+ from src.generate_cultivation import generate_cultivation_with_rag
10
+ except:
11
+ from Database import Database
12
+ from Captioner import Captioner
13
+ from ImageBase import Imagebase
14
+ from get_major_object import get_major_object, verify_keyword_in_base
15
+ from generate_cultivation import generate_cultivation_with_rag
16
+
17
+
18
+ class GameMaster:
19
+ def __init__( self ):
20
+ self.textdb = self.init_textdb()
21
+
22
+ self.clip_extractor = self.textdb.clip_extractor
23
+
24
+ self.imgdb = self.init_imgdb()
25
+
26
+ self.captioner = Captioner()
27
+
28
+ self.minimal_image_threshold = 0.9
29
+
30
+ def init_textdb( self ):
31
+ text_db = Database()
32
+ text_db.init_bge_extractor()
33
+ text_db.init_clip_extractor()
34
+ return text_db
35
+
36
+ def init_imgdb( self ):
37
+ img_db = Imagebase()
38
+ return img_db
39
+
40
+ def random_image_text_data( self, n = 12 ):
41
+ random_img_datas = self.imgdb.random_sample(n)
42
+ # keep image_name and keywords only
43
+ image_names = [img_data['image_name'] for img_data in random_img_datas]
44
+ blank_image_path = "datas/blank_item.jpg"
45
+ for i in range(len(image_names)):
46
+ if not os.path.exists(image_names[i]):
47
+ image_names[i] = blank_image_path
48
+
49
+ keywords_zh = [img_data['keyword'] for img_data in random_img_datas]
50
+ keywords = [img_data['translated_word'] for img_data in random_img_datas]
51
+ descriptions = []
52
+
53
+ for keyword, keyword_zh in zip(keywords, keywords_zh):
54
+ result = self.textdb.search_by_en_keyword(keyword)
55
+ if result and "description_in_cultivation" in result:
56
+ description = result['description_in_cultivation']
57
+ if "name_in_cultivation" in result:
58
+ description = result['name_in_cultivation'] + "--" + description
59
+ descriptions.append(description)
60
+ else:
61
+ descriptions.append("")
62
+
63
+ #return tuple of imapge path and description
64
+ return zip(image_names, descriptions)
65
+
66
+
67
+ def search_with_path( self, image_path , threshold = None ):
68
+ # this is a relatively light weight search
69
+ image_feature = self.clip_extractor.extract_image_from_file(image_path)
70
+
71
+ # image_search_result = img_db.search_with_image_name(image_path)
72
+ image_search_result = self.imgdb.top_k_search(image_feature, top_k=1)
73
+
74
+ search_result = None
75
+
76
+ if threshold is None:
77
+ threshold = self.minimal_image_threshold
78
+
79
+ if image_search_result and len(image_search_result)>0 and image_search_result[0]['similarity'] > threshold:
80
+
81
+ # try find data with translated_word
82
+ result = self.textdb.search_by_en_keyword(image_search_result[0]['translated_word'])
83
+ if result and "name_in_cultivation" in result:
84
+ search_result = result
85
+ search_result['similarity'] = image_search_result[0]['similarity']
86
+ else:
87
+ print("Warning! Unfound keyword: ", image_search_result[0]['translated_word'])
88
+
89
+ # backup_results = None
90
+ # if search_result is None:
91
+ # try search with textdb
92
+ backup_results = self.textdb.top_k_search(image_feature, 'clip_feature', top_k = 5)
93
+
94
+ return search_result, backup_results, image_feature
95
+
96
+ def generate_cultivation_data( self, image_path , image_feature, text_search_result ):
97
+ # this is very expensive
98
+
99
+ cultivation_data = None
100
+
101
+ try:
102
+ caption_response = self.captioner.caption(image_path)
103
+ except:
104
+ print("Error occurred while captioning the image ", image_path)
105
+ return cultivation_data
106
+
107
+ if text_search_result is None:
108
+ # complete text search
109
+ text_search_result = self.textdb.top_k_search(image_feature, 'clip_feature', top_k = 5)
110
+
111
+ seen = set()
112
+ keywords = [res['translated_word'] for res in text_search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))]
113
+
114
+ try:
115
+ json_response = get_major_object(caption_response , keywords)
116
+ except:
117
+ print("Error occurred while getting major object from caption ", caption_response)
118
+ return cultivation_data
119
+
120
+ in_base_data , alt_data = verify_keyword_in_base(json_response , self.textdb )
121
+
122
+ if in_base_data is not None:
123
+ cultivation_data = in_base_data
124
+
125
+ # 这意味着找到了一张新的图片,不需要生成额外的词条
126
+ # required_fields = ['image_name', 'keyword', 'translated_word']
127
+ image_data = {
128
+ 'image_name': image_path,
129
+ 'keyword': in_base_data['keyword'],
130
+ 'translated_word': in_base_data['translated_word']
131
+ }
132
+ self.imgdb.add_image( image_data, True, image_feature )
133
+ elif alt_data is not None:
134
+ try:
135
+ cultivation_data = generate_cultivation_with_rag(alt_data, text_search_result)
136
+ except:
137
+ print("Error occurred while generating cultivation data")
138
+ return cultivation_data
139
+
140
+ new_data = {
141
+ "keyword": alt_data['keyword'],
142
+ "name_in_cultivation": cultivation_data['new_name'],
143
+ "description_in_cultivation": cultivation_data['final_enhanced_description'],
144
+ "translated_word": alt_data['translated_word'],
145
+ "description": alt_data['description']
146
+ }
147
+ self.textdb.add_data(new_data)
148
+ print("Added new data to textdb: ", new_data["name_in_cultivation"])
149
+
150
+ image_data = {
151
+ 'image_name': image_path,
152
+ 'keyword': new_data['keyword'],
153
+ 'translated_word': new_data['translated_word']
154
+ }
155
+ self.imgdb.add_image( image_data, True, image_feature )
156
+ print("Added new image to imgdb: ", image_data["keyword"])
157
+
158
+ cultivation_data = new_data
159
+
160
+ return cultivation_data
161
+
162
+
163
+
164
+ if __name__ == "__main__":
165
+ os.environ['HTTP_PROXY'] = 'http://localhost:8234'
166
+ os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
167
+
168
+ game_master = GameMaster()
169
+
170
+ target_folder="temp_images"
171
+
172
+ image_files = glob(os.path.join(target_folder, "*.jpg"))
173
+
174
+ for index, image_path in enumerate(image_files):
175
+ print("index:" , index )
176
+
177
+ search_result, backup_results, image_feature = game_master.search_with_path(image_path)
178
+
179
+ if search_result:
180
+ print(search_result)
181
+
182
+ break
183
+
184
+ test_image_path = "temp_images/向日葵.jpg"
185
+
186
+ search_result, backup_results, image_feature = game_master.search_with_path(test_image_path)
187
+ cultivation_data = game_master.generate_cultivation_data( \
188
+ test_image_path, image_feature, backup_results )
189
+ print(cultivation_data)
src/ImageBase.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+
7
+ class Imagebase:
8
+ def __init__(self, parquet_path=None):
9
+ self.default_parquet_path = 'datas/imagebase.parquet'
10
+ self.parquet_path = parquet_path or self.default_parquet_path
11
+ self.datas = None
12
+
13
+ if os.path.exists(self.parquet_path):
14
+ self.load_from_parquet(self.parquet_path)
15
+
16
+ self.clip_extractor = None
17
+
18
+ def random_sample(self, num_samples=12):
19
+ if self.datas is not None:
20
+ return self.datas.sample(num_samples).to_dict(orient='records')
21
+ else:
22
+ return []
23
+
24
+ def load_from_parquet(self, parquet_path):
25
+ self.datas = pd.read_parquet(parquet_path)
26
+
27
+ def save_to_parquet(self, parquet_path=None):
28
+ parquet_path = parquet_path or self.default_parquet_path
29
+ if self.datas is not None:
30
+ self.datas.to_parquet(parquet_path)
31
+
32
+ def init_clip_extractor(self):
33
+ if self.clip_extractor is None:
34
+ try:
35
+ from CLIPExtractor import CLIPExtractor
36
+ except:
37
+ from src.CLIPExtractor import CLIPExtractor
38
+
39
+ cache_dir = "D:\\aistudio\\LubaoGithub\\models"
40
+ self.clip_extractor = CLIPExtractor(model_name="openai/clip-vit-large-patch14", cache_dir=cache_dir)
41
+
42
+ def top_k_search(self, query_feature, top_k=15):
43
+ if self.datas is None:
44
+ return []
45
+ if 'clip_feature' not in self.datas.columns:
46
+ raise ValueError("clip_feature column not found in the data.")
47
+
48
+ query_feature = np.array(query_feature).reshape(1, -1)
49
+ attribute_features = np.stack(self.datas['clip_feature'].dropna().values)
50
+
51
+ similarities = cosine_similarity(query_feature, attribute_features)[0]
52
+
53
+ top_k_indices = np.argsort(similarities)[-top_k:][::-1]
54
+
55
+ top_k_results = self.datas.iloc[top_k_indices].copy()
56
+
57
+ top_k_results['similarity'] = similarities[top_k_indices]
58
+
59
+ # Drop the 'clip_feature' column
60
+ top_k_results = top_k_results.drop(columns=['clip_feature'])
61
+
62
+ return top_k_results.to_dict(orient='records')
63
+
64
+
65
+ def search_with_image_name(self, image_name):
66
+ self.init_clip_extractor()
67
+
68
+ img_feature = self.clip_extractor.extract_image_from_file(image_name)
69
+
70
+ return self.top_k_search(img_feature)
71
+
72
+ def search_with_image(self, image, if_opencv=False):
73
+ self.init_clip_extractor()
74
+
75
+ img_feature = self.clip_extractor.extract_image(image, if_opencv=if_opencv)
76
+
77
+ return self.top_k_search(img_feature)
78
+
79
+ def add_image(self, data, if_save = True, image_feature = None):
80
+ required_fields = ['image_name', 'keyword', 'translated_word']
81
+ if not all(field in data for field in required_fields):
82
+ raise ValueError(f"Data must contain the following fields: {required_fields}")
83
+
84
+
85
+
86
+ image_name = data['image_name']
87
+ if image_feature is None:
88
+ self.init_clip_extractor()
89
+ data['clip_feature'] = self.clip_extractor.extract_image_from_file(image_name)
90
+ else:
91
+ data['clip_feature'] = image_feature
92
+
93
+ if self.datas is None:
94
+ self.datas = pd.DataFrame([data])
95
+ else:
96
+ self.datas = pd.concat([self.datas, pd.DataFrame([data])], ignore_index=True)
97
+ if if_save:
98
+ self.save_to_parquet()
99
+
100
+ def add_images(self, datas):
101
+ for data in datas:
102
+ self.add_image(data, if_save=False)
103
+ self.save_to_parquet()
104
+
105
+ import os
106
+ from glob import glob
107
+
108
+ def scan_and_update_imagebase(db, target_folder="temp_images"):
109
+ # 获取target_folder目录下所有.jpg文件
110
+ image_files = glob(os.path.join(target_folder, "*.jpg"))
111
+
112
+ duplicate_count = 0
113
+ added_count = 0
114
+
115
+ for image_path in image_files:
116
+ # 使用文件名作为keyword
117
+ keyword = os.path.basename(image_path).rsplit('.', 1)[0]
118
+ translated_word = keyword # 可以根据需要调整translated_word
119
+
120
+ # 搜索数据库中是否有相似的图片
121
+ results = db.search_with_image_name(image_path)
122
+
123
+ if results and results[0]['similarity'] > 0.9:
124
+ print(f"Image '{image_path}' is considered a duplicate.")
125
+ duplicate_count += 1
126
+ else:
127
+ new_image_data = {
128
+ 'image_name': image_path,
129
+ 'keyword': keyword,
130
+ 'translated_word': translated_word
131
+ }
132
+ db.add_image(new_image_data)
133
+ print(f"Image '{image_path}' added to the database.")
134
+ added_count += 1
135
+
136
+ print(f"Total duplicate images found: {duplicate_count}")
137
+ print(f"Total new images added to the database: {added_count}")
138
+
139
+ if __name__ == '__main__':
140
+ img_db = Imagebase()
141
+
142
+ # 目���目录
143
+ target_folder = "temp_images"
144
+
145
+ # 扫描并更新数据库
146
+ scan_and_update_imagebase(img_db, target_folder)
147
+
148
+ # Usage example
149
+ # img_db = Imagebase()
150
+
151
+ # new_image_data = {
152
+ # 'image_name': "datas/老虎.jpg",
153
+ # 'keyword': 'tiger',
154
+ # 'translated_word': '老虎'
155
+ # }
156
+
157
+ # img_db.add_image(new_image_data)
158
+
159
+ # image_path = "datas/老虎.jpg"
160
+ # results = img_db.search_with_image_name(image_path)
161
+ # for result in results[:3]:
162
+ # print(result)
src/ZhipuClient.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from zhipuai import ZhipuAI
2
+ import os
3
+
4
+ class ZhipuClient:
5
+ def __init__(self, api_key_file_path = None):
6
+ if api_key_file_path is None:
7
+ cands = ['./datas/zhipu_key.txt', '../datas/zhipu_key.txt']
8
+ flag = False
9
+ for cand in cands:
10
+ if os.path.exists(cand):
11
+ api_key_file_path = cand
12
+ flag = True
13
+ break
14
+ if not flag:
15
+ raise ValueError("No valid api key file found.")
16
+
17
+ self.api_key = self._load_access_token(api_key_file_path)
18
+ self.client = ZhipuAI(api_key=self.api_key)
19
+
20
+ def _load_access_token(self, file_path):
21
+ with open(file_path, 'r') as file:
22
+ return file.read().strip()
23
+
24
+ def prompt2response(self, prompt):
25
+ response = self.client.chat.completions.create(
26
+ model="glm-4", # 填写需要调用的模型名称
27
+ messages=[
28
+ {"role": "user", "content": prompt}
29
+ ],
30
+ )
31
+ return response.choices[0].message.content
32
+
33
+ # Usage:
34
+ # zhipu_client = ZhipuClient('../datas/zhipu_key.txt')
35
+ # response = zhipu_client.prompt2response('Your prompt here')
src/__pycache__/Captioner.cpython-310.pyc ADDED
Binary file (4.22 kB). View file
 
src/__pycache__/Database.cpython-310.pyc ADDED
Binary file (7 kB). View file
 
src/__pycache__/GameMaster.cpython-310.pyc ADDED
Binary file (4.83 kB). View file
 
src/__pycache__/ImageBase.cpython-310.pyc ADDED
Binary file (4.53 kB). View file
 
src/__pycache__/ZhipuClient.cpython-310.pyc ADDED
Binary file (1.34 kB). View file
 
src/__pycache__/generate_cultivation.cpython-310.pyc ADDED
Binary file (5 kB). View file
 
src/__pycache__/get_major_object.cpython-310.pyc ADDED
Binary file (4.22 kB). View file
 
src/__pycache__/text_embedding.cpython-310.pyc ADDED
Binary file (7.49 kB). View file
 
src/generate_cultivation.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ def data2reference( top_k_items, output_n = 3 ):
4
+ outputted_items = set()
5
+
6
+ output_str = "#Reference:\n"
7
+
8
+ for item in top_k_items:
9
+ item_in_life = item["keyword"]
10
+ if item_in_life in outputted_items:
11
+ continue
12
+ name_in_cultivation = item["name_in_cultivation"]
13
+ description_in_cultivation = item["description_in_cultivation"]
14
+ # output_str += f"name_in_life: {item_in_life}\n"
15
+ # output_str += f"name_in_cultivation: {name_in_cultivation}\n"
16
+ # output_str += f"description_in_cultivation: {description_in_cultivation}\n\n"
17
+ # output with into json format
18
+ output_data = {
19
+ "name_in_life": item_in_life,
20
+ "name_in_cultivation": name_in_cultivation,
21
+ "description_in_cultivation": description_in_cultivation
22
+ }
23
+ output_str += json.dumps(output_data, ensure_ascii=False) + "\n\n"
24
+
25
+ outputted_items.add(item_in_life)
26
+ if len(outputted_items) >= output_n:
27
+ break
28
+ return output_str.strip()
29
+
30
+
31
+
32
+ def data2prompt(query_item , top_k_items):
33
+
34
+ reference_prompt = data2reference(top_k_items, 3)
35
+
36
+ task_prompt1 = "\n请参考Reference中的物品描述,将Input中的输入物品,联系改写成修仙世界中的对应物品\n"
37
+
38
+ input_prompt = "# Input:\n"
39
+ if "keyword" in query_item:
40
+ input_prompt += f"input_name:{query_item['keyword']}\n"
41
+ if "description" in query_item:
42
+ input_prompt += f"description_in_life:{query_item['description']}\n"
43
+ else:
44
+ # directly dump query_item
45
+ input_prompt += json.dumps(query_item, ensure_ascii=False) + "\n"
46
+
47
+ CoT_prompt = \
48
+ """Let's think it step by step,以json形式输出逐个字段。包含以下字段
49
+ - name_in_life: 进一步明确要生成描述的物品名称
50
+ - name_in_cultivation_1: 尝试编写物品在修仙界对应的名称
51
+ - description_in_cultivation_1: 尝试编写物品在修仙界对应的描述
52
+ - echo_1: "我将分析description_in_cultivation_1与Reference中的差异,分析description_in_cultivation_1是否已经足够生动"
53
+ - critique: 相比于Reference中的描述,分析description_in_cultivation_1在哪些方面有所欠缺
54
+ - echo_2: "根据input_name和description_in_cultivation_1,我将分析从物体的哪些属性,可以进一步加强、夸张和修改描述"
55
+ - analysis: 分析从物体的哪些属性,可以进一步加强、夸张和修改描述
56
+ - echo_3: "我将尝试3次,从不同角度加强description_in_cultivation_1的描述"
57
+ - candidate_descriptions: 从不同角度,输出3次不同的加强后的描述
58
+ - analysis_candidates: 分析各个candidates有什么优点
59
+ - echo_4: "根据analysis_candidates,我将merge出一个最终的描述"
60
+ - final_enhanced_description: 通过各个candidates的优点, merge出一个最终的描述
61
+ - echo_5: "我将分析根据final_description,是否简易将物品名称替换为新的名词"
62
+ - name_fit_analysis: 分析item_name是否还匹配final_description的描述,是否需要给input_name起一个更响亮的名字
63
+ - new_name: 如果需要,给input_name起一个更响亮的名字, 如果不需要,则仍然输出name_in_cultivation_1
64
+ """
65
+
66
+ return reference_prompt + task_prompt1 + input_prompt + CoT_prompt
67
+
68
+ try:
69
+ from src.ZhipuClient import ZhipuClient
70
+ except:
71
+ from ZhipuClient import ZhipuClient
72
+
73
+ zhipu_client = None
74
+
75
+
76
+ import json
77
+
78
+ def markdown_to_json(markdown_str):
79
+ # 移除Markdown语法中可能存在的标记,如代码块标记等
80
+ if markdown_str.startswith("```json"):
81
+ markdown_str = markdown_str[7:-3].strip()
82
+ elif markdown_str.startswith("```"):
83
+ markdown_str = markdown_str[3:-3].strip()
84
+
85
+ # 将字符串转换为JSON字典
86
+ json_dict = json.loads(markdown_str)
87
+
88
+ return json_dict
89
+
90
+ import re
91
+
92
+ def forced_extract(input_str, keywords):
93
+ result = {key: "" for key in keywords}
94
+
95
+ for key in keywords:
96
+ # 使用正则表达式来查找关键词-值对
97
+ pattern = f'"{key}":\s*"(.*?)"'
98
+ match = re.search(pattern, input_str)
99
+ if match:
100
+ result[key] = match.group(1)
101
+
102
+ return result
103
+
104
+ def generate_cultivation_with_rag( query_item, search_result ):
105
+ global zhipu_client
106
+ if zhipu_client is None:
107
+ zhipu_client = ZhipuClient()
108
+ prompt = data2prompt(query_item, search_result)
109
+ response = zhipu_client.prompt2response(prompt)
110
+
111
+ try:
112
+ json_response = markdown_to_json(response)
113
+ except:
114
+ keyword_list = ["name_in_life", "name_in_cultivation_1","description_in_cultivation_1", "final_enhanced_description", "new_name"]
115
+ json_response = forced_extract(response, keyword_list)
116
+
117
+ if "new_name" not in json_response or json_response["new_name"] == "":
118
+ if "name_in_cultivation_1" in json_response:
119
+ json_response["new_name"] = json_response["name_in_cultivation_1"]
120
+ else:
121
+ json_response["new_name"] = ""
122
+
123
+ if "final_enhanced_description" not in json_response or json_response["final_enhanced_description"] == "":
124
+ if "description_in_cultivation_1" in json_response:
125
+ json_response["final_enhanced_description"] = json_response["description_in_cultivation_1"]
126
+ else:
127
+ json_response["final_enhanced_description"] = json_response["new_name"]
128
+
129
+
130
+ return json_response
131
+
132
+ if __name__ == '__main__':
133
+ try:
134
+ from src.Database import Database
135
+ except:
136
+ from Database import Database
137
+
138
+ db = Database()
139
+
140
+ try:
141
+ from src.Captioner import Captioner
142
+ except:
143
+ from Captioner import Captioner
144
+
145
+ import os
146
+ os.environ['HTTP_PROXY'] = 'http://localhost:8234'
147
+ os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
148
+
149
+
150
+ captioner = Captioner()
151
+
152
+ test_image = "temp_images/3or47vg0.jpg"
153
+ caption_response = captioner.caption(test_image)
154
+
155
+ # print(caption_response)
156
+
157
+ search_result = db.search_with_image_name( test_image )
158
+
159
+ # print(search_result[0].keys())
160
+ # reference_str = data2reference(search_result, output_n = 3)
161
+ # print(reference_str)
162
+
163
+ seen = set()
164
+ keywords = [res['translated_word'] for res in search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))]
165
+ # print(keywords)
166
+
167
+ # prompt = data2prompt(caption_response , keywords)
168
+ # print(prompt)
169
+ from get_major_object import get_major_object, verify_keyword_in_base
170
+
171
+ json_response = get_major_object(caption_response , keywords)
172
+
173
+ print(json_response)
174
+
175
+ print()
176
+
177
+ in_base_data , alt_data = verify_keyword_in_base(json_response , db)
178
+
179
+ if alt_data is not None:
180
+ result = generate_cultivation_with_rag(alt_data , search_result)
181
+ print(result)
182
+
183
+
src/get_comments_from_level.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_comments_from_level(inbase_similarity_level, inlibrary_similarity_level):
2
+ name = "{name}" # Placeholder for the item name
3
+
4
+ # Define comments based on the similarity levels
5
+ comments_mapping = {
6
+ (0, 0): f"道友,在下才疏学浅不太认识这个东西,我感觉这个东西有点像古籍上所说的{name},还要找经验更丰富的长老来确定下",
7
+ (0, 1): f"这个东西似乎有些眼熟,但天机阁尚未收录,古籍上的描述也有些模糊,可能需要长老们进一步鉴定,我猜是{name}",
8
+ (0, 2): f"这东西颇为罕见,天机阁未曾有过记录,但古籍中的描述与{name}颇为相似,我将呈给长老们以作鉴定",
9
+ (0, 3): f"想必这一定是{name}吧,虽然天机阁还没有收录过这个东西,倒是和修仙古籍上的记载非常相像。我赶紧拿给长老再鉴定下",
10
+ (1, 0): f"天机阁的记录中似乎没有这个东西,但我依稀记得古籍中提到过{name},还需长老进一步确认",
11
+ (1, 1): f"这个物品有些特别,天机阁的记录不多,古籍中的描述也只是一笔带过,可能是{name},还需长老鉴定",
12
+ (1, 2): f"此物颇为罕见,天机阁记录较少,但古籍中的描述与{name}有一定相似之处,长老们或能给出答案",
13
+ (1, 3): f"虽然古籍中对{name}的描述详细,但天机阁中却鲜有记录,或许这是一件稀世之宝",
14
+ (2, 0): f"天机阁中对此物知之甚少,但古籍中曾提到{name},这件物品或许不简单,需长老们鉴定",
15
+ (2, 1): f"天机阁中对此物的记录不多,古籍中对{name}的描述也有限,但似乎是一件非凡之物",
16
+ (2, 2): f"这件物品在古籍中有所记载,天机阁也有少量收录,看来是{name}无疑,但还需长老确认",
17
+ (2, 3): f"虽然在古籍中有记载,天机阁过往有一点点收录,但也算稀世珍宝,{name}确实非凡",
18
+ (3, 0): f"天机阁中没有记录,但古籍中对{name}的描述颇为详细,这件物品可能是个谜",
19
+ (3, 1): f"天机阁中记录较少,但古籍中对{name}的描述详尽,这件物品或许有着不同寻常的来历",
20
+ (3, 2): f"古籍中记载{name}颇多,天机阁中也有所收录,看来这东西并不罕见",
21
+ (3, 3): f"{name}这种东西很常见啊,天机阁的库房里面都有不少呢"
22
+ }
23
+
24
+ # Return the appropriate comment based on the similarity levels
25
+ return comments_mapping.get((inbase_similarity_level, inlibrary_similarity_level), "道友,我会给出初步的鉴定") + "。"
26
+
27
+ # Example usage:
28
+ # comments = get_comments_from_level(2, 3)
29
+ # print(comments)
src/get_major_object.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def data2prompt(caption_response, ref_words):
2
+
3
+ ref_word_str = ",".join(ref_words[:5])
4
+
5
+ task_prompt = "Based on the following Caption Response, you will output a description of the Major Object's name."
6
+
7
+ input_str = "# Caption Response:\n" + caption_response + "\n"
8
+
9
+ CoT_prompt = \
10
+ f"""
11
+ Let's think it step by step. Output each field in JSON format. Include the following fields:
12
+ - major_object: From the caption response, identify the major_object. If not present, extract it again from the detailed_description or caption_response.
13
+ - better_major_object: Reread the description in the caption response to see if there's a more suitable word for the major object. If not, still output major_object.
14
+ - echo_1: "I will generate a simple description in about 200 words in English for the better_major_object, introducing what the input object is."
15
+ - description: Generate a WIKI description for the better_major_object (explain what is the better_major_object).
16
+ - major_object_chinese: Translate the better_major_object into Chinese.
17
+ - echo_2: "I will check whether there is synonym of the major_object_chinese in the '{ref_word_str}'."
18
+ - synonym: If present, output the synonym directly; otherwise, output "NOT_INCLUDED."
19
+ - recheck: Based on the content of the Caption Response, determine whether the synonym is accurate. If accurate, output "ACCURATE"; otherwise, output "NOT_ACCURATE."
20
+ """
21
+ return task_prompt + input_str + CoT_prompt
22
+
23
+
24
+ # def data2prompt(caption_response , ref_words ):
25
+
26
+
27
+ # ref_word_str = ",".join(ref_words[:5])
28
+
29
+ # ref_str = "# Reference Word:\n"+ref_word_str+"\n\n"
30
+
31
+ # task_prompt = "你将根据下面的Caption Response,输出Major Object的名称描述"
32
+
33
+ # input_str = "# Caption Response:\n"+caption_response+"\n"
34
+
35
+ # CoT_prompt = \
36
+ # """
37
+ # Let's think it step by step,以json形式输出逐个字段。包含以下字段
38
+ # - major_object: 从caption response中,确认major_object,如果没有,则从detailed_description或者caption_response中重新抽取
39
+ # - better_major_object: 重新阅读caption response中的描述,看看是否有更合适的major object的词语,如果没有则仍然输出major_object
40
+ # - echo_1: "I will generate a simple description in about 200 words in English for the input word, introducing what the input object is"
41
+ # - description: generate the description for the input object
42
+ # - major_object_chinese: 将major_object翻译为中文
43
+ # - echo_2: "我将判断reference word中,是否存在major_object的同义词"
44
+ # - 同义词: 如果存在,则直接输出同义词,否则输出"NOT_INCLUDED"
45
+ # - recheck: 结合Caption Response的内容,判断同义词是否准确,如果准确,则输出"ACCURATE",否则输出"NOT_ACCURATE"
46
+ # """
47
+ # return ref_str+task_prompt+input_str+CoT_prompt
48
+
49
+ try:
50
+ from src.ZhipuClient import ZhipuClient
51
+ except:
52
+ from ZhipuClient import ZhipuClient
53
+
54
+ zhipu_client = None
55
+
56
+ import json
57
+
58
+ def markdown_to_json(markdown_str):
59
+ # 移除Markdown语法中可能存在的标记,如代码块标记等
60
+ if markdown_str.startswith("```json"):
61
+ markdown_str = markdown_str[7:-3].strip()
62
+ elif markdown_str.startswith("```"):
63
+ markdown_str = markdown_str[3:-3].strip()
64
+
65
+ # 将字符串转换为JSON字典
66
+ json_dict = json.loads(markdown_str)
67
+
68
+ return json_dict
69
+
70
+ import re
71
+
72
+ def forced_extract(input_str, keywords):
73
+ result = {key: "" for key in keywords}
74
+
75
+ for key in keywords:
76
+ # 使用正则表达式来查找关键词-值对
77
+ pattern = f'"{key}":\s*"(.*?)"'
78
+ match = re.search(pattern, input_str)
79
+ if match:
80
+ result[key] = match.group(1)
81
+
82
+ return result
83
+
84
+ def get_major_object(caption_response, ref_words):
85
+ global zhipu_client
86
+ if zhipu_client is None:
87
+ zhipu_client = ZhipuClient()
88
+ prompt = data2prompt(caption_response , ref_words)
89
+ response = zhipu_client.prompt2response(prompt)
90
+
91
+ try:
92
+ json_response = markdown_to_json(response)
93
+ except:
94
+ keyword_list = ["major_object", "better_major_object", "description", "major_object_chinese", "synonym", "recheck"]
95
+ json_response = forced_extract(response, keyword_list)
96
+
97
+ return json_response
98
+
99
+ def verify_keyword_in_base( json_response , database ):
100
+
101
+ keyword2verify = []
102
+ if "better_major_object" in json_response:
103
+ keyword2verify.append(json_response["better_major_object"].lower())
104
+
105
+ if "major_object" in json_response:
106
+ keyword2verify.append(json_response["major_object"].lower())
107
+
108
+ if "recheck" in json_response and json_response["recheck"] == "ACCURATE":
109
+ if "synonym" in json_response and json_response["synonym"] != "NOT_INCLUDED":
110
+ keyword2verify.append(json_response["synonym"].lower())
111
+
112
+ ans = None
113
+
114
+ for keyword in keyword2verify:
115
+ res = database.search_by_en_keyword(keyword)
116
+ if res is None:
117
+ continue
118
+ ans = res
119
+ return ans, None
120
+
121
+ if len(keyword2verify) == 0:
122
+ return None, None
123
+
124
+ # 这里我们需要一个新的data, keyword是中文名, translated_word是英文名,description是英文描述
125
+ description = keyword2verify[0]
126
+ if "description" in json_response:
127
+ description = json_response["description"]
128
+
129
+ translated_word = keyword2verify[0]
130
+
131
+ keyword = translated_word
132
+ if "major_object_chinese" in json_response:
133
+ keyword = json_response["major_object_chinese"]
134
+
135
+ data = {
136
+ "keyword": keyword,
137
+ "translated_word": translated_word,
138
+ "description": description
139
+ }
140
+
141
+ return None, data
142
+
143
+
144
+
145
+
146
+ if __name__ == '__main__':
147
+ try:
148
+ from src.Database import Database
149
+ except:
150
+ from Database import Database
151
+
152
+ db = Database()
153
+
154
+ try:
155
+ from src.Captioner import Captioner
156
+ except:
157
+ from Captioner import Captioner
158
+
159
+ import os
160
+ os.environ['HTTP_PROXY'] = 'http://localhost:8234'
161
+ os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
162
+
163
+
164
+ captioner = Captioner()
165
+
166
+ test_image = "temp_images/3or47vg0.jpg"
167
+ caption_response = captioner.caption(test_image)
168
+
169
+ # print(caption_response)
170
+
171
+ search_result = db.search_with_image_name( test_image )
172
+
173
+ seen = set()
174
+ keywords = [res['translated_word'] for res in search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))]
175
+ # print(keywords)
176
+
177
+ # prompt = data2prompt(caption_response , keywords)
178
+ # print(prompt)
179
+
180
+ json_response = get_major_object(caption_response , keywords)
181
+
182
+ print(json_response)
183
+
184
+ print()
185
+
186
+ in_base_data , alt_data = verify_keyword_in_base(json_response , db)
187
+
188
+ if in_base_data is not None:
189
+ print(in_base_data)
190
+
191
+ if alt_data is not None:
192
+ print(alt_data)
193
+
194
+ # {'keyword': '埃菲尔铁塔', 'translated_word': 'eiffel tower', 'description': "The Eiffel Tower is an iconic symbol of Paris and one of the most recognizable stru
195
+ # ower', 'description': "The Eiffel Tower is an iconic symbol of Paris and one of the most recognizable structures in the world. Designed and constructed by the engineer Gustave Eiffel and his company for the 1889 Exposition Universelle (World's Fair) to celebrate the 100th anniversary of the French Revolution, the tower was initially criticized by some of France's leading artists and intellectuals. However, it quickly became a beloved landmark and a symbol of French pride. Standing 324 meters tall, the tower is made of wrought iron and consists of thousands of metal parts, including over 18,000 individual iron rivets. It is renowned for its architectural and engineering design, and it is visited by millions of people each year, making it one of the most visited paid monuments in the world."}
src/text_embedding.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModel
3
+ import os
4
+
5
+ class TextExtractor:
6
+ def __init__(self, model_name, proxy=None):
7
+ """
8
+ Initialize the TextExtractor with a specified model and optional proxy settings.
9
+
10
+ Parameters:
11
+ - model_name (str): The name of the pre-trained model to load from HuggingFace Hub.
12
+ - proxy (str, optional): The proxy address to use for HTTP and HTTPS requests.
13
+ """
14
+ if proxy is None:
15
+ proxy = 'http://localhost:8234'
16
+
17
+ if proxy:
18
+ os.environ['HTTP_PROXY'] = proxy
19
+ os.environ['HTTPS_PROXY'] = proxy
20
+ try:
21
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ self.model = AutoModel.from_pretrained(model_name)
23
+ except:
24
+ print('try switch on local_files_only')
25
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
26
+ self.model = AutoModel.from_pretrained(model_name, local_files_only=True)
27
+
28
+ self.model.eval()
29
+
30
+ def extract(self, sentences):
31
+ """
32
+ Extract sentence embeddings for the provided sentences.
33
+
34
+ Parameters:
35
+ - sentences (list of str): A list of sentences to extract embeddings for.
36
+
37
+ Returns:
38
+ - torch.Tensor: The normalized sentence embeddings.
39
+ """
40
+ encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
41
+
42
+ with torch.no_grad():
43
+ model_output = self.model(**encoded_input)
44
+ sentence_embeddings = model_output[0][:, 0]
45
+
46
+ sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
47
+ return sentence_embeddings
48
+
49
+ import pandas as pd
50
+ def get_qas(excel_file = None):
51
+
52
+ defaule_excel_file = 'data/output_fixid.xlsx'
53
+ if excel_file is None:
54
+ excel_file = defaule_excel_file
55
+
56
+ # 读取Excel文件
57
+ df = pd.read_excel(excel_file)
58
+
59
+ df = df[df["question"].notna()]
60
+ df = df[df["summary"].notna()]
61
+
62
+ datas = []
63
+
64
+ # 遍历DataFrame的每一行
65
+ for index, row in df.iterrows():
66
+ id = row['id']
67
+ question = row['question']
68
+ short_answer = row['summary']
69
+ category = row['category']
70
+
71
+ texts = [question, short_answer]
72
+
73
+ data_value = {
74
+ "texts":texts,
75
+ }
76
+
77
+ data = {
78
+ "id":id,
79
+ "value":data_value
80
+ }
81
+
82
+ datas.append(data)
83
+
84
+ return datas
85
+
86
+ from tqdm import tqdm
87
+
88
+ def extract_embedding(datas, text_extractor):
89
+ """
90
+ Extract embeddings for each item in the provided data.
91
+
92
+ Parameters:
93
+ - datas (list of dict): A list of dictionaries containing text data.
94
+
95
+ Returns:
96
+ - list of dict: The input data with added embeddings.
97
+ """
98
+ for data in tqdm(datas):
99
+ texts = data["value"]["texts"]
100
+ text = "。".join(texts)
101
+ embeddings = text_extractor.extract(text)
102
+ embeddings_list = embeddings.tolist() # Convert tensor to list of lists
103
+ data["value"]["embedding"] = embeddings_list
104
+ return datas
105
+
106
+ def save_parquet(datas, file_path):
107
+ """
108
+ Save the provided data to a Parquet file.
109
+
110
+ Parameters:
111
+ - datas (list of dict): A list of dictionaries containing text data and embeddings.
112
+ - file_path (str): The path to the output Parquet file.
113
+ """
114
+ # Flatten the data for easier conversion to DataFrame
115
+ flattened_data = []
116
+ for data in datas:
117
+ id = data["id"]
118
+ texts = data["value"]["texts"]
119
+ text = "。".join(texts)
120
+ embedding = data["value"]["embedding"]
121
+ flattened_data.append({
122
+ "id": id,
123
+ "text": text,
124
+ "embedding": embedding
125
+ })
126
+
127
+ # Create DataFrame
128
+ df = pd.DataFrame(flattened_data)
129
+
130
+ # Save DataFrame to Parquet
131
+ df.to_parquet(file_path, index=False)
132
+
133
+ import pandas as pd
134
+ import os
135
+
136
+ def get_id2embedding(regen=False, parquet_file='datas/qa_with_embedding.parquet'):
137
+ """
138
+ Get a dictionary mapping IDs to embeddings. Regenerate embeddings if specified.
139
+
140
+ Parameters:
141
+ - parquet_file (str): The path to the Parquet file.
142
+ - regen (bool): Whether to regenerate embeddings.
143
+
144
+ Returns:
145
+ - dict: A dictionary mapping IDs to list of float embeddings.
146
+ """
147
+ if regen or not os.path.exists(parquet_file):
148
+ print("Regenerating embeddings...")
149
+ # Example usage:
150
+ model_name = 'BAAI/bge-small-zh-v1.5'
151
+ text_extractor = TextExtractor(model_name)
152
+
153
+ datas = get_qas()
154
+ print("Extracting embeddings for", len(datas), "data items")
155
+
156
+ datas = extract_embedding(datas, text_extractor)
157
+ save_parquet(datas, parquet_file)
158
+
159
+ df = pd.read_parquet(parquet_file)
160
+
161
+ id2embedding = {}
162
+ for index, row in df.iterrows():
163
+ id = row['id']
164
+ embedding = row['embedding']
165
+ id2embedding[id] = embedding[0]
166
+
167
+ return id2embedding
168
+
169
+ import torch
170
+ from sklearn.metrics.pairwise import cosine_similarity
171
+ import heapq
172
+
173
+ def __get_id2top30map(id2embedding):
174
+ """
175
+ Get a dictionary mapping IDs to their top 30 nearest neighbors based on cosine similarity.
176
+
177
+ Parameters:
178
+ - id2embedding (dict): A dictionary mapping IDs to list of float embeddings.
179
+
180
+ Returns:
181
+ - dict: A dictionary mapping each ID to a list of the top 30 nearest neighbor IDs.
182
+ """
183
+ ids = list(id2embedding.keys())
184
+ embeddings = torch.tensor([id2embedding[id] for id in ids])
185
+
186
+ # Compute cosine similarity matrix
187
+ cos_sim_matrix = cosine_similarity(embeddings)
188
+
189
+ id2top30map = {}
190
+ for i, id in enumerate(ids):
191
+ # Get the similarity scores for the current ID
192
+ sim_scores = cos_sim_matrix[i]
193
+
194
+ # Get the top 30 indices (excluding the current ID itself)
195
+ top_indices = heapq.nlargest(31, range(len(sim_scores)), key=lambda x: sim_scores[x])
196
+ top_indices.remove(i) # Remove the index of the current ID
197
+
198
+ # Map the indices back to IDs
199
+ top_30_ids = [ids[idx] for idx in top_indices[:30]]
200
+
201
+ id2top30map[id] = top_30_ids
202
+
203
+ return id2top30map
204
+
205
+ import pickle
206
+
207
+ def get_id2top30map( id2embedding = None ):
208
+ default_save_pkl = "data/id2top30map.pkl"
209
+ if id2embedding is None:
210
+ if os.path.exists(default_save_pkl):
211
+ with open(default_save_pkl, 'rb') as f:
212
+ id2top30map = pickle.load(f)
213
+ else:
214
+ print("No embedding found, generating new one...")
215
+ id2embedding = get_id2embedding(regen=False)
216
+ id2top30map = __get_id2top30map(id2embedding)
217
+ with open(default_save_pkl, 'wb') as f:
218
+ pickle.dump(id2top30map, f)
219
+ else:
220
+ id2top30map = __get_id2top30map(id2embedding)
221
+
222
+ return id2top30map
223
+
224
+
225
+
226
+ if __name__ == '__main__':
227
+ if False:
228
+ # Example usage:
229
+ model_name = 'BAAI/bge-small-zh-v1.5'
230
+ sentences = ["样例数据-1", "样例数据-2"]
231
+
232
+ text_extractor = TextExtractor(model_name)
233
+ embeddings = text_extractor.extract(sentences)
234
+ print("Sentence embeddings:", embeddings)
235
+
236
+ datas = get_qas()
237
+
238
+ print("extract embedding for ", len(datas), " datas")
239
+
240
+ datas = extract_embedding(datas, text_extractor )
241
+
242
+ default_parquet_save_name = "data/qa_with_embedding.parquet"
243
+
244
+ save_parquet(datas, default_parquet_save_name)
245
+ if True:
246
+ id2embedding = get_id2embedding(regen=False)
247
+ print(len(id2embedding[4]))
248
+ id2top30map = get_id2top30map( None )
249
+ print("ID to Top 30 Neighbors dictionary:", id2top30map[4])
250
+
251
+ if True:
252
+
253
+ start_id = 332
254
+ visited_ids = [start_id]
255
+ current_queue = [start_id]
256
+
257
+ expend_num = 5
258
+
259
+ for iteration in range(10):
260
+ current_node = current_queue.pop(0)
261
+ top30 = id2top30map[current_node]
262
+ current_expend = []
263
+ for id in top30:
264
+ if id not in visited_ids:
265
+ visited_ids.append(id)
266
+ current_queue.append(id)
267
+ current_expend.append(id)
268
+ if len(current_expend) >= expend_num:
269
+ break
270
+ display_text = f"{current_node} | ->" + ",".join([str(i) for i in current_expend])
271
+ print(display_text)
272
+
273
+ from get_qa_and_image import get_qa_and_image
274
+ image_datas = get_qa_and_image()
275
+
276
+ id2index = {}
277
+
278
+ for i, data in enumerate(image_datas):
279
+ id2index[data['id']] = i
280
+
281
+ indexes = [id2index[i] for i in visited_ids if i in id2index]
282
+ image_names = [image_datas[index]['value']['image'] for index in indexes]
283
+
284
+ target_copy_folder = "data/asso_collection"
285
+
286
+ import shutil
287
+ # copy image into target_copy_folder
288
+ for image_name in image_names:
289
+ shutil.copy(image_name, target_copy_folder)