SignLanguage / align.py
ZiyuG's picture
Update align.py
6f37bf4 verified
import numpy as np
from scipy.spatial.distance import cdist
from fastdtw import fastdtw
import json
import cv2
def read_video_frames(video_path):
cap = cv2.VideoCapture(video_path)
frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
cap.release()
return frames
def extract_keypoints(sequence):
keypoints_sequence = []
for frame in sequence:
keypoints = frame['instances'][0]['keypoints'][5:13] + frame['instances'][0]['keypoints'][91:133] # 获取每一帧的关键点(仅考虑上半身及手部的点,共计50个)
keypoints_sequence.append(np.array(keypoints))
return keypoints_sequence
# 计算两帧之间的距离(这里使用欧氏距离)
def calculate_distance_matrix(seq1, seq2):
distances = []
for i in range(len(seq1)):
frame_distances = []
for j in range(len(seq2)):
distance = np.linalg.norm(seq1[i] - seq2[j], axis=1).mean() # 计算关键点的平均距离
frame_distances.append(distance)
distances.append(frame_distances)
return np.array(distances)
# 计算两个手语序列的最佳对齐路径
def align_sequences(seq1, seq2):
keypoints_seq1 = extract_keypoints(seq1)
keypoints_seq2 = extract_keypoints(seq2)
distances = calculate_distance_matrix(keypoints_seq1, keypoints_seq2)
distance, path = fastdtw(keypoints_seq1, keypoints_seq2, dist=lambda x, y: np.linalg.norm(x - y))
return distance, path
def filter_sequence_by_alignment(sequence, alignment_path, index):
"""
根据alignment_path筛选序列
sequence: 输入序列 (sequence1 or sequence2)
alignment_path: 对齐路径
index: 选择是sequence1还是sequence2 (0表示sequence1, 1表示sequence2)
"""
filtered_sequence = []
for path in alignment_path:
frame_index = path[index]
filtered_sequence.append(sequence[frame_index])
return filtered_sequence
def scale_keypoints(standard, user, seq1_frames, seq2_frames):
height1, width1, _ = seq1_frames[0].shape
height2, width2, _ = seq2_frames[0].shape
sequence1 = json.load(open(standard + ".json", 'r'))
sequence2 = json.load(open(user + ".json", 'r'))
unified_width = int(max(width1, width2))
unified_height = int(max(height1, height2))
# 计算标准视频和用户视频的缩放比例 float
scale_x_standard = unified_width / width1
scale_y_standard = unified_height / height1
scale_x_user = unified_width / width2
scale_y_user = unified_height / height2
# 如果标准视频的宽、高需要缩放
if scale_x_standard != 1.0 or scale_y_standard != 1.0:
for frame in range(len(sequence1)):
keypoints_00 = sequence1[frame]["instances"][0]["keypoints"][5:13] + sequence1[frame]["instances"][0]["keypoints"][91:133]
adjusted_keypoints_00 = [
[point[0] * scale_x_standard, point[1] * scale_y_standard]
for point in keypoints_00
]
sequence1[frame]["instances"][0]["keypoints"][5:13] = adjusted_keypoints_00[:8]
sequence1[frame]["instances"][0]["keypoints"][91:133] = adjusted_keypoints_00[8:]
# 如果用户视频的宽、高需要缩放
if scale_x_user != 1.0 or scale_y_user != 1.0:
for frame in range(len(sequence2)):
keypoints_01 = sequence2[frame]["instances"][0]["keypoints"][5:13] + sequence2[frame]["instances"][0]["keypoints"][91:133]
adjusted_keypoints_01 = [
[point[0] * scale_x_user, point[1] * scale_y_user]
for point in keypoints_01
]
sequence2[frame]["instances"][0]["keypoints"][5:13] = adjusted_keypoints_01[:8]
sequence2[frame]["instances"][0]["keypoints"][91:133] = adjusted_keypoints_01[8:]
json.dump(sequence1, open(standard + ".json", 'w'), indent=4)
json.dump(sequence2, open(user + ".json", 'w'), indent=4)
# 根据对齐路径提取帧并创建新视频
def create_aligned_videos(seq1_frames, seq2_frames, alignment_path, output_combined_path, output_seq1_path, output_seq2_path, fps=30):
height1, width1, _ = seq1_frames[0].shape
height2, width2, _ = seq2_frames[0].shape
# height = max(height1, height2)
# width = width1 + width2
# 计算输出视频的统一宽度和高度
max_height = max(height1, height2)
max_width = max(width1, width2)
# 按比例计算缩放系数
scale1 = min(max_width / width1, max_height / height1)
scale2 = min(max_width / width2, max_height / height2)
# 缩放后统一的宽高
unified_width = int(max_width)
unified_height = int(max_height)
# 创建输出视频
fourcc = cv2.VideoWriter_fourcc(*'XVID')
# 创建并排视频
# combined_out = cv2.VideoWriter(output_combined_path, fourcc, fps, (unified_width * 2, unified_height))
# 创建单独的视频
seq1_out = cv2.VideoWriter(output_seq1_path, fourcc, fps, (unified_width, unified_height))
seq2_out = cv2.VideoWriter(output_seq2_path, fourcc, fps, (unified_width, unified_height))
for idx1, idx2 in alignment_path:
# 获取对齐的帧
frame1 = seq1_frames[idx1]
frame2 = seq2_frames[idx2]
# 等比例缩放两个视频的帧到统一尺寸
frame1_resized = cv2.resize(frame1, (unified_width, unified_height), interpolation=cv2.INTER_AREA)
frame2_resized = cv2.resize(frame2, (unified_width, unified_height), interpolation=cv2.INTER_AREA)
# 拼接帧,左右并排
combined_frame = np.hstack((frame1_resized, frame2_resized))
# 写入到输出视频
# combined_out.write(combined_frame)
seq1_out.write(frame1_resized)
seq2_out.write(frame2_resized)
# 释放资源
# combined_out.release()
seq1_out.release()
seq2_out.release()
def align_filter(standard, user, tmpdir):
# 示例数据
sequence1 = json.load(open(standard + ".json", 'r'))
sequence2 = json.load(open(user + ".json", 'r'))
# 对齐两个序列
distance, alignment_path = align_sequences(sequence1, sequence2)
# 过滤sequence1和sequence2
filtered_sequence1 = filter_sequence_by_alignment(sequence1, alignment_path, index=0)
filtered_sequence2 = filter_sequence_by_alignment(sequence2, alignment_path, index=1)
print(f"DTW 最佳对齐路径: {alignment_path}")
print(f"DTW 最小对齐距离: {distance}")
# 存储对齐的json keypoint
json.dump(filtered_sequence1, open(standard + ".json", 'w'), indent=4)
json.dump(filtered_sequence2, open(user + ".json", 'w'), indent=4)
# 读取视频帧
seq1_frames = read_video_frames(standard + '.mp4') # Sequence 1: 55 frames
seq2_frames = read_video_frames(user + '.mp4') # Sequence 2: 34 frames
# 输出视频路径
output_combined_path = tmpdir + '/aligned_combined_output.mp4' # 合并视频
output_seq1_path = standard + '.mp4' # 对齐后sequence1视频
output_seq2_path = user + '.mp4' # 对齐后sequence2视频
height1, width1, _ = seq1_frames[0].shape
height2, width2, _ = seq2_frames[0].shape
# 如果视频尺寸不相同则根据比例缩放关键点坐标
if height1 != height2 or width1 != width2:
scale_keypoints(standard, user, seq1_frames, seq2_frames)
# 创建对齐后的视频
create_aligned_videos(seq1_frames, seq2_frames, alignment_path, output_combined_path, output_seq1_path, output_seq2_path)
# print(f"Combined video created at {output_combined_path}")
print(f"Aligned Sequence 1 video created at {output_seq1_path}")
print(f"Aligned Sequence 2 video created at {output_seq2_path}")