from datetime import date import jittor as jt from numpy.core.fromnumeric import shape from numpy.lib.type_check import imag from model import Model from jittor.dataset.mnist import MNIST import jittor.transform as trans import numpy as np import cv2 import os def ImageClassification(img_path, model): # 得到一个 HxWx3 的 array(224, 225, 3) image_start = cv2.imread(img_path) # 把图像缩放到 28x28 个像素(28, 28, 3) image = cv2.resize(image_start, (28, 28)) # print(image.shape) image = image / 255.0 # 把图像的 RGB 值从 [0, 255] 变为 [0, 1] image = image.transpose(2, 0, 1) # 把输入格式从 HWC 改为 CHW image = jt.float32(image) # 变为 Jittor Var image_end = image.unsqueeze(dim=0) # 加入 batch 维度,变为 [1, C, H, W] outputs = model(image_end) prediction = np.argmax(outputs.data, axis=1) # TODO 展示图片 cv2.imshow('MNISt', image_start) cv2.waitKey(0) print('图片识别结果:'+str(prediction[0])) def main(): pwd_path = os.path.abspath(os.path.dirname(__file__)) save_model_path = os.path.join(pwd_path, 'model/mnist_model.pkl') # TODO 加载模型 model = Model() model.load_parameters(jt.load(save_model_path)) # TODO 加载本地图片 img_path = '0.jpg' # TODO 对图片进行识别 ImageClassification(img_path, model) if __name__ == '__main__': main()