hhhwmws's picture
Update src/ImageBase.py
438c612 verified
raw
history blame contribute delete
No virus
5.46 kB
import pandas as pd
import os
from tqdm import tqdm
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
class Imagebase:
def __init__(self, parquet_path=None):
self.default_parquet_path = 'datas/imagebase.parquet'
self.parquet_path = parquet_path or self.default_parquet_path
self.datas = None
if os.path.exists(self.parquet_path):
self.load_from_parquet(self.parquet_path)
self.clip_extractor = None
def random_sample(self, num_samples=12):
if self.datas is not None:
return self.datas.sample(num_samples).to_dict(orient='records')
else:
return []
def load_from_parquet(self, parquet_path):
self.datas = pd.read_parquet(parquet_path)
def save_to_parquet(self, parquet_path=None):
parquet_path = parquet_path or self.default_parquet_path
if self.datas is not None:
self.datas.to_parquet(parquet_path)
def init_clip_extractor(self):
if self.clip_extractor is None:
try:
from CLIPExtractor import CLIPExtractor
except:
from src.CLIPExtractor import CLIPExtractor
cache_dir = "models"
self.clip_extractor = CLIPExtractor(model_name="openai/clip-vit-large-patch14", cache_dir=cache_dir)
def top_k_search(self, query_feature, top_k=15):
if self.datas is None:
return []
if 'clip_feature' not in self.datas.columns:
raise ValueError("clip_feature column not found in the data.")
query_feature = np.array(query_feature).reshape(1, -1)
attribute_features = np.stack(self.datas['clip_feature'].dropna().values)
similarities = cosine_similarity(query_feature, attribute_features)[0]
top_k_indices = np.argsort(similarities)[-top_k:][::-1]
top_k_results = self.datas.iloc[top_k_indices].copy()
top_k_results['similarity'] = similarities[top_k_indices]
# Drop the 'clip_feature' column
top_k_results = top_k_results.drop(columns=['clip_feature'])
return top_k_results.to_dict(orient='records')
def search_with_image_name(self, image_name):
self.init_clip_extractor()
img_feature = self.clip_extractor.extract_image_from_file(image_name)
return self.top_k_search(img_feature)
def search_with_image(self, image, if_opencv=False):
self.init_clip_extractor()
img_feature = self.clip_extractor.extract_image(image, if_opencv=if_opencv)
return self.top_k_search(img_feature)
def add_image(self, data, if_save = True, image_feature = None):
required_fields = ['image_name', 'keyword', 'translated_word']
if not all(field in data for field in required_fields):
raise ValueError(f"Data must contain the following fields: {required_fields}")
image_name = data['image_name']
if image_feature is None:
self.init_clip_extractor()
data['clip_feature'] = self.clip_extractor.extract_image_from_file(image_name)
else:
data['clip_feature'] = image_feature
if self.datas is None:
self.datas = pd.DataFrame([data])
else:
self.datas = pd.concat([self.datas, pd.DataFrame([data])], ignore_index=True)
if if_save:
self.save_to_parquet()
def add_images(self, datas):
for data in datas:
self.add_image(data, if_save=False)
self.save_to_parquet()
import os
from glob import glob
def scan_and_update_imagebase(db, target_folder="temp_images"):
# 获取target_folder目录下所有.jpg文件
image_files = glob(os.path.join(target_folder, "*.jpg"))
duplicate_count = 0
added_count = 0
for image_path in image_files:
# 使用文件名作为keyword
keyword = os.path.basename(image_path).rsplit('.', 1)[0]
translated_word = keyword # 可以根据需要调整translated_word
# 搜索数据库中是否有相似的图片
results = db.search_with_image_name(image_path)
if results and results[0]['similarity'] > 0.9:
print(f"Image '{image_path}' is considered a duplicate.")
duplicate_count += 1
else:
new_image_data = {
'image_name': image_path,
'keyword': keyword,
'translated_word': translated_word
}
db.add_image(new_image_data)
print(f"Image '{image_path}' added to the database.")
added_count += 1
print(f"Total duplicate images found: {duplicate_count}")
print(f"Total new images added to the database: {added_count}")
if __name__ == '__main__':
img_db = Imagebase()
# 目标目录
target_folder = "temp_images"
# 扫描并更新数据库
scan_and_update_imagebase(img_db, target_folder)
# Usage example
# img_db = Imagebase()
# new_image_data = {
# 'image_name': "datas/老虎.jpg",
# 'keyword': 'tiger',
# 'translated_word': '老虎'
# }
# img_db.add_image(new_image_data)
# image_path = "datas/老虎.jpg"
# results = img_db.search_with_image_name(image_path)
# for result in results[:3]:
# print(result)