isLandLZ commited on
Commit
64e7562
·
1 Parent(s): d2ad7a2

Upload test.py

Browse files
Files changed (1) hide show
  1. test.py +44 -0
test.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import date
2
+ import jittor as jt
3
+ from numpy.core.fromnumeric import shape
4
+ from numpy.lib.type_check import imag
5
+ from model import Model
6
+ from jittor.dataset.mnist import MNIST
7
+ import jittor.transform as trans
8
+ import numpy as np
9
+ import cv2
10
+ import os
11
+
12
+
13
+ def ImageClassification(img_path, model):
14
+ # 得到一个 HxWx3 的 array(224, 225, 3)
15
+ image_start = cv2.imread(img_path)
16
+ # 把图像缩放到 28x28 个像素(28, 28, 3)
17
+ image = cv2.resize(image_start, (28, 28))
18
+ # print(image.shape)
19
+ image = image / 255.0 # 把图像的 RGB 值从 [0, 255] 变为 [0, 1]
20
+ image = image.transpose(2, 0, 1) # 把输入格式从 HWC 改为 CHW
21
+ image = jt.float32(image) # 变为 Jittor Var
22
+ image_end = image.unsqueeze(dim=0) # 加入 batch 维度,变为 [1, C, H, W]
23
+ outputs = model(image_end)
24
+ prediction = np.argmax(outputs.data, axis=1)
25
+ # TODO 展示图片
26
+ cv2.imshow('MNISt', image_start)
27
+ cv2.waitKey(0)
28
+ print('图片识别结果:'+str(prediction[0]))
29
+
30
+
31
+ def main():
32
+ pwd_path = os.path.abspath(os.path.dirname(__file__))
33
+ save_model_path = os.path.join(pwd_path, 'model/mnist_model.pkl')
34
+ # TODO 加载模型
35
+ model = Model()
36
+ model.load_parameters(jt.load(save_model_path))
37
+ # TODO 加载本地图片
38
+ img_path = '0.jpg'
39
+ # TODO 对图片进行识别
40
+ ImageClassification(img_path, model)
41
+
42
+
43
+ if __name__ == '__main__':
44
+ main()