LIPREAD / inference.py
wissemkarous's picture
init
b072050 verified
raw
history blame
8.55 kB
import argparse
import os
from model import LipCoordNet
from dataset import MyDataset
import torch
import cv2
import face_alignment
import numpy as np
import dlib
import glob
def get_position(size, padding=0.25):
x = [
0.000213256,
0.0752622,
0.18113,
0.29077,
0.393397,
0.586856,
0.689483,
0.799124,
0.904991,
0.98004,
0.490127,
0.490127,
0.490127,
0.490127,
0.36688,
0.426036,
0.490127,
0.554217,
0.613373,
0.121737,
0.187122,
0.265825,
0.334606,
0.260918,
0.182743,
0.645647,
0.714428,
0.793132,
0.858516,
0.79751,
0.719335,
0.254149,
0.340985,
0.428858,
0.490127,
0.551395,
0.639268,
0.726104,
0.642159,
0.556721,
0.490127,
0.423532,
0.338094,
0.290379,
0.428096,
0.490127,
0.552157,
0.689874,
0.553364,
0.490127,
0.42689,
]
y = [
0.106454,
0.038915,
0.0187482,
0.0344891,
0.0773906,
0.0773906,
0.0344891,
0.0187482,
0.038915,
0.106454,
0.203352,
0.307009,
0.409805,
0.515625,
0.587326,
0.609345,
0.628106,
0.609345,
0.587326,
0.216423,
0.178758,
0.179852,
0.231733,
0.245099,
0.244077,
0.231733,
0.179852,
0.178758,
0.216423,
0.244077,
0.245099,
0.780233,
0.745405,
0.727388,
0.742578,
0.727388,
0.745405,
0.780233,
0.864805,
0.902192,
0.909281,
0.902192,
0.864805,
0.784792,
0.778746,
0.785343,
0.778746,
0.784792,
0.824182,
0.831803,
0.824182,
]
x, y = np.array(x), np.array(y)
x = (x + padding) / (2 * padding + 1)
y = (y + padding) / (2 * padding + 1)
x = x * size
y = y * size
return np.array(list(zip(x, y)))
def transformation_from_points(points1, points2):
points1 = points1.astype(np.float64)
points2 = points2.astype(np.float64)
c1 = np.mean(points1, axis=0)
c2 = np.mean(points2, axis=0)
points1 -= c1
points2 -= c2
s1 = np.std(points1)
s2 = np.std(points2)
points1 /= s1
points2 /= s2
U, S, Vt = np.linalg.svd(points1.T * points2)
R = (U * Vt).T
return np.vstack(
[
np.hstack(((s2 / s1) * R, c2.T - (s2 / s1) * R * c1.T)),
np.matrix([0.0, 0.0, 1.0]),
]
)
def load_video(file, device: str):
# create the samples directory if it doesn't exist
if not os.path.exists("samples"):
os.makedirs("samples")
p = os.path.join("samples")
output = os.path.join("samples", "%04d.jpg")
cmd = "ffmpeg -hide_banner -loglevel error -i {} -qscale:v 2 -r 25 {}".format(
file, output
)
os.system(cmd)
files = os.listdir(p)
files = sorted(files, key=lambda x: int(os.path.splitext(x)[0]))
array = [cv2.imread(os.path.join(p, file)) for file in files]
array = list(filter(lambda im: not im is None, array))
fa = face_alignment.FaceAlignment(
face_alignment.LandmarksType._2D, flip_input=False, device=device
)
points = [fa.get_landmarks(I) for I in array]
front256 = get_position(256)
video = []
for point, scene in zip(points, array):
if point is not None:
shape = np.array(point[0])
shape = shape[17:]
M = transformation_from_points(np.matrix(shape), np.matrix(front256))
img = cv2.warpAffine(scene, M[:2], (256, 256))
(x, y) = front256[-20:].mean(0).astype(np.int32)
w = 160 // 2
img = img[y - w // 2 : y + w // 2, x - w : x + w, ...]
img = cv2.resize(img, (128, 64))
video.append(img)
video = np.stack(video, axis=0).astype(np.float32)
video = torch.FloatTensor(video.transpose(3, 0, 1, 2)) / 255.0
return video
def extract_lip_coordinates(detector, predictor, img_path):
image = cv2.imread(img_path)
image = cv2.resize(image, (600, 500))
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
rects = detector(gray)
retries = 3
while retries > 0:
try:
assert len(rects) == 1
break
except AssertionError as e:
retries -= 1
for rect in rects:
# apply the shape predictor to the face ROI
shape = predictor(gray, rect)
x = []
y = []
for n in range(48, 68):
x.append(shape.part(n).x)
y.append(shape.part(n).y)
return [x, y]
def generate_lip_coordinates(frame_images_directory, detector, predictor):
frames = glob.glob(frame_images_directory + "/*.jpg")
frames.sort()
img = cv2.imread(frames[0])
height, width, layers = img.shape
coords = []
for frame in frames:
x_coords, y_coords = extract_lip_coordinates(detector, predictor, frame)
normalized_coords = []
for x, y in zip(x_coords, y_coords):
normalized_x = x / width
normalized_y = y / height
normalized_coords.append((normalized_x, normalized_y))
coords.append(normalized_coords)
coords_array = np.array(coords, dtype=np.float32)
coords_array = torch.from_numpy(coords_array)
return coords_array
def ctc_decode(y):
y = y.argmax(-1)
t = y.size(0)
result = []
for i in range(t + 1):
result.append(MyDataset.ctc_arr2txt(y[:i], start=1))
return result
def output_video(p, txt, output_path):
files = os.listdir(p)
files = sorted(files, key=lambda x: int(os.path.splitext(x)[0]))
font = cv2.FONT_HERSHEY_SIMPLEX
for file, line in zip(files, txt):
img = cv2.imread(os.path.join(p, file))
h, w, _ = img.shape
img = cv2.putText(
img, line, (w // 8, 11 * h // 12), font, 1.2, (0, 0, 0), 3, cv2.LINE_AA
)
img = cv2.putText(
img,
line,
(w // 8, 11 * h // 12),
font,
1.2,
(255, 255, 255),
0,
cv2.LINE_AA,
)
h = h // 2
w = w // 2
img = cv2.resize(img, (w, h))
cv2.imwrite(os.path.join(p, file), img)
# create the output_videos directory if it doesn't exist
if not os.path.exists(output_path):
os.makedirs(output_path)
output = os.path.join(output_path, "output.mp4")
cmd = "ffmpeg -hide_banner -loglevel error -y -i {}/%04d.jpg -r 25 {}".format(
p, output
)
os.system(cmd)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--weights",
type=str,
default="pretrain/LipCoordNet_coords_loss_0.025581153109669685_wer_0.01746208431890914_cer_0.006488426950253695.pt",
help="path to the weights file",
)
parser.add_argument(
"--input_video",
type=str,
help="path to the input video frames",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="device to run the model on",
)
parser.add_argument(
"--output_path",
type=str,
default="output_videos",
help="directory to save the output video",
)
args = parser.parse_args()
# validate if device is valid
if args.device not in ("cuda", "cpu"):
raise ValueError("Invalid device, must be either cuda or cpu")
device = args.device
# load model
model = LipCoordNet()
model.load_state_dict(torch.load(args.weights))
model = model.to(device)
model.eval()
detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor(
"lip_coordinate_extraction/shape_predictor_68_face_landmarks_GTX.dat"
)
# load video
video = load_video(args.input_video, device)
# generate lip coordinates
coords = generate_lip_coordinates("samples", detector, predictor)
pred = model(video[None, ...].to(device), coords[None, ...].to(device))
output = ctc_decode(pred[0])
print(output[-1])
output_video("samples", output, args.output_path)
if __name__ == "__main__":
main()