auto_avsr / pipelines /pipeline.py
mpc001's picture
Upload 125 files
09481f3
raw
history blame
No virus
3.02 kB
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 Imperial College London (Pingchuan Ma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import os
import torch
import pickle
from configparser import ConfigParser
from pipelines.model import AVSR
from pipelines.data.data_module import AVSRDataLoader
class InferencePipeline(torch.nn.Module):
def __init__(self, config_filename, detector="retinaface", face_track=False, device="cuda:0"):
super(InferencePipeline, self).__init__()
assert os.path.isfile(config_filename), f"config_filename: {config_filename} does not exist."
config = ConfigParser()
config.read(config_filename)
# modality configuration
modality = config.get("input", "modality")
self.modality = modality
# data configuration
input_v_fps = config.getfloat("input", "v_fps")
model_v_fps = config.getfloat("model", "v_fps")
# model configuration
model_path = config.get("model","model_path")
model_conf = config.get("model","model_conf")
# language model configuration
rnnlm = config.get("model", "rnnlm")
rnnlm_conf = config.get("model", "rnnlm_conf")
penalty = config.getfloat("decode", "penalty")
ctc_weight = config.getfloat("decode", "ctc_weight")
lm_weight = config.getfloat("decode", "lm_weight")
beam_size = config.getint("decode", "beam_size")
self.dataloader = AVSRDataLoader(modality, speed_rate=input_v_fps/model_v_fps, detector=detector)
self.model = AVSR(modality, model_path, model_conf, rnnlm, rnnlm_conf, penalty, ctc_weight, lm_weight, beam_size, device)
if face_track and self.modality in ["video", "audiovisual"]:
if detector == "mediapipe":
from pipelines.detectors.mediapipe.detector import LandmarksDetector
self.landmarks_detector = LandmarksDetector()
if detector == "retinaface":
from pipelines.detectors.retinaface.detector import LandmarksDetector
self.landmarks_detector = LandmarksDetector(device="cuda:0")
else:
self.landmarks_detector = None
def process_landmarks(self, data_filename, landmarks_filename):
if self.modality == "audio":
return None
if self.modality in ["video", "audiovisual"]:
if isinstance(landmarks_filename, str):
landmarks = pickle.load(open(landmarks_filename, "rb"))
else:
landmarks = self.landmarks_detector(data_filename)
return landmarks
def forward(self, data_filename, landmarks_filename=None):
assert os.path.isfile(data_filename), f"data_filename: {data_filename} does not exist."
landmarks = self.process_landmarks(data_filename, landmarks_filename)
data = self.dataloader.load_data(data_filename, landmarks)
transcript = self.model.infer(data)
return transcript