bienom's picture
hotfix
d2d0403
raw
history blame
2.05 kB
import gradio as gr
from model import SixDRepNet
import os
import numpy as np
import torch
from torchvision import transforms
import utils
import time
transformations = transforms.Compose([transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
model = SixDRepNet(backbone_name='RepVGG-A0',
backbone_file='',
deploy=True,
pretrained=False)
saved_state_dict = torch.load(os.path.join(
"weights_ALFW_A0.pth"), map_location='cpu')
if 'model_state_dict' in saved_state_dict:
model.load_state_dict(saved_state_dict['model_state_dict'])
else:
model.load_state_dict(saved_state_dict)
# Test the Model
model.eval() # Change model to 'eval' mode (BN uses moving mean/var).
th = 15
def predict(img):
img = img.convert('RGB')
img = transformations(img).unsqueeze(0)
with torch.no_grad():
start = time.time()
R_pred = model(img)
end = time.time()
timemilis = (end - start)*1000
euler = utils.compute_euler_angles_from_rotation_matrices(
R_pred,use_gpu=False)*180/np.pi
p_pred_deg = euler[:, 0].cpu().item()
y_pred_deg = euler[:, 1].cpu().item()
direction_str = ""
if p_pred_deg > th:
direction_str = "UP "
elif p_pred_deg < -th:
direction_str ="DOWN "
if y_pred_deg > th:
direction_str += "LEFT"
elif y_pred_deg < -th:
direction_str += "RIGHT"
return f"Yaw: {y_pred_deg:0.1f} \n Pitch: {p_pred_deg:0.1f}\n Direction: {direction_str} \n Time: {timemilis:0.2f}ms"
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(),
examples=["face_left.jpg","face_right.jpg","face_up.jpg","face_down.jpg"]).launch(share=True)