File size: 7,361 Bytes
024304f
 
 
 
 
 
 
 
 
8f243be
 
 
 
024304f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f243be
024304f
8f243be
024304f
 
 
 
 
8f243be
 
 
 
 
 
 
024304f
8f243be
 
 
 
024304f
 
 
8f243be
024304f
8f243be
024304f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f243be
 
024304f
 
 
8f243be
024304f
 
8f243be
024304f
 
8f243be
 
 
 
 
 
024304f
8f243be
 
 
024304f
 
8f243be
024304f
 
 
 
8f243be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
024304f
 
8f243be
024304f
 
 
8f243be
 
 
 
 
 
 
 
 
024304f
 
 
 
8f243be
024304f
8f243be
 
 
 
 
024304f
 
 
 
 
 
8f243be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
024304f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
from transformers import AutoTokenizer, AutoModel
import torch
from PIL import Image
from config import get_inference_config
from models import build_model
from torch.autograd import Variable
from torchvision.transforms import transforms
import numpy as np
import argparse
from pycocotools.coco import COCO
import requests
import os
from tqdm.auto import tqdm

try:
    from apex import amp
except ImportError:
    amp = None

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)


class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)


def model_config(config_path):
    args = Namespace(cfg=config_path)
    config = get_inference_config(args)
    return config


def read_class_names(file_path):
    file = open(file_path, 'r')
    lines = file.readlines()
    class_list = []

    for l in lines:
        line = l.strip()
        # class_list.append(line[0])
        class_list.append(line)

    classes = tuple(class_list)
    return classes


def read_class_names_coco(file_path):
    dataset = COCO(file_path)
    classes =  [dataset.cats[k]['name'] for k in sorted(dataset.cats.keys())]

    with open("names.txt", 'w') as fp:
        for c in classes:
            fp.write(f"{c}\n")

    return classes

class GenerateEmbedding:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        self.model = AutoModel.from_pretrained("bert-base-uncased")

    def generate(self, text_file):
        text_list = []
        with open(text_file, 'r') as f_text:
            for line in f_text:
                line = line.encode(encoding='UTF-8', errors='strict')
                line = line.replace(b'\xef\xbf\xbd\xef\xbf\xbd', b' ')
                line = line.decode('UTF-8', 'strict')
                text_list.append(line)
            # data = f_text.read()
        select_index = np.random.randint(len(text_list))
        inputs = self.tokenizer(text_list[select_index], return_tensors="pt", padding="max_length",
                                truncation=True, max_length=32)
        outputs = self.model(**inputs)
        embedding_mean = outputs[1].mean(dim=0).reshape(1, -1).detach().numpy()
        embedding_full = outputs[1].detach().numpy()
        embedding_words = outputs[0] # outputs[0].detach().numpy()
        return None, None, embedding_words


class Inference:
    def __init__(self, config_path, model_path, names_path):

        self.config_path = config_path
        self.model_path = model_path
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.classes = read_class_names(names_path)

        self.config = model_config(self.config_path)

        self.model = build_model(self.config)
        self.checkpoint = torch.load(self.model_path, map_location='cpu')

        if 'model' in self.checkpoint:
            self.model.load_state_dict(self.checkpoint['model'], strict=False)
        else:
            self.model.load_state_dict(self.checkpoint, strict=False)
            
        self.model.eval()
        self.model.to(self.device)
        self.topk = 10
        self.embedding_gen = GenerateEmbedding()

        self.transform_img = transforms.Compose([
            transforms.Resize((self.config.DATA.IMG_SIZE, self.config.DATA.IMG_SIZE), interpolation=Image.BILINEAR),
            transforms.ToTensor(), # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
        ])

    def infer(self, img_path, meta_data_path, topk=None):

        if isinstance(img_path, str):
            if img_path.startswith("http"):
                img = Image.open(requests.get(img_path, stream=True).raw).convert('RGB')
            else:
                img = Image.open(img_path).convert('RGB')
        else:
            img = img_path
        
        """

        _, _, meta = self.embedding_gen(meta_data_path)

        meta = meta.to(self.device)

        """
        meta = None

        img = self.transform_img(img)
        img.unsqueeze_(0)
        img = img.to(self.device)
        img = Variable(img).to(self.device)
        out = self.model(img, meta)

        f = torch.nn.Softmax(dim=1)
        y_pred = f(out)
        indices = reversed(torch.argsort(y_pred, dim=1).squeeze().tolist())

        if topk is not None:
            predict = [{self.classes[idx] : y_pred.squeeze()[idx].cpu().item() for idx in indices[:topk]}]
            return predict
        else:
            return {self.classes[idx] : y_pred.squeeze()[idx].cpu().item() for idx in indices}


def parse_option():
    parser = argparse.ArgumentParser('MetaFG Inference script', add_help=False)
    parser.add_argument('--cfg', type=str, metavar="FILE", help='path to config file', default="configs/MetaFG_2_224.yaml")
    # easy config modification
    parser.add_argument('--model-path', type=str, help="path to model data", default="ckpt_epoch_12.pth")
    parser.add_argument('--img-path', type=str, help='path to image')
    parser.add_argument('--img-folder', type=str, help='path to image')
    parser.add_argument('--meta-path', default="meta.txt", type=str, help='path to meta data')
    parser.add_argument('--names-path', default="names_mf2.txt", type=str, help='path to meta data')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_option()
    model = Inference(config_path=args.cfg,
                       model_path=args.model_path,
                       names_path=args.names_path)
    
    from glob import glob
    glob_imgs = glob(os.path.join(args.img_folder, "*.jpg"))
    out_dir = f"results_{os.path.splitext(os.path.basename(args.model_path))[0]}"
    os.makedirs(out_dir, exist_ok=True)

    for img in tqdm(glob_imgs):
        try:
            res = model.infer(img_path=img, meta_data_path=args.meta_path)
        except KeyboardInterrupt:
            break
        except Exception as e:
            print(e)
            continue
        
        out = {}
        out['preds'] = res

        """

        # Out is a list of (class, score). Return true/false if the top1 class is correct

        out['top1_correct'] = '_'.join(res[0][1].split(' ')).lower() in os.path.basename(img).lower()



        out['top5_correct'] = False

        print(os.path.basename(img).lower())

        for i in range(5):

            out['top5_correct'] |= '_'.join(res[i][1].split(' ')).lower() in os.path.basename(img).lower()

            print('_'.join(res[i][1].split(' ')).lower())

            

        out['top10_correct'] = False

        for i in range(10):

            out['top10_correct'] |= '_'.join(res[i][1].split(' ')).lower() in os.path.basename(img).lower()

        """

        # output json with inference results, use image basename 
        # as filename
        import json
        with open(os.path.join(out_dir, os.path.splitext(os.path.basename(img))[0]+".json"), 'w') as fp:
            json.dump(out, fp, indent=1)

# Usage: python inference.py --cfg 'path/to/cfg' --model_path 'path/to/model' --img-path 'path/to/img' --meta-path 'path/to/meta'