PFA-Demo / dataset.py
wissemkarous's picture
init
8c79f36 verified
raw
history blame
4.28 kB
# encoding: utf-8
import numpy as np
import cv2
import os
from torch.utils.data import Dataset
from cvtransforms import *
import torch
import editdistance
class MyDataset(Dataset):
letters = [
" ",
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
"K",
"L",
"M",
"N",
"O",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z",
]
def __init__(self, video_path, anno_path, file_list, vid_pad, txt_pad, phase):
self.anno_path = anno_path
self.vid_pad = vid_pad
self.txt_pad = txt_pad
self.phase = phase
with open(file_list, "r") as f:
self.videos = [
os.path.join(video_path, line.strip()) for line in f.readlines()
]
self.data = []
for vid in self.videos:
# items = vid.split(os.path.sep)
items = vid.split("/")
self.data.append((vid, items[-4], items[-1]))
def __getitem__(self, idx):
(vid, spk, name) = self.data[idx]
vid = self._load_vid(vid)
anno = self._load_anno(
os.path.join(self.anno_path, spk, "align", name + ".align")
)
if self.phase == "train":
vid = HorizontalFlip(vid)
vid = ColorNormalize(vid)
vid_len = vid.shape[0]
anno_len = anno.shape[0]
vid = self._padding(vid, self.vid_pad)
anno = self._padding(anno, self.txt_pad)
return {
"vid": torch.FloatTensor(vid.transpose(3, 0, 1, 2)),
"txt": torch.LongTensor(anno),
"txt_len": anno_len,
"vid_len": vid_len,
}
def __len__(self):
return len(self.data)
def _load_vid(self, p):
files = os.listdir(p)
files = list(filter(lambda file: file.find(".jpg") != -1, files))
files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
array = [cv2.imread(os.path.join(p, file)) for file in files]
array = list(filter(lambda im: not im is None, array))
array = [
cv2.resize(im, (128, 64), interpolation=cv2.INTER_LANCZOS4) for im in array
]
array = np.stack(array, axis=0).astype(np.float32)
return array
def _load_anno(self, name):
with open(name, "r") as f:
lines = [line.strip().split(" ") for line in f.readlines()]
txt = [line[2] for line in lines]
txt = list(filter(lambda s: not s.upper() in ["SIL", "SP"], txt))
return MyDataset.txt2arr(" ".join(txt).upper(), 1)
def _padding(self, array, length):
array = [array[_] for _ in range(array.shape[0])]
size = array[0].shape
for i in range(length - len(array)):
array.append(np.zeros(size))
return np.stack(array, axis=0)
@staticmethod
def txt2arr(txt, start):
arr = []
for c in list(txt):
arr.append(MyDataset.letters.index(c) + start)
return np.array(arr)
@staticmethod
def arr2txt(arr, start):
txt = []
for n in arr:
if n >= start:
txt.append(MyDataset.letters[n - start])
return "".join(txt).strip()
@staticmethod
def ctc_arr2txt(arr, start):
pre = -1
txt = []
for n in arr:
if pre != n and n >= start:
if (
len(txt) > 0
and txt[-1] == " "
and MyDataset.letters[n - start] == " "
):
pass
else:
txt.append(MyDataset.letters[n - start])
pre = n
return "".join(txt).strip()
@staticmethod
def wer(predict, truth):
word_pairs = [(p[0].split(" "), p[1].split(" ")) for p in zip(predict, truth)]
wer = [1.0 * editdistance.eval(p[0], p[1]) / len(p[1]) for p in word_pairs]
return wer
@staticmethod
def cer(predict, truth):
cer = [
1.0 * editdistance.eval(p[0], p[1]) / len(p[1]) for p in zip(predict, truth)
]
return cer