File size: 1,625 Bytes
016285f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import logging
import matplotlib.pyplot as plt
from PIL import Image
import nltk

def logging_handler(verbose, save_name, idx=0):
    logger = logging.getLogger(str(idx))
    logger.setLevel(logging.INFO)

    stream_logger = logging.StreamHandler()
    stream_logger.setFormatter(logging.Formatter("%(message)s"))
    logger.addHandler(stream_logger)
    
    if save_name is not None:
        savepath = f"results/{save_name}"
        if not os.path.exists(savepath):
            os.makedirs(savepath)
        file_logger = logging.FileHandler(f"{savepath}/{idx}.log")
        file_logger.setFormatter(logging.Formatter("%(message)s"))
        logger.addHandler(file_logger)

    return logger


def image_saver(images, save_name, idx=0, interactive=True):
    fig, a =  plt.subplots(2,5) 
    fig.set_size_inches(30, 15)
    for i in range(10):
        a[i//5][i%5].imshow(images[i])
        a[i//5][i%5].axis('off')
        a[i//5][i%5].set_aspect('equal')
    plt.tight_layout()
    plt.subplots_adjust(wspace=0, hspace=0)
    if not interactive:
        plt.savefig(f"results/{save_name}/{idx}.png")
    else:
        plt.savefig(f"{save_name}.png")

def assert_checks(args):
    if args.question_strategy=="gpt3":
        assert args.include_what

def extract_nouns(sents):
    noun_list = []
    for idx, s in enumerate(sents):
        curr = []
        sent = (nltk.pos_tag(s.split()))
        for word in sent:
            if word[1] not in ["NN", "NNS"]: continue
            currword = word[0].replace('.','')
            curr.append(currword.lower())
        noun_list.append(curr)
    return noun_list