import torch import cv2 import sys import numpy as np import os from PIL import Image # from zdete import Predictor as BboxPredictor from transformers import Wav2Vec2Model, Wav2Vec2Processor class MyWav2Vec(): def __init__(self, model_path, device="cuda"): super(MyWav2Vec, self).__init__() self.processor = Wav2Vec2Processor.from_pretrained(model_path) self.wav2Vec = Wav2Vec2Model.from_pretrained(model_path).to(device) self.device = device print("### Wav2Vec model loaded ###") def forward(self, x): return self.wav2Vec(x).last_hidden_state def process(self, x): return self.processor(x, sampling_rate=16000, return_tensors="pt").input_values.to(self.device) class AutoFlow(): def __init__(self, auto_flow_dir, imh=512, imw=512): super(AutoFlow, self).__init__() model_dir = auto_flow_dir+'/third_lib/model_zoo/' cfg_file = model_dir + '/zdete_detector/mobilenet_v1_0.25.yaml' model_file = model_dir + '/zdete_detector/last_39.pt' self.bbox_predictor = BboxPredictor(cfg_file, model_file, imgsz=320, conf_thres=0.6, iou_thres=0.2) self.imh = imh self.imw = imw print("### AutoFlow bbox_predictor loaded ###") def frames_to_face_regions(self, frames, toPIL=True): # 输入是bgr numpy格式 face_region_list = [] for img in frames: # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) bbox = self.bbox_predictor.predict(img)[0][0] xyxy = bbox[:4] score = bbox[4] xyxy = np.round(xyxy).astype('int') rb, re, cb, ce = xyxy[1], xyxy[3], xyxy[0], xyxy[2] face_mask = np.zeros((img.shape[0], img.shape[1])).astype('uint8') face_mask[rb:re,cb:ce] = 255 face_mask = cv2.resize(face_mask, ((self.imw, self.imh))) if toPIL: face_mask = Image.fromarray(face_mask) face_region_list.append(face_mask) return face_region_list def xyxy2x0y0wh(bbox): x0, y0, x1, y1 = bbox[:4] return [x0, y0, x1-x0, y1-y0] def video_to_frame(video_path: str, interval=1, max_frame=None, imh=None, imw=None, is_return_sum=False, is_rgb=False): vidcap = cv2.VideoCapture(video_path) success = True key_frames = [] sum_frames = None count = 0 while success: success, image = vidcap.read() if image is not None: if is_rgb: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if imh is not None and imw is not None: image = img_resize(image, imh=None, imw=None) if count % interval == 0: key_frames.append(image) if is_return_sum: if sum_frames is None: sum_frames = image.copy().astype('float32') else: sum_frames = sum_frames + image count += 1 if max_frame != None: if count >= max_frame: break vidcap.release() if is_return_sum: return key_frames, sum_frames else: return key_frames def img_resize(input_img, imh=None, imw=None, max_val=512): if imh is not None and imw is not None: width, height = imw, imh else: height, width = input_img.shape[0], input_img.shape[1] if height > width: ratio = width/height height = max_val width = ratio * height else: ratio = height/width width = max_val height = ratio * width height = int(round(height/8)*8) width = int(round(width/8)*8) input_img = cv2.resize(input_img, (width, height)) return input_img def assign_audio_to_frame(audio_input, frame_num): audio_len = audio_input.shape[0] audio_per_frame = audio_len / frame_num audio_to_frame_list = [] for f_i in range(frame_num): start_idx = int(round(f_i * audio_per_frame)) end_idx = int(round((f_i + 1) * audio_per_frame)) if start_idx >= audio_len: start_idx = int(round(start_idx - audio_per_frame)) # print(f"frame_i:{f_i}, start_index:{start_idx}, end_index:{end_idx}, audio_length:{audio_input.shape}") seg_audio = audio_input[start_idx:end_idx, :] if type(seg_audio) == np.ndarray: seg_audio = seg_audio.mean(axis=0, keepdims=True) # B * 20 * 768 elif torch.is_tensor(seg_audio): seg_audio = seg_audio.mean(dim=0, keepdim=True) audio_to_frame_list.append(seg_audio) if type(seg_audio) == np.ndarray: audio_to_frames = np.concatenate(audio_to_frame_list, 0) else: audio_to_frames = torch.cat(audio_to_frame_list, 0) return audio_to_frames def assign_audio_to_frame_new(audio_input, frame_num, pad_frame): audio_len = audio_input.shape[0] audio_to_frame_list = [] for f_i in range(frame_num): mid_index = int(f_i / frame_num * audio_len) start_idx = mid_index - pad_frame end_idx = mid_index + pad_frame + 1 if start_idx < 0: start_idx = 0 end_idx = start_idx + pad_frame * 2 + 1 if end_idx >= audio_len: end_idx = audio_len - 1 start_idx = end_idx - (pad_frame * 2 + 1) seg_audio = audio_input[None, start_idx:end_idx, :] audio_to_frame_list.append(seg_audio) if type(seg_audio) == np.ndarray: audio_to_frames = np.concatenate(audio_to_frame_list, 0) else: audio_to_frames = torch.cat(audio_to_frame_list, 0) return audio_to_frames class DotDict(dict): def __init__(self, *args, **kwargs): super(DotDict, self).__init__(*args, **kwargs) def __getattr__(self, key): value = self[key] if isinstance(value, dict): value = DotDict(value) return value