#! /usr/bin/env python # -*- coding: utf-8 -*- # Copyright 2023 Imperial College London (Pingchuan Ma) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import os import torch import pickle from configparser import ConfigParser from pipelines.model import AVSR from pipelines.data.data_module import AVSRDataLoader class InferencePipeline(torch.nn.Module): def __init__(self, config_filename, detector="retinaface", face_track=False, device="cuda:0"): super(InferencePipeline, self).__init__() assert os.path.isfile(config_filename), f"config_filename: {config_filename} does not exist." config = ConfigParser() config.read(config_filename) # modality configuration modality = config.get("input", "modality") self.modality = modality # data configuration input_v_fps = config.getfloat("input", "v_fps") model_v_fps = config.getfloat("model", "v_fps") # model configuration model_path = config.get("model","model_path") model_conf = config.get("model","model_conf") # language model configuration rnnlm = config.get("model", "rnnlm") rnnlm_conf = config.get("model", "rnnlm_conf") penalty = config.getfloat("decode", "penalty") ctc_weight = config.getfloat("decode", "ctc_weight") lm_weight = config.getfloat("decode", "lm_weight") beam_size = config.getint("decode", "beam_size") self.dataloader = AVSRDataLoader(modality, speed_rate=input_v_fps/model_v_fps, detector=detector) self.model = AVSR(modality, model_path, model_conf, rnnlm, rnnlm_conf, penalty, ctc_weight, lm_weight, beam_size, device) if face_track and self.modality in ["video", "audiovisual"]: if detector == "mediapipe": from pipelines.detectors.mediapipe.detector import LandmarksDetector self.landmarks_detector = LandmarksDetector() if detector == "retinaface": from pipelines.detectors.retinaface.detector import LandmarksDetector self.landmarks_detector = LandmarksDetector(device="cuda:0") else: self.landmarks_detector = None def process_landmarks(self, data_filename, landmarks_filename): if self.modality == "audio": return None if self.modality in ["video", "audiovisual"]: if isinstance(landmarks_filename, str): landmarks = pickle.load(open(landmarks_filename, "rb")) else: landmarks = self.landmarks_detector(data_filename) return landmarks def forward(self, data_filename, landmarks_filename=None): assert os.path.isfile(data_filename), f"data_filename: {data_filename} does not exist." landmarks = self.process_landmarks(data_filename, landmarks_filename) data = self.dataloader.load_data(data_filename, landmarks) transcript = self.model.infer(data) return transcript