import sys sys.path.append('../XGBoost_Prediction_Model/') import warnings warnings.filterwarnings("ignore") import Predict import XGBoost_utils import torch import numpy as np import os from os.path import isfile, isdir, join text_detection_model_path = '../XGBoost_Prediction_Model/EAST-Text-Detection/frozen_east_text_detection.pb' LDA_model_pth = '../XGBoost_Prediction_Model/LDA_Model_trained/lda_model_best_tot.model' training_ad_text_dictionary_path = '../XGBoost_Prediction_Model/LDA_Model_trained/object_word_dictionary' training_lang_preposition_path = '../XGBoost_Prediction_Model/LDA_Model_trained/dutch_preposition' Adversarial_types = ['Newspaper', 'Banner', 'Multiple Ads', 'Outdoor', 'Others'] mypath = '../XGBoost_Prediction_Model/Adversarial Samples New' ad_locations_total = { 'Newspaper': [0,0,0,0,0,0,0,0,0,0,0], 'Banner': [2,2,2,2,2,2,2,2,2,2,2], 'Multiple Ads': [0,0,1,0,0,1,0,0,0,0,0], 'Outdoor': [2,2,2,2,2,2,2,2,2,2,2], 'Others': [1,0,1,0,2,0,1,1,0,0,0] } Predicted_AG = {} Predicted_BG = {} GT_AG = {} GT_BG = {} for f in os.listdir(mypath): # if f != 'Outdoor': # continue if isdir(join(mypath, f)): print('Currently processing samples of type '+f+'......') path_temp = join(mypath, f) surfaces = torch.load(join(path_temp,f+'_Adversarial_Surfaces')) print(surfaces) categories = torch.load(join(path_temp,f+'_Adversarial_Categories')) ad_locations_curr = ad_locations_total[f] ad_embeddings = torch.load(join(path_temp,f+'_ad_topic_embeddings')) ctpg_embeddings = torch.load(join(path_temp,f+'_ctpg_topic_embeddings')) GT_AG[f] = torch.load(join(path_temp,'AGs')) GT_BG[f] = torch.load(join(path_temp,'BGs')) #f is, e.g. Outdoor #sub_f is, e.g. 1,2,...,11 AG_predictions_per_type = np.zeros(11) BG_predictions_per_type = np.zeros(11) for sub_f in os.listdir(path_temp): if isdir(join(path_temp, sub_f)): print('Sample Number '+sub_f+'...... ') context_pth = None for imgs in os.listdir(join(path_temp, sub_f)): jpg_removed_split = imgs[:-4].split(' ') if len(jpg_removed_split) > 1: img_type = jpg_removed_split[-1] if img_type == 'Ad': ad_pth = join(join(path_temp, sub_f),imgs) print(ad_pth) elif img_type == 'Context': context_pth = join(join(path_temp, sub_f),imgs) #Start Predicting sample_num = int(sub_f)-1 AG = Predict.Ad_Gaze_Prediction(input_ad_path=ad_pth, input_ctpg_path=context_pth, text_detection_model_path=text_detection_model_path, LDA_model_pth=LDA_model_pth, training_ad_text_dictionary_path=training_ad_text_dictionary_path, training_lang_preposition_path=training_lang_preposition_path, training_language='dutch', ad_embeddings=ad_embeddings[sample_num].reshape(1,768), ctpg_embeddings=ctpg_embeddings[sample_num].reshape(1,768), surface_sizes=list(surfaces[sample_num]), Product_Group=list(categories[sample_num]), obj_detection_model_pth=None, ad_location=ad_locations_curr[sample_num], num_topic=20, Gaze_Time_Type='Ad') AG_predictions_per_type[sample_num] = AG BG = Predict.Ad_Gaze_Prediction(input_ad_path=ad_pth, input_ctpg_path=context_pth, text_detection_model_path=text_detection_model_path, LDA_model_pth=LDA_model_pth, training_ad_text_dictionary_path=training_ad_text_dictionary_path, training_lang_preposition_path=training_lang_preposition_path, training_language='dutch', ad_embeddings=ad_embeddings[sample_num].reshape(1,768), ctpg_embeddings=ctpg_embeddings[sample_num].reshape(1,768), surface_sizes=list(surfaces[sample_num]), Product_Group=list(categories[sample_num]), obj_detection_model_pth=None, ad_location=ad_locations_curr[sample_num], num_topic=20, Gaze_Time_Type='Brand') BG_predictions_per_type[sample_num] = BG Predicted_AG[f] = AG_predictions_per_type Predicted_BG[f] = BG_predictions_per_type print("Final results: ") diffs_rmse = {} diffs_rmsrpd = {} GT_AG_tot = [] GT_BG_tot = [] Pred_AG_tot = [] Pred_BG_tot = [] for key in Predicted_AG.keys(): print(key) print('AGs', Predicted_AG[key]) print('BGs', Predicted_BG[key]) Pred_AG_tot.append(Predicted_AG[key]) Pred_BG_tot.append(Predicted_BG[key]) GT_AG_tot.append(GT_AG[key]) GT_BG_tot.append(GT_BG[key]) rmse1 = np.sqrt(np.mean((GT_AG[key]-Predicted_AG[key])**2)) rmse2 = np.sqrt(np.mean((GT_BG[key]-Predicted_BG[key])**2)) diffs_rmse[key] = (rmse1, rmse2) rmsrpd1 = XGBoost_utils.RMSRPD(GT_AG[key],Predicted_AG[key]) rmsrpd2 = XGBoost_utils.RMSRPD(GT_BG[key],Predicted_BG[key]) diffs_rmsrpd[key] = (rmsrpd1, rmsrpd2) print() Pred_AG_tot = np.concatenate(Pred_AG_tot) Pred_BG_tot = np.concatenate(Pred_BG_tot) GT_AG_tot = np.concatenate(GT_AG_tot) GT_BG_tot = np.concatenate(GT_BG_tot) print("RMSE: ") print("Total AG: ", np.sqrt(np.mean((Pred_AG_tot-GT_AG_tot)**2))) print("Total BG: ", np.sqrt(np.mean((Pred_BG_tot-GT_BG_tot)**2))) print() for key in diffs_rmse.keys(): print(key, diffs_rmse[key][0], diffs_rmse[key][1]) print() print("RMSRPD: ") print("Total AG: ", XGBoost_utils.RMSRPD(Pred_AG_tot,GT_AG_tot)) print("Total BG: ", XGBoost_utils.RMSRPD(Pred_BG_tot,GT_BG_tot)) print() for key in diffs_rmsrpd.keys(): print(key, diffs_rmsrpd[key][0], diffs_rmsrpd[key][1]) print() print("Correlation: ") print("Total AG: ", np.corrcoef(Pred_AG_tot, GT_AG_tot)) print("Total BG: ", np.corrcoef(Pred_BG_tot, GT_BG_tot))