Syed Abdul Gaffar Shakhadri commited on
Commit
024304f
·
unverified ·
1 Parent(s): dbbc0af

added inference script

Browse files
Files changed (1) hide show
  1. inference.py +125 -0
inference.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ import torch
3
+ from PIL import Image
4
+ from config import get_inference_config
5
+ from models import build_model
6
+ from torch.autograd import Variable
7
+ from torchvision.transforms import transforms
8
+ import numpy as np
9
+ import argparse
10
+
11
+ try:
12
+ from apex import amp
13
+ except ImportError:
14
+ amp = None
15
+
16
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
17
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
18
+
19
+
20
+ class Namespace:
21
+ def __init__(self, **kwargs):
22
+ self.__dict__.update(kwargs)
23
+
24
+
25
+ def model_config(config_path):
26
+ args = Namespace(cfg=config_path)
27
+ config = get_inference_config(args)
28
+ return config
29
+
30
+
31
+ def read_class_names(file_path):
32
+ file = open(file_path, 'r')
33
+ lines = file.readlines()
34
+ class_list = []
35
+
36
+ for l in lines:
37
+ line = l.strip().split()
38
+ # class_list.append(line[0])
39
+ class_list.append(line[1][4:])
40
+
41
+ classes = tuple(class_list)
42
+ return classes
43
+
44
+
45
+ class GenerateEmbedding:
46
+ def __init__(self, text_file):
47
+ self.text_file = text_file
48
+
49
+ self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
50
+ self.model = AutoModel.from_pretrained("bert-base-uncased")
51
+
52
+ def generate(self):
53
+ text_list = []
54
+ with open(self.text_file, 'r') as f_text:
55
+ for line in f_text:
56
+ line = line.encode(encoding='UTF-8', errors='strict')
57
+ line = line.replace(b'\xef\xbf\xbd\xef\xbf\xbd', b' ')
58
+ line = line.decode('UTF-8', 'strict')
59
+ text_list.append(line)
60
+ # data = f_text.read()
61
+ select_index = np.random.randint(len(text_list))
62
+ inputs = self.tokenizer(text_list[select_index], return_tensors="pt", padding="max_length",
63
+ truncation=True, max_length=32)
64
+ outputs = self.model(**inputs)
65
+ embedding_mean = outputs[1].mean(dim=0).reshape(1, -1).detach().numpy()
66
+ embedding_full = outputs[1].detach().numpy()
67
+ embedding_words = outputs[0] # outputs[0].detach().numpy()
68
+ return None, None, embedding_words
69
+
70
+
71
+ class Inference:
72
+ def __init__(self, config_path, model_path):
73
+ self.config_path = config_path
74
+ self.model_path = model_path
75
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
76
+ # self.classes = ("cat", "dog")
77
+ self.classes = read_class_names(r"D:\dataset\CUB_200_2011\CUB_200_2011\classes_custom.txt")
78
+
79
+ self.config = model_config(self.config_path)
80
+ self.model = build_model(self.config)
81
+ self.checkpoint = torch.load(self.model_path, map_location='cpu')
82
+ self.model.load_state_dict(self.checkpoint['model'], strict=False)
83
+ self.model.eval()
84
+ self.model.cuda()
85
+
86
+ self.transform_img = transforms.Compose([
87
+ transforms.Resize((224, 224), interpolation=Image.BILINEAR),
88
+ transforms.ToTensor(), # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
89
+ transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
90
+ ])
91
+
92
+ def infer(self, img_path, meta_data_path):
93
+ _, _, meta = GenerateEmbedding(meta_data_path).generate()
94
+ meta = meta.cuda()
95
+ img = Image.open(img_path).convert('RGB')
96
+ img = self.transform_img(img)
97
+ img.unsqueeze_(0)
98
+ img = img.cuda()
99
+ img = Variable(img).to(self.device)
100
+ out = self.model(img, meta)
101
+
102
+ _, pred = torch.max(out.data, 1)
103
+ predict = self.classes[pred.data.item()]
104
+ # print(Fore.MAGENTA + f"The Prediction is: {predict}")
105
+ return predict
106
+
107
+
108
+ def parse_option():
109
+ parser = argparse.ArgumentParser('MetaFG Inference script', add_help=False)
110
+ parser.add_argument('--cfg', type=str, default='D:/pycharmprojects/MetaFormer/configs/MetaFG_meta_bert_1_224.yaml', metavar="FILE", help='path to config file', )
111
+ # easy config modification
112
+ parser.add_argument('--model-path', default='D:\pycharmprojects\MetaFormer\output\MetaFG_meta_1\cub_200\ckpt_epoch_92.pth', type=str, help="path to model data")
113
+ parser.add_argument('--img-path', default=r"D:\dataset\CUB_200_2011\CUB_200_2011\images\012.Yellow_headed_Blackbird\Yellow_Headed_Blackbird_0003_8337.jpg", type=str, help='path to image')
114
+ parser.add_argument('--meta-path', default=r"D:\dataset\CUB_200_2011\text_c10\012.Yellow_headed_Blackbird\Yellow_Headed_Blackbird_0003_8337.txt", type=str, help='path to meta data')
115
+ args = parser.parse_args()
116
+ return args
117
+
118
+
119
+ if __name__ == '__main__':
120
+ args = parse_option()
121
+ result = Inference(config_path=args.cfg,
122
+ model_path=args.model_path).infer(img_path=args.img_path, meta_data_path=args.meta_path)
123
+ print("Predicted: ", result)
124
+
125
+ # Usage: python inference.py --cfg 'path/to/cfg' --model_path 'path/to/model' --img-path 'path/to/img' --meta-path 'path/to/meta'