nouamanetazi's picture
nouamanetazi HF staff
minor fix
1dbb2fe
raw
history blame
2.87 kB
import re
import glob
import pickle
import os
import torch
import numpy as np
from utils.audio import load_spectrograms
from utils.compute_args import compute_args
from utils.tokenize import tokenize, create_dict, sent_to_ix, cmumosei_2, cmumosei_7, pad_feature
from model_LA import Model_LA
import gradio as gr
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load model
ckpts_path = 'ckpt'
model_name = "Model_LA_e"
# Listing sorted checkpoints
ckpts = sorted(glob.glob(os.path.join(ckpts_path, model_name,'best*')), reverse=True)
# Load original args
args = torch.load(ckpts[0], map_location=torch.device(device))['args']
args = compute_args(args)
pretrained_emb = np.load("train_glove.npy")
token_to_ix = pickle.load(open("token_to_ix.pkl", "rb"))
state_dict = torch.load(ckpts[0], map_location=torch.device(device))['state_dict']
net = Model_LA(args, len(token_to_ix), pretrained_emb).to(device)
net.load_state_dict(state_dict)
def inference(video_path, text):
# data preprocessing
# text
def clean(w):
return re.sub(
r"([.,'!?\"()*#:;])",
'',
w.lower()
).replace('-', ' ').replace('/', ' ')
s = [clean(w) for w in text.split() if clean(w) != '']
# Sound
_, mel, mag = load_spectrograms(video_path)
l_max_len = args.lang_seq_len
a_max_len = args.audio_seq_len
v_max_len = args.video_seq_len
L = sent_to_ix(s, token_to_ix, max_token=l_max_len)
A = pad_feature(mel, a_max_len)
V = pad_feature(mel, v_max_len)
# print shapes
print(f"Processed text shape from {len(s)} to {L.shape}")
print(f"Processed audio shape from {mel.shape} to {A.shape}")
print(f"Processed video shape from {mel.shape} to {V.shape}")
net.train(False)
x = np.expand_dims(L,axis=0)
y = np.expand_dims(A,axis=0)
z = np.expand_dims(V,axis=0)
x, y, z = torch.from_numpy(x).to(device), torch.from_numpy(y).to(device), torch.from_numpy(z).float().to(device)
pred = net(x, y, z).cpu().data.numpy()[0]
pred = np.exp(pred) / np.sum(np.exp(pred)) # softmax
label_to_ix = ['happy', 'sad', 'angry', 'fear', 'disgust', 'surprise']
result_dict = {label_to_ix[i]: float(pred[i]) for i in range(len(label_to_ix))}
return result_dict
title="Emotion Recognition"
description=""
examples = [
['examples/03bSnISJMiM_1.mp4', "IT WAS REALLY GOOD "],
['examples/03bSnISJMiM_5.mp4', "AND THEY SHOULDVE I GUESS "],
]
gr.Interface(inference,
inputs = [gr.inputs.Video(type="mp4", source="upload"), "text"],
outputs=["label"],
title=title,
description=description,
examples=examples
).launch(debug=True)