|
from datetime import date |
|
import jittor as jt |
|
from numpy.core.fromnumeric import shape |
|
from numpy.lib.type_check import imag |
|
from model import Model |
|
from jittor.dataset.mnist import MNIST |
|
import jittor.transform as trans |
|
import numpy as np |
|
import cv2 |
|
import os |
|
|
|
|
|
def ImageClassification(img_path, model): |
|
|
|
image_start = cv2.imread(img_path) |
|
|
|
image = cv2.resize(image_start, (28, 28)) |
|
|
|
image = image / 255.0 |
|
image = image.transpose(2, 0, 1) |
|
image = jt.float32(image) |
|
image_end = image.unsqueeze(dim=0) |
|
outputs = model(image_end) |
|
prediction = np.argmax(outputs.data, axis=1) |
|
|
|
cv2.imshow('MNISt', image_start) |
|
cv2.waitKey(0) |
|
print('图片识别结果:'+str(prediction[0])) |
|
|
|
|
|
def main(): |
|
pwd_path = os.path.abspath(os.path.dirname(__file__)) |
|
save_model_path = os.path.join(pwd_path, 'model/mnist_model.pkl') |
|
|
|
model = Model() |
|
model.load_parameters(jt.load(save_model_path)) |
|
|
|
img_path = '0.jpg' |
|
|
|
ImageClassification(img_path, model) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|