Spaces:
Build error
Build error
import re | |
import glob | |
import pickle | |
import os | |
import torch | |
import numpy as np | |
from utils.audio import load_spectrograms | |
from utils.compute_args import compute_args | |
from utils.tokenize import tokenize, create_dict, sent_to_ix, cmumosei_2, cmumosei_7, pad_feature | |
from model_LA import Model_LA | |
import gradio as gr | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# load model | |
ckpts_path = 'ckpt' | |
model_name = "Model_LA_e" | |
# Listing sorted checkpoints | |
ckpts = sorted(glob.glob(os.path.join(ckpts_path, model_name,'best*')), reverse=True) | |
# Load original args | |
args = torch.load(ckpts[0], map_location=torch.device(device))['args'] | |
args = compute_args(args) | |
pretrained_emb = np.load("train_glove.npy") | |
token_to_ix = pickle.load(open("token_to_ix.pkl", "rb")) | |
state_dict = torch.load(ckpts[0], map_location=torch.device(device))['state_dict'] | |
net = Model_LA(args, len(token_to_ix), pretrained_emb).to(device) | |
net.load_state_dict(state_dict) | |
def inference(video_path, text): | |
# data preprocessing | |
# text | |
def clean(w): | |
return re.sub( | |
r"([.,'!?\"()*#:;])", | |
'', | |
w.lower() | |
).replace('-', ' ').replace('/', ' ') | |
s = [clean(w) for w in text.split() if clean(w) != ''] | |
# Sound | |
_, mel, mag = load_spectrograms(video_path) | |
l_max_len = args.lang_seq_len | |
a_max_len = args.audio_seq_len | |
v_max_len = args.video_seq_len | |
L = sent_to_ix(s, token_to_ix, max_token=l_max_len) | |
A = pad_feature(mel, a_max_len) | |
V = pad_feature(mel, v_max_len) | |
# print shapes | |
print(f"Processed text shape from {len(s)} to {L.shape}") | |
print(f"Processed audio shape from {mel.shape} to {A.shape}") | |
print(f"Processed video shape from {mel.shape} to {V.shape}") | |
net.train(False) | |
x = np.expand_dims(L,axis=0) | |
y = np.expand_dims(A,axis=0) | |
z = np.expand_dims(V,axis=0) | |
x, y, z = torch.from_numpy(x).to(device), torch.from_numpy(y).to(device), torch.from_numpy(z).float().to(device) | |
pred = net(x, y, z).cpu().data.numpy()[0] | |
pred = np.exp(pred) / np.sum(np.exp(pred)) # softmax | |
label_to_ix = ['happy', 'sad', 'angry', 'fear', 'disgust', 'surprise'] | |
result_dict = {label_to_ix[i]: float(pred[i]) for i in range(len(label_to_ix))} | |
return result_dict | |
title="Emotion Recognition" | |
description="" | |
examples = [ | |
['examples/03bSnISJMiM_1.mp4', "IT WAS REALLY GOOD "], | |
['examples/03bSnISJMiM_5.mp4', "AND THEY SHOULDVE I GUESS "], | |
] | |
gr.Interface(inference, | |
inputs = [gr.inputs.Video(type="mp4", source="upload"), "text"], | |
outputs=["label"], | |
title=title, | |
description=description, | |
examples=examples | |
).launch(debug=True) |