Spaces:
Build error
Build error
import re | |
import glob | |
import pickle | |
import os | |
import torch | |
import numpy as np | |
from utils.audio import load_spectrograms | |
from utils.compute_args import compute_args | |
from utils.tokenize import ( | |
tokenize, | |
create_dict, | |
sent_to_ix, | |
cmumosei_2, | |
cmumosei_7, | |
pad_feature, | |
) | |
from model_LA import Model_LA | |
import gradio as gr | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# load model | |
ckpts_path = "ckpt" | |
model_name = "Model_LA_e" | |
# Listing sorted checkpoints | |
ckpts = sorted(glob.glob(os.path.join(ckpts_path, model_name, "best*")), reverse=True) | |
# Load original args | |
args = torch.load(ckpts[0], map_location=torch.device(device))["args"] | |
args = compute_args(args) | |
pretrained_emb = np.load("train_glove.npy") | |
token_to_ix = pickle.load(open("token_to_ix.pkl", "rb")) | |
state_dict = torch.load(ckpts[0], map_location=torch.device(device))["state_dict"] | |
net = Model_LA(args, len(token_to_ix), pretrained_emb).to(device) | |
net.load_state_dict(state_dict) | |
def inference(source_video, transcription): | |
# data preprocessing | |
# text | |
def clean(w): | |
return ( | |
re.sub(r"([.,'!?\"()*#:;])", "", w.lower()) | |
.replace("-", " ") | |
.replace("/", " ") | |
) | |
s = [clean(w) for w in transcription.split() if clean(w) != ""] | |
# Sound | |
_, mel, mag = load_spectrograms(source_video) | |
l_max_len = args.lang_seq_len | |
a_max_len = args.audio_seq_len | |
v_max_len = args.video_seq_len | |
L = sent_to_ix(s, token_to_ix, max_token=l_max_len) | |
A = pad_feature(mel, a_max_len) | |
V = pad_feature(mel, v_max_len) | |
# print shapes | |
print(f"Processed text shape from {len(s)} to {L.shape}") | |
print(f"Processed audio shape from {mel.shape} to {A.shape}") | |
print(f"Processed video shape from {mel.shape} to {V.shape}") | |
net.train(False) | |
x = np.expand_dims(L, axis=0) | |
y = np.expand_dims(A, axis=0) | |
z = np.expand_dims(V, axis=0) | |
x, y, z = ( | |
torch.from_numpy(x).to(device), | |
torch.from_numpy(y).to(device), | |
torch.from_numpy(z).float().to(device), | |
) | |
pred = net(x, y, z).cpu().data.numpy()[0] | |
# pred = np.exp(pred) / np.sum(np.exp(pred)) # softmax | |
label_to_ix = ["happy", "sad", "angry", "fear", "disgust", "surprise"] | |
# result_dict = {label_to_ix[i]: float(pred[i]) for i in range(len(label_to_ix))} | |
result_dict = {label_to_ix[i]: float(pred[i]) > 0 for i in range(len(label_to_ix))} | |
return result_dict | |
title = "Emotion Recognition" | |
description = "" | |
examples = [ | |
[ | |
"examples/0h-zjBukYpk_2.mp4", | |
"NOW IM NOT EVEN GONNA SUGAR COAT THIS THIS MOVIE FRUSTRATED ME TO SUCH AN EXTREME EXTENT THAT I WAS LOUDLY EXCLAIMING WHY AT THE END OF THE FILM", | |
], | |
["examples/0h-zjBukYpk_19.mp4", "NOW OTHER PERFORMANCES ARE BORDERLINE OKAY"], | |
["examples/03bSnISJMiM_1.mp4", "IT WAS REALLY GOOD "], | |
["examples/03bSnISJMiM_5.mp4", "AND THEY SHOULDVE I GUESS "], | |
] | |
gr.Interface( | |
inference, | |
inputs=[gr.inputs.Video(type="avi", source="upload"), "text"], | |
outputs=["label"], | |
title=title, | |
description=description, | |
examples=examples, | |
).launch(debug=True) | |