import gradio as gr from model import SixDRepNet import os import numpy as np import torch from torchvision import transforms import utils import time transformations = transforms.Compose([transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) model = SixDRepNet(backbone_name='RepVGG-A0', backbone_file='', deploy=True, pretrained=False) saved_state_dict = torch.load(os.path.join( "weights_ALFW_A0.pth"), map_location='cpu') if 'model_state_dict' in saved_state_dict: model.load_state_dict(saved_state_dict['model_state_dict']) else: model.load_state_dict(saved_state_dict) # Test the Model model.eval() # Change model to 'eval' mode (BN uses moving mean/var). th = 15 def predict(img): img = img.convert('RGB') img = transformations(img).unsqueeze(0) with torch.no_grad(): start = time.time() R_pred = model(img) end = time.time() timemilis = (end - start)*1000 euler = utils.compute_euler_angles_from_rotation_matrices( R_pred,use_gpu=False)*180/np.pi p_pred_deg = euler[:, 0].cpu().item() y_pred_deg = euler[:, 1].cpu().item() direction_str = "" if p_pred_deg > th: direction_str = "UP " elif p_pred_deg < -th: direction_str ="DOWN " if y_pred_deg > th: direction_str += "LEFT" elif y_pred_deg < -th: direction_str += "RIGHT" return f"Yaw: {y_pred_deg:0.1f} \n Pitch: {p_pred_deg:0.1f}\n Direction: {direction_str} \n Time: {timemilis:0.2f}ms" gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Textbox(), examples=["face_left.jpg","face_right.jpg","face_up.jpg","face_down.jpg"]).launch(share=True)