import os |
from skimage import io,img_as_float32 |
import cv2 |
import torch |
import numpy as np |
import subprocess |
import pandas |
from models.audio2pose import audio2poseLSTM |
from scipy.io import wavfile |
import python_speech_features |
import pyworld |
import config |
import json |
from scipy.interpolate import interp1d |
def inter_pitch(y,y_flag): |
frame_num = y.shape[0] |
i = 0 |
last = -1 |
while(i<frame_num): |
if y_flag[i] == 0: |
while True: |
if y_flag[i]==0: |
if i == frame_num-1: |
if last !=-1: |
y[last+1:] = y[last] |
i+=1 |
break |
i+=1 |
else: |
break |
if i >= frame_num: |
break |
elif last == -1: |
y[:i] = y[i] |
else: |
inter_num = i-last+1 |
fy = np.array([y[last],y[i]]) |
fx = np.linspace(0, 1, num=2) |
f = interp1d(fx,fy) |
fx_new = np.linspace(0,1,inter_num) |
fy_new = f(fx_new) |
y[last+1:i] = fy_new[1:-1] |
last = i |
i+=1 |
else: |
last = i |
i+=1 |
return y |
def load_ckpt(checkpoint_path, generator = None, kp_detector = None, ph2kp = None): |
checkpoint = torch.load(checkpoint_path) |
if ph2kp is not None: |
ph2kp.load_state_dict(checkpoint['ph2kp']) |
if generator is not None: |
generator.load_state_dict(checkpoint['generator']) |
if kp_detector is not None: |
kp_detector.load_state_dict(checkpoint['kp_detector']) |
def get_img_pose(img_path): |
tmp_dir = "samples/tmp_dir" |
os.makedirs((tmp_dir),exist_ok=True) |
subprocess.call([processor, "-f", img_path, "-out_dir", tmp_dir, "-pose"]) |
img_file = os.path.basename(img_path)[:-4]+".csv" |
csv_file = os.path.join(tmp_dir,img_file) |
pos_data = pandas.read_csv(csv_file) |
i = 0 |
pose = [pos_data["pose_Rx"][i], pos_data["pose_Ry"][i], pos_data["pose_Rz"][i],pos_data["pose_Tx"][i], pos_data["pose_Ty"][i], pos_data["pose_Tz"][i]] |
pose = np.array(pose,dtype=np.float32) |
return pose |
def read_img(path): |
img = io.imread(path)[:,:,:3] |
img = cv2.resize(img, (256, 256)) |
img = np.array(img_as_float32(img)) |
img = img.transpose((2, 0, 1)) |
img = torch.from_numpy(img).unsqueeze(0) |
return img |
def parse_phoneme_file(phoneme_path,use_index = True): |
with open(phoneme_path,'r') as f: |
result_text = json.load(f) |
frame_num = int(result_text[-1]['phones'][-1]['ed']/100*25) |
phoneset_list = [] |
index = 0 |
word_len = len(result_text) |
word_index = 0 |
phone_index = 0 |
cur_phone_list = result_text[0]["phones"] |
phone_len = len(cur_phone_list) |
cur_end = cur_phone_list[0]["ed"] |
phone_list = [] |
phoneset_list.append(cur_phone_list[0]["ph"]) |
i = 0 |
while i < frame_num: |
if i * 4 < cur_end: |
phone_list.append(cur_phone_list[phone_index]["ph"]) |
i += 1 |
else: |
phone_index += 1 |
if phone_index >= phone_len: |
word_index += 1 |
if word_index >= word_len: |
phone_list.append(cur_phone_list[-1]["ph"]) |
i += 1 |
else: |
phone_index = 0 |
cur_phone_list = result_text[word_index]["phones"] |
phone_len = len(cur_phone_list) |
cur_end = cur_phone_list[phone_index]["ed"] |
phoneset_list.append(cur_phone_list[phone_index]["ph"]) |
index += 1 |
else: |
cur_end = cur_phone_list[phone_index]["ed"] |
phoneset_list.append(cur_phone_list[phone_index]["ph"]) |
index += 1 |
with open("phindex.json") as f: |
ph2index = json.load(f) |
if use_index: |
phone_list = [ph2index[p] for p in phone_list] |
saves = {"phone_list": phone_list} |
return saves |
def get_audio_feature_from_audio(audio_path): |
sample_rate, audio = wavfile.read(audio_path) |
if len(audio.shape) == 2: |
if np.min(audio[:, 0]) <= 0: |
audio = audio[:, 1] |
else: |
audio = audio[:, 0] |
audio = audio - np.mean(audio) |
audio = audio / np.max(np.abs(audio)) |
a = python_speech_features.mfcc(audio, sample_rate) |
b = python_speech_features.logfbank(audio, sample_rate) |
c, _ = pyworld.harvest(audio, sample_rate, frame_period=10) |
c_flag = (c == 0.0) ^ 1 |
c = inter_pitch(c, c_flag) |
c = np.expand_dims(c, axis=1) |
c_flag = np.expand_dims(c_flag, axis=1) |
frame_num = np.min([a.shape[0], b.shape[0], c.shape[0]]) |
cat = np.concatenate([a[:frame_num], b[:frame_num], c[:frame_num], c_flag[:frame_num]], axis=1) |
return cat |
def get_pose_from_audio(img,audio,audio2pose): |
num_frame = len(audio) // 4 |
minv = np.array([-0.6, -0.6, -0.6, -128.0, -128.0, 128.0], dtype=np.float32) |
maxv = np.array([0.6, 0.6, 0.6, 128.0, 128.0, 384.0], dtype=np.float32) |
generator = audio2poseLSTM().cuda().eval() |
ckpt_para = torch.load(audio2pose) |
generator.load_state_dict(ckpt_para["generator"]) |
generator.eval() |
audio_seq = [] |
for i in range(num_frame): |
audio_seq.append(audio[i*4:i*4+4]) |
audio = torch.from_numpy(np.array(audio_seq,dtype=np.float32)).unsqueeze(0).cuda() |
x = {} |
x ["img"] = img |
x["audio"] = audio |
poses = generator(x) |
poses = poses.cpu().data.numpy()[0] |
poses = (poses+1)/2*(maxv-minv)+minv |
return poses |