vardaan123 commited on
Commit
bc38547
·
1 Parent(s): 225150e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -0
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import argparse
4
+ import numpy as np
5
+ import random
6
+ import pandas as pd
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torchvision
11
+ import sys
12
+ import json
13
+ from collections import defaultdict
14
+ import math
15
+
16
+ from model import DistMult
17
+
18
+ from tqdm import tqdm
19
+ from utils import collate_list, detach_and_clone, move_to
20
+ from PIL import Image
21
+ from torchvision import transforms
22
+
23
+ _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
24
+ _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225]
25
+
26
+ def evaluate(model, id2entity, target_list, args):
27
+ model.eval()
28
+ torch.set_grad_enabled(False)
29
+
30
+ overall_id_to_name = json.load(open('data/iwildcam_v2.0/overall_id_to_name.json'))
31
+
32
+ img = Image.open(args.img_path).convert('RGB')
33
+
34
+ transform_steps = transforms.Compose([transforms.Resize((448, 448)), transforms.ToTensor(), transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD)])
35
+ h = transform_steps(img)
36
+ r = torch.tensor([3])
37
+
38
+ h = move_to(h, args.device).unsqueeze(0)
39
+ r = move_to(r, args.device).unsqueeze(0)
40
+
41
+ outputs = model.forward_ce(h, r, triple_type=('image', 'id'))
42
+
43
+ y_pred = detach_and_clone(outputs.cpu())
44
+ y_pred = y_pred.argmax(-1)
45
+
46
+ pred_label = target_list[y_pred].item()
47
+ species_label = overall_id_to_name[str(id2entity[pred_label])]
48
+ print('species label = {}'.format(species_label))
49
+
50
+ return
51
+
52
+ def _get_id(dict, key):
53
+ id = dict.get(key, None)
54
+ if id is None:
55
+ id = len(dict)
56
+ dict[key] = id
57
+ return id
58
+
59
+ def generate_target_list(data, entity2id):
60
+ sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']]
61
+ sub = list(sub['t'])
62
+ categories = []
63
+ for item in tqdm(sub):
64
+ if entity2id[str(int(float(item)))] not in categories:
65
+ categories.append(entity2id[str(int(float(item)))])
66
+ # print('categories = {}'.format(categories))
67
+ # print("No. of target categories = {}".format(len(categories)))
68
+ return torch.tensor(categories, dtype=torch.long).unsqueeze(-1)
69
+
70
+
71
+
72
+ if __name__=='__main__':
73
+ parser = argparse.ArgumentParser()
74
+ parser.add_argument('--data-dir', type=str, default='data/iwildcam_v2.0/')
75
+ parser.add_argument('--img-path', type=str, required=True, help='path to species image to be classified')
76
+ parser.add_argument('--seed', type=int, default=813765)
77
+ parser.add_argument('--ckpt-path', type=str, default=None, help='path to ckpt for restarting expt')
78
+ parser.add_argument('--debug', action='store_true')
79
+ parser.add_argument('--no-cuda', action='store_true')
80
+ parser.add_argument('--batch_size', type=int, default=16)
81
+
82
+ parser.add_argument('--embedding-dim', type=int, default=512)
83
+ parser.add_argument('--location_input_dim', type=int, default=2)
84
+ parser.add_argument('--time_input_dim', type=int, default=1)
85
+ parser.add_argument('--mlp_location_numlayer', type=int, default=3)
86
+ parser.add_argument('--mlp_time_numlayer', type=int, default=3)
87
+
88
+ parser.add_argument('--img-embed-model', choices=['resnet18', 'resnet50'], default='resnet50')
89
+ parser.add_argument('--use-data-subset', action='store_true')
90
+ parser.add_argument('--subset-size', type=int, default=10)
91
+
92
+ args = parser.parse_args()
93
+
94
+ print('args = {}'.format(args))
95
+ args.device = torch.device('cuda') if not args.no_cuda and torch.cuda.is_available() else torch.device('cpu')
96
+
97
+ # Set random seed
98
+ torch.manual_seed(args.seed)
99
+ np.random.seed(args.seed)
100
+ random.seed(args.seed)
101
+
102
+ datacsv = pd.read_csv(os.path.join(args.data_dir, 'dataset_subtree.csv'), low_memory=False)
103
+
104
+ entity_id_file = os.path.join(args.data_dir, 'entity2id_subtree.json')
105
+
106
+ if not os.path.exists(entity_id_file):
107
+ entity2id = {} # each of triple types have their own entity2id
108
+
109
+ for i in tqdm(range(datacsv.shape[0])):
110
+ if datacsv.iloc[i,1] == "id":
111
+ _get_id(entity2id, str(int(float(datacsv.iloc[i,0]))))
112
+
113
+ if datacsv.iloc[i,-2] == "id":
114
+ _get_id(entity2id, str(int(float(datacsv.iloc[i,-3]))))
115
+ json.dump(entity2id, open(entity_id_file, 'w'))
116
+ else:
117
+ entity2id = json.load(open(entity_id_file, 'r'))
118
+
119
+ id2entity = {v:k for k,v in entity2id.items()}
120
+
121
+ num_ent_id = len(entity2id)
122
+
123
+ # print('len(entity2id) = {}'.format(len(entity2id)))
124
+
125
+ target_list = generate_target_list(datacsv, entity2id)
126
+
127
+ model = DistMult(args, num_ent_id, target_list, args.device)
128
+
129
+ model.to(args.device)
130
+
131
+ # restore from ckpt
132
+ if args.ckpt_path:
133
+ ckpt = torch.load(args.ckpt_path, map_location=args.device)
134
+ model.load_state_dict(ckpt['model'], strict=False)
135
+ print('ckpt loaded...')
136
+
137
+ evaluate(model, id2entity, target_list, args)