Files changed (1) hide show
  1. use/use.py +118 -0
use/use.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # animeGender 0.8 use
2
+ # model using file
3
+ # DOF Studio 230801
4
+
5
+ import cv2
6
+ import torch
7
+ import numpy as np
8
+ from torchvision import transforms
9
+
10
+ num_cls = 2
11
+ classes = ['female', 'male']
12
+
13
+ #############################
14
+ # graphic lib
15
+ def cmpgraph_224x224_ret(imgpath:str):
16
+ img = cv2.imread(imgpath, 1)
17
+ height, width, channels = img.shape
18
+ img2 = []
19
+ if height > width:
20
+ hnew = int(np.round(224 / width * height))
21
+ wnew = 224
22
+ img2 = cv2.resize(img, (wnew, hnew), interpolation = cv2.INTER_LANCZOS4)
23
+ img2 = img2[0:224, 0:224]
24
+ elif width > height:
25
+ wnew = int(np.round(224 / height * width))
26
+ hnew = 224
27
+ img2 = cv2.resize(img, (wnew, hnew), interpolation = cv2.INTER_LANCZOS4)
28
+ img2 = img2[0:224, 0:224]
29
+ elif width == 224 and height == 224:
30
+ img2 = img
31
+ else:
32
+ img2 = cv2.resize(img, (224, 224), interpolation = cv2.INTER_LANCZOS4)
33
+ img3 = cv2.cvtColor(img2, cv2.COLOR_BGRA2BGR)
34
+ return img3
35
+
36
+ def cmpgraph_224x224_dret(img:any):
37
+ height, width, channels = img.shape
38
+ img2 = []
39
+ if height > width:
40
+ hnew = int(np.round(224 / width * height))
41
+ wnew = 224
42
+ img2 = cv2.resize(img, (wnew, hnew), interpolation = cv2.INTER_LANCZOS4)
43
+ img2 = img2[0:224, 0:224]
44
+ elif width > height:
45
+ wnew = int(np.round(224 / height * width))
46
+ hnew = 224
47
+ img2 = cv2.resize(img, (wnew, hnew), interpolation = cv2.INTER_LANCZOS4)
48
+ img2 = img2[0:224, 0:224]
49
+ elif width == 224 and height == 224:
50
+ img2 = img
51
+ else:
52
+ img2 = cv2.resize(img, (224, 224), interpolation = cv2.INTER_LANCZOS4)
53
+ img3 = cv2.cvtColor(img2, cv2.COLOR_BGRA2BGR)
54
+ return img3
55
+
56
+ #############################
57
+ # use it
58
+ def loadmodel(model_path:str, is_cuda:bool=True):
59
+ model = torch.load(model_path)
60
+ if is_cuda == True:
61
+ model.to(torch.device('cuda'))
62
+ else:
63
+ model.to(torch.device('cpu'))
64
+ return model
65
+
66
+ # for those who use "image_path"
67
+ def predict_class(img_path:str, model:any, print_:bool = False):
68
+ img = cmpgraph_224x224_ret(img_path)
69
+ transform = transforms.Compose(
70
+ [
71
+ # transforms.Resize(224),
72
+ # transforms.CenterCrop(224),
73
+ transforms.ToTensor()
74
+ ])
75
+ img = transform(img).cuda()
76
+ img = torch.unsqueeze(img, dim=0)
77
+ model.eval()
78
+ out = model(img)
79
+ out = torch.nn.functional.softmax(out)
80
+ max = torch.max(out).item()
81
+ pmax = torch.max(out, 1)[1].item()
82
+ cls = classes[pmax]
83
+ if print_ == True:
84
+ print('This is ' + cls + ' with a confidence of ' + str(np.round(max, 3)))
85
+ return cls, max
86
+
87
+ # for those who use direct image data
88
+ def predict_img_class(img:any, model:any, print_:bool = False):
89
+ img = cmpgraph_224x224_dret(img)
90
+ transform = transforms.Compose(
91
+ [
92
+ # transforms.Resize(224),
93
+ # transforms.CenterCrop(224),
94
+ transforms.ToTensor()
95
+ ])
96
+ img = transform(img).cuda()
97
+ img = torch.unsqueeze(img, dim=0)
98
+ model.eval()
99
+ out = model(img)
100
+ out = torch.nn.functional.softmax(out)
101
+ max = torch.max(out).item()
102
+ pmax = torch.max(out, 1)[1].item()
103
+ cls = classes[pmax]
104
+ if print_ == True:
105
+ print('This is ' + cls + ' with a confidence of ' + str(np.round(max, 3)))
106
+ return cls, max
107
+
108
+ if __name__ == '__main__':
109
+
110
+ # TWO STEPS TO USE THIS MODEL
111
+ # @ DOF Studio @
112
+
113
+ # load a model from your disk
114
+ model = loadmodel("your_model_path")
115
+
116
+ # interfere an image and get the feedback
117
+ cls, confidence = predict_class("your_image_path", model, print_ = True)
118
+