import os import logging import matplotlib.pyplot as plt from PIL import Image import nltk def logging_handler(verbose, save_name, idx=0): logger = logging.getLogger(str(idx)) logger.setLevel(logging.INFO) stream_logger = logging.StreamHandler() stream_logger.setFormatter(logging.Formatter("%(message)s")) logger.addHandler(stream_logger) if save_name is not None: savepath = f"results/{save_name}" if not os.path.exists(savepath): os.makedirs(savepath) file_logger = logging.FileHandler(f"{savepath}/{idx}.log") file_logger.setFormatter(logging.Formatter("%(message)s")) logger.addHandler(file_logger) return logger def image_saver(images, save_name, idx=0, interactive=True): fig, a = plt.subplots(2,5) fig.set_size_inches(30, 15) for i in range(10): a[i//5][i%5].imshow(images[i]) a[i//5][i%5].axis('off') a[i//5][i%5].set_aspect('equal') plt.tight_layout() plt.subplots_adjust(wspace=0, hspace=0) if not interactive: plt.savefig(f"results/{save_name}/{idx}.png") else: plt.savefig(f"{save_name}.png") def assert_checks(args): if args.question_strategy=="gpt3": assert args.include_what def extract_nouns(sents): noun_list = [] for idx, s in enumerate(sents): curr = [] sent = (nltk.pos_tag(s.split())) for word in sent: if word[1] not in ["NN", "NNS"]: continue currword = word[0].replace('.','') curr.append(currword.lower()) noun_list.append(curr) return noun_list