isLandLZ's picture
Upload test.py
64e7562
from datetime import date
import jittor as jt
from numpy.core.fromnumeric import shape
from numpy.lib.type_check import imag
from model import Model
from jittor.dataset.mnist import MNIST
import jittor.transform as trans
import numpy as np
import cv2
import os
def ImageClassification(img_path, model):
# 得到一个 HxWx3 的 array(224, 225, 3)
image_start = cv2.imread(img_path)
# 把图像缩放到 28x28 个像素(28, 28, 3)
image = cv2.resize(image_start, (28, 28))
# print(image.shape)
image = image / 255.0 # 把图像的 RGB 值从 [0, 255] 变为 [0, 1]
image = image.transpose(2, 0, 1) # 把输入格式从 HWC 改为 CHW
image = jt.float32(image) # 变为 Jittor Var
image_end = image.unsqueeze(dim=0) # 加入 batch 维度,变为 [1, C, H, W]
outputs = model(image_end)
prediction = np.argmax(outputs.data, axis=1)
# TODO 展示图片
cv2.imshow('MNISt', image_start)
cv2.waitKey(0)
print('图片识别结果:'+str(prediction[0]))
def main():
pwd_path = os.path.abspath(os.path.dirname(__file__))
save_model_path = os.path.join(pwd_path, 'model/mnist_model.pkl')
# TODO 加载模型
model = Model()
model.load_parameters(jt.load(save_model_path))
# TODO 加载本地图片
img_path = '0.jpg'
# TODO 对图片进行识别
ImageClassification(img_path, model)
if __name__ == '__main__':
main()