hhhwmws's picture
Upload 19 files
0319a9a verified
raw
history blame
No virus
7.2 kB
import os
from glob import glob
try:
from src.Database import Database
from src.Captioner import Captioner
from src.ImageBase import Imagebase
from src.get_major_object import get_major_object, verify_keyword_in_base
from src.generate_cultivation import generate_cultivation_with_rag
except:
from Database import Database
from Captioner import Captioner
from ImageBase import Imagebase
from get_major_object import get_major_object, verify_keyword_in_base
from generate_cultivation import generate_cultivation_with_rag
class GameMaster:
def __init__( self ):
self.textdb = self.init_textdb()
self.clip_extractor = self.textdb.clip_extractor
self.imgdb = self.init_imgdb()
self.captioner = Captioner()
self.minimal_image_threshold = 0.9
def init_textdb( self ):
text_db = Database()
text_db.init_bge_extractor()
text_db.init_clip_extractor()
return text_db
def init_imgdb( self ):
img_db = Imagebase()
return img_db
def random_image_text_data( self, n = 12 ):
random_img_datas = self.imgdb.random_sample(n)
# keep image_name and keywords only
image_names = [img_data['image_name'] for img_data in random_img_datas]
blank_image_path = "datas/blank_item.jpg"
for i in range(len(image_names)):
if not os.path.exists(image_names[i]):
image_names[i] = blank_image_path
keywords_zh = [img_data['keyword'] for img_data in random_img_datas]
keywords = [img_data['translated_word'] for img_data in random_img_datas]
descriptions = []
for keyword, keyword_zh in zip(keywords, keywords_zh):
result = self.textdb.search_by_en_keyword(keyword)
if result and "description_in_cultivation" in result:
description = result['description_in_cultivation']
if "name_in_cultivation" in result:
description = result['name_in_cultivation'] + "--" + description
descriptions.append(description)
else:
descriptions.append("")
#return tuple of imapge path and description
return zip(image_names, descriptions)
def search_with_path( self, image_path , threshold = None ):
# this is a relatively light weight search
image_feature = self.clip_extractor.extract_image_from_file(image_path)
# image_search_result = img_db.search_with_image_name(image_path)
image_search_result = self.imgdb.top_k_search(image_feature, top_k=1)
search_result = None
if threshold is None:
threshold = self.minimal_image_threshold
if image_search_result and len(image_search_result)>0 and image_search_result[0]['similarity'] > threshold:
# try find data with translated_word
result = self.textdb.search_by_en_keyword(image_search_result[0]['translated_word'])
if result and "name_in_cultivation" in result:
search_result = result
search_result['similarity'] = image_search_result[0]['similarity']
else:
print("Warning! Unfound keyword: ", image_search_result[0]['translated_word'])
# backup_results = None
# if search_result is None:
# try search with textdb
backup_results = self.textdb.top_k_search(image_feature, 'clip_feature', top_k = 5)
return search_result, backup_results, image_feature
def generate_cultivation_data( self, image_path , image_feature, text_search_result ):
# this is very expensive
cultivation_data = None
try:
caption_response = self.captioner.caption(image_path)
except:
print("Error occurred while captioning the image ", image_path)
return cultivation_data
if text_search_result is None:
# complete text search
text_search_result = self.textdb.top_k_search(image_feature, 'clip_feature', top_k = 5)
seen = set()
keywords = [res['translated_word'] for res in text_search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))]
try:
json_response = get_major_object(caption_response , keywords)
except:
print("Error occurred while getting major object from caption ", caption_response)
return cultivation_data
in_base_data , alt_data = verify_keyword_in_base(json_response , self.textdb )
if in_base_data is not None:
cultivation_data = in_base_data
# 这意味着找到了一张新的图片,不需要生成额外的词条
# required_fields = ['image_name', 'keyword', 'translated_word']
image_data = {
'image_name': image_path,
'keyword': in_base_data['keyword'],
'translated_word': in_base_data['translated_word']
}
self.imgdb.add_image( image_data, True, image_feature )
elif alt_data is not None:
try:
cultivation_data = generate_cultivation_with_rag(alt_data, text_search_result)
except:
print("Error occurred while generating cultivation data")
return cultivation_data
new_data = {
"keyword": alt_data['keyword'],
"name_in_cultivation": cultivation_data['new_name'],
"description_in_cultivation": cultivation_data['final_enhanced_description'],
"translated_word": alt_data['translated_word'],
"description": alt_data['description']
}
self.textdb.add_data(new_data)
print("Added new data to textdb: ", new_data["name_in_cultivation"])
image_data = {
'image_name': image_path,
'keyword': new_data['keyword'],
'translated_word': new_data['translated_word']
}
self.imgdb.add_image( image_data, True, image_feature )
print("Added new image to imgdb: ", image_data["keyword"])
cultivation_data = new_data
return cultivation_data
if __name__ == "__main__":
os.environ['HTTP_PROXY'] = 'http://localhost:8234'
os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
game_master = GameMaster()
target_folder="temp_images"
image_files = glob(os.path.join(target_folder, "*.jpg"))
for index, image_path in enumerate(image_files):
print("index:" , index )
search_result, backup_results, image_feature = game_master.search_with_path(image_path)
if search_result:
print(search_result)
break
test_image_path = "temp_images/向日葵.jpg"
search_result, backup_results, image_feature = game_master.search_with_path(test_image_path)
cultivation_data = game_master.generate_cultivation_data( \
test_image_path, image_feature, backup_results )
print(cultivation_data)