Upload test.py
Browse files
test.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import date
|
2 |
+
import jittor as jt
|
3 |
+
from numpy.core.fromnumeric import shape
|
4 |
+
from numpy.lib.type_check import imag
|
5 |
+
from model import Model
|
6 |
+
from jittor.dataset.mnist import MNIST
|
7 |
+
import jittor.transform as trans
|
8 |
+
import numpy as np
|
9 |
+
import cv2
|
10 |
+
import os
|
11 |
+
|
12 |
+
|
13 |
+
def ImageClassification(img_path, model):
|
14 |
+
# 得到一个 HxWx3 的 array(224, 225, 3)
|
15 |
+
image_start = cv2.imread(img_path)
|
16 |
+
# 把图像缩放到 28x28 个像素(28, 28, 3)
|
17 |
+
image = cv2.resize(image_start, (28, 28))
|
18 |
+
# print(image.shape)
|
19 |
+
image = image / 255.0 # 把图像的 RGB 值从 [0, 255] 变为 [0, 1]
|
20 |
+
image = image.transpose(2, 0, 1) # 把输入格式从 HWC 改为 CHW
|
21 |
+
image = jt.float32(image) # 变为 Jittor Var
|
22 |
+
image_end = image.unsqueeze(dim=0) # 加入 batch 维度,变为 [1, C, H, W]
|
23 |
+
outputs = model(image_end)
|
24 |
+
prediction = np.argmax(outputs.data, axis=1)
|
25 |
+
# TODO 展示图片
|
26 |
+
cv2.imshow('MNISt', image_start)
|
27 |
+
cv2.waitKey(0)
|
28 |
+
print('图片识别结果:'+str(prediction[0]))
|
29 |
+
|
30 |
+
|
31 |
+
def main():
|
32 |
+
pwd_path = os.path.abspath(os.path.dirname(__file__))
|
33 |
+
save_model_path = os.path.join(pwd_path, 'model/mnist_model.pkl')
|
34 |
+
# TODO 加载模型
|
35 |
+
model = Model()
|
36 |
+
model.load_parameters(jt.load(save_model_path))
|
37 |
+
# TODO 加载本地图片
|
38 |
+
img_path = '0.jpg'
|
39 |
+
# TODO 对图片进行识别
|
40 |
+
ImageClassification(img_path, model)
|
41 |
+
|
42 |
+
|
43 |
+
if __name__ == '__main__':
|
44 |
+
main()
|