photo2cartoon / p2c /models /face_features.py
hylee's picture
init
eb7d2bb
raw
history blame
1.06 kB
import torch
import torch.nn.functional as F
from .mobilefacenet import MobileFaceNet
class FaceFeatures(object):
def __init__(self, weights_path, device):
self.device = device
self.model = MobileFaceNet(512).to(device)
self.model.load_state_dict(torch.load(weights_path))
self.model.eval()
def infer(self, batch_tensor):
# crop face
h, w = batch_tensor.shape[2:]
top = int(h / 2.1 * (0.8 - 0.33))
bottom = int(h - (h / 2.1 * 0.3))
size = bottom - top
left = int(w / 2 - size / 2)
right = left + size
batch_tensor = batch_tensor[:, :, top: bottom, left: right]
batch_tensor = F.interpolate(batch_tensor, size=[112, 112], mode='bilinear', align_corners=True)
features = self.model(batch_tensor)
return features
def cosine_distance(self, batch_tensor1, batch_tensor2):
feature1 = self.infer(batch_tensor1)
feature2 = self.infer(batch_tensor2)
return 1 - torch.cosine_similarity(feature1, feature2)