Demo750's picture
Upload folder using huggingface_hub
569f484 verified
raw
history blame
5.97 kB
import sys
sys.path.append('../XGBoost_Prediction_Model/')
import warnings
warnings.filterwarnings("ignore")
import Predict
import XGBoost_utils
import torch
import numpy as np
import os
from os.path import isfile, isdir, join
text_detection_model_path = '../XGBoost_Prediction_Model/EAST-Text-Detection/frozen_east_text_detection.pb'
LDA_model_pth = '../XGBoost_Prediction_Model/LDA_Model_trained/lda_model_best_tot.model'
training_ad_text_dictionary_path = '../XGBoost_Prediction_Model/LDA_Model_trained/object_word_dictionary'
training_lang_preposition_path = '../XGBoost_Prediction_Model/LDA_Model_trained/dutch_preposition'
Adversarial_types = ['Newspaper', 'Banner', 'Multiple Ads', 'Outdoor', 'Others']
mypath = '../XGBoost_Prediction_Model/Adversarial Samples New'
ad_locations_total = {
'Newspaper': [0,0,0,0,0,0,0,0,0,0,0],
'Banner': [2,2,2,2,2,2,2,2,2,2,2],
'Multiple Ads': [0,0,1,0,0,1,0,0,0,0,0],
'Outdoor': [2,2,2,2,2,2,2,2,2,2,2],
'Others': [1,0,1,0,2,0,1,1,0,0,0]
}
Predicted_AG = {}
Predicted_BG = {}
GT_AG = {}
GT_BG = {}
for f in os.listdir(mypath):
# if f != 'Outdoor':
# continue
if isdir(join(mypath, f)):
print('Currently processing samples of type '+f+'......')
path_temp = join(mypath, f)
surfaces = torch.load(join(path_temp,f+'_Adversarial_Surfaces'))
print(surfaces)
categories = torch.load(join(path_temp,f+'_Adversarial_Categories'))
ad_locations_curr = ad_locations_total[f]
ad_embeddings = torch.load(join(path_temp,f+'_ad_topic_embeddings'))
ctpg_embeddings = torch.load(join(path_temp,f+'_ctpg_topic_embeddings'))
GT_AG[f] = torch.load(join(path_temp,'AGs'))
GT_BG[f] = torch.load(join(path_temp,'BGs'))
#f is, e.g. Outdoor
#sub_f is, e.g. 1,2,...,11
AG_predictions_per_type = np.zeros(11)
BG_predictions_per_type = np.zeros(11)
for sub_f in os.listdir(path_temp):
if isdir(join(path_temp, sub_f)):
print('Sample Number '+sub_f+'...... ')
context_pth = None
for imgs in os.listdir(join(path_temp, sub_f)):
jpg_removed_split = imgs[:-4].split(' ')
if len(jpg_removed_split) > 1:
img_type = jpg_removed_split[-1]
if img_type == 'Ad':
ad_pth = join(join(path_temp, sub_f),imgs)
print(ad_pth)
elif img_type == 'Context':
context_pth = join(join(path_temp, sub_f),imgs)
#Start Predicting
sample_num = int(sub_f)-1
AG = Predict.Ad_Gaze_Prediction(input_ad_path=ad_pth, input_ctpg_path=context_pth, text_detection_model_path=text_detection_model_path, LDA_model_pth=LDA_model_pth,
training_ad_text_dictionary_path=training_ad_text_dictionary_path, training_lang_preposition_path=training_lang_preposition_path, training_language='dutch',
ad_embeddings=ad_embeddings[sample_num].reshape(1,768), ctpg_embeddings=ctpg_embeddings[sample_num].reshape(1,768),
surface_sizes=list(surfaces[sample_num]), Product_Group=list(categories[sample_num]),
obj_detection_model_pth=None, ad_location=ad_locations_curr[sample_num], num_topic=20, Gaze_Time_Type='Ad')
AG_predictions_per_type[sample_num] = AG
BG = Predict.Ad_Gaze_Prediction(input_ad_path=ad_pth, input_ctpg_path=context_pth, text_detection_model_path=text_detection_model_path, LDA_model_pth=LDA_model_pth,
training_ad_text_dictionary_path=training_ad_text_dictionary_path, training_lang_preposition_path=training_lang_preposition_path, training_language='dutch',
ad_embeddings=ad_embeddings[sample_num].reshape(1,768), ctpg_embeddings=ctpg_embeddings[sample_num].reshape(1,768),
surface_sizes=list(surfaces[sample_num]), Product_Group=list(categories[sample_num]),
obj_detection_model_pth=None, ad_location=ad_locations_curr[sample_num], num_topic=20, Gaze_Time_Type='Brand')
BG_predictions_per_type[sample_num] = BG
Predicted_AG[f] = AG_predictions_per_type
Predicted_BG[f] = BG_predictions_per_type
print("Final results: ")
diffs_rmse = {}
diffs_rmsrpd = {}
GT_AG_tot = []
GT_BG_tot = []
Pred_AG_tot = []
Pred_BG_tot = []
for key in Predicted_AG.keys():
print(key)
print('AGs', Predicted_AG[key])
print('BGs', Predicted_BG[key])
Pred_AG_tot.append(Predicted_AG[key])
Pred_BG_tot.append(Predicted_BG[key])
GT_AG_tot.append(GT_AG[key])
GT_BG_tot.append(GT_BG[key])
rmse1 = np.sqrt(np.mean((GT_AG[key]-Predicted_AG[key])**2))
rmse2 = np.sqrt(np.mean((GT_BG[key]-Predicted_BG[key])**2))
diffs_rmse[key] = (rmse1, rmse2)
rmsrpd1 = XGBoost_utils.RMSRPD(GT_AG[key],Predicted_AG[key])
rmsrpd2 = XGBoost_utils.RMSRPD(GT_BG[key],Predicted_BG[key])
diffs_rmsrpd[key] = (rmsrpd1, rmsrpd2)
print()
Pred_AG_tot = np.concatenate(Pred_AG_tot)
Pred_BG_tot = np.concatenate(Pred_BG_tot)
GT_AG_tot = np.concatenate(GT_AG_tot)
GT_BG_tot = np.concatenate(GT_BG_tot)
print("RMSE: ")
print("Total AG: ", np.sqrt(np.mean((Pred_AG_tot-GT_AG_tot)**2)))
print("Total BG: ", np.sqrt(np.mean((Pred_BG_tot-GT_BG_tot)**2)))
print()
for key in diffs_rmse.keys():
print(key, diffs_rmse[key][0], diffs_rmse[key][1])
print()
print("RMSRPD: ")
print("Total AG: ", XGBoost_utils.RMSRPD(Pred_AG_tot,GT_AG_tot))
print("Total BG: ", XGBoost_utils.RMSRPD(Pred_BG_tot,GT_BG_tot))
print()
for key in diffs_rmsrpd.keys():
print(key, diffs_rmsrpd[key][0], diffs_rmsrpd[key][1])
print()
print("Correlation: ")
print("Total AG: ", np.corrcoef(Pred_AG_tot, GT_AG_tot))
print("Total BG: ", np.corrcoef(Pred_BG_tot, GT_BG_tot))