File size: 5,461 Bytes
0319a9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438c612
0319a9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import pandas as pd
import os
from tqdm import tqdm
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

class Imagebase:
    def __init__(self, parquet_path=None):
        self.default_parquet_path = 'datas/imagebase.parquet'
        self.parquet_path = parquet_path or self.default_parquet_path
        self.datas = None
        
        if os.path.exists(self.parquet_path):
            self.load_from_parquet(self.parquet_path)

        self.clip_extractor = None

    def random_sample(self, num_samples=12):
        if self.datas is not None:
            return self.datas.sample(num_samples).to_dict(orient='records')
        else:
            return []

    def load_from_parquet(self, parquet_path):
        self.datas = pd.read_parquet(parquet_path)
    
    def save_to_parquet(self, parquet_path=None):
        parquet_path = parquet_path or self.default_parquet_path
        if self.datas is not None:
            self.datas.to_parquet(parquet_path)
    
    def init_clip_extractor(self):
        if self.clip_extractor is None:
            try:
                from CLIPExtractor import CLIPExtractor
            except:
                from src.CLIPExtractor import CLIPExtractor

            cache_dir = "models"
            self.clip_extractor = CLIPExtractor(model_name="openai/clip-vit-large-patch14", cache_dir=cache_dir)
    
    def top_k_search(self, query_feature, top_k=15):
        if self.datas is None:
            return []
        if 'clip_feature' not in self.datas.columns:
            raise ValueError("clip_feature column not found in the data.")
        
        query_feature = np.array(query_feature).reshape(1, -1)
        attribute_features = np.stack(self.datas['clip_feature'].dropna().values)
        
        similarities = cosine_similarity(query_feature, attribute_features)[0]
        
        top_k_indices = np.argsort(similarities)[-top_k:][::-1]
        
        top_k_results = self.datas.iloc[top_k_indices].copy()
        
        top_k_results['similarity'] = similarities[top_k_indices]
        
        # Drop the 'clip_feature' column
        top_k_results = top_k_results.drop(columns=['clip_feature'])
        
        return top_k_results.to_dict(orient='records')


    def search_with_image_name(self, image_name):
        self.init_clip_extractor()

        img_feature = self.clip_extractor.extract_image_from_file(image_name)

        return self.top_k_search(img_feature)

    def search_with_image(self, image, if_opencv=False):
        self.init_clip_extractor()

        img_feature = self.clip_extractor.extract_image(image, if_opencv=if_opencv)

        return self.top_k_search(img_feature)
    
    def add_image(self, data, if_save = True, image_feature = None):
        required_fields = ['image_name', 'keyword', 'translated_word']
        if not all(field in data for field in required_fields):
            raise ValueError(f"Data must contain the following fields: {required_fields}")
        
        
        
        image_name = data['image_name']
        if image_feature is None:
            self.init_clip_extractor()
            data['clip_feature'] = self.clip_extractor.extract_image_from_file(image_name)
        else:
            data['clip_feature'] = image_feature
            
        if self.datas is None:
            self.datas = pd.DataFrame([data])
        else:
            self.datas = pd.concat([self.datas, pd.DataFrame([data])], ignore_index=True)
        if if_save:
            self.save_to_parquet()
    
    def add_images(self, datas):
        for data in datas:
            self.add_image(data, if_save=False)
        self.save_to_parquet()

import os
from glob import glob

def scan_and_update_imagebase(db, target_folder="temp_images"):
    # 获取target_folder目录下所有.jpg文件
    image_files = glob(os.path.join(target_folder, "*.jpg"))

    duplicate_count = 0
    added_count = 0

    for image_path in image_files:
        # 使用文件名作为keyword
        keyword = os.path.basename(image_path).rsplit('.', 1)[0]
        translated_word = keyword  # 可以根据需要调整translated_word

        # 搜索数据库中是否有相似的图片
        results = db.search_with_image_name(image_path)

        if results and results[0]['similarity'] > 0.9:
            print(f"Image '{image_path}' is considered a duplicate.")
            duplicate_count += 1
        else:
            new_image_data = {
                'image_name': image_path,
                'keyword': keyword,
                'translated_word': translated_word
            }
            db.add_image(new_image_data)
            print(f"Image '{image_path}' added to the database.")
            added_count += 1

    print(f"Total duplicate images found: {duplicate_count}")
    print(f"Total new images added to the database: {added_count}")

if __name__ == '__main__':
    img_db = Imagebase()

    # 目标目录
    target_folder = "temp_images"
    
    # 扫描并更新数据库
    scan_and_update_imagebase(img_db, target_folder)
    
    # Usage example
    # img_db = Imagebase()

    # new_image_data = {
    #     'image_name': "datas/老虎.jpg",
    #     'keyword': 'tiger',
    #     'translated_word': '老虎'
    # }
    
    # img_db.add_image(new_image_data)

    # image_path = "datas/老虎.jpg"
    # results = img_db.search_with_image_name(image_path)
    # for result in results[:3]:
    #     print(result)