diff --git "a/inference.ipynb" "b/inference.ipynb" new file mode 100644--- /dev/null +++ "b/inference.ipynb" @@ -0,0 +1 @@ +{"cells":[{"cell_type":"markdown","metadata":{"id":"Qg8z-btej7oA"},"source":["# Emotion Recognition from Video"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"zC2xyVhEkAXu"},"outputs":[],"source":["# !pip install torch==1.2.0 torchvision==0.4.0 numpy==1.18.1 #if necessary\n","!wget https://github.com/explosion/spacy-models/releases/download/en_vectors_web_lg-2.1.0/en_vectors_web_lg-2.1.0.tar.gz -O en_vectors_web_lg-2.1.0.tar.gz\n","!pip install -q en_vectors_web_lg-2.1.0.tar.gz"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":17589,"status":"ok","timestamp":1646140738835,"user":{"displayName":"Nouamane Tazi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gg753z6h9fmTPmGyKajJFbNQG48KIqPziiTsxl4Tw=s64","userId":"11345629174419407363"},"user_tz":-60},"id":"F6WrnZIRlyor","outputId":"4d922169-b27e-4ba8-c84d-02a39da67785"},"outputs":[{"name":"stdout","output_type":"stream","text":["--2022-03-01 13:18:41-- https://docs.google.com/uc?export=download&confirm=t&id=1GyXRWhtf0_sJQacy5wT8vHoynwHkMo79\n","Resolving docs.google.com (docs.google.com)... 142.250.73.206, 2607:f8b0:4004:829::200e\n","Connecting to docs.google.com (docs.google.com)|142.250.73.206|:443... connected.\n","HTTP request sent, awaiting response... 303 See Other\n","Location: https://doc-0k-a8-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/4cpnuak4dlt7oefub8j53usv9o7c1o6o/1646140650000/04146720491471605701/*/1GyXRWhtf0_sJQacy5wT8vHoynwHkMo79?e=download [following]\n","Warning: wildcards not supported in HTTP.\n","--2022-03-01 13:18:41-- https://doc-0k-a8-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/4cpnuak4dlt7oefub8j53usv9o7c1o6o/1646140650000/04146720491471605701/*/1GyXRWhtf0_sJQacy5wT8vHoynwHkMo79?e=download\n","Resolving doc-0k-a8-docs.googleusercontent.com (doc-0k-a8-docs.googleusercontent.com)... 172.217.15.65, 2607:f8b0:4004:810::2001\n","Connecting to doc-0k-a8-docs.googleusercontent.com (doc-0k-a8-docs.googleusercontent.com)|172.217.15.65|:443... connected.\n","HTTP request sent, awaiting response... 200 OK\n","Length: 727119633 (693M) [application/zip]\n","Saving to: ‘emotion_LA.zip’\n","\n","emotion_LA.zip 100%[===================>] 693.43M 130MB/s in 6.3s \n","\n","2022-03-01 13:18:48 (110 MB/s) - ‘emotion_LA.zip’ saved [727119633/727119633]\n","\n","Archive: emotion_LA.zip\n"," creating: MOSEI_UMONS/ckpt/Model_LA_e/\n"," inflating: MOSEI_UMONS/ckpt/Model_LA_e/best81.21325494388027_1117766.pkl \n"," inflating: MOSEI_UMONS/ckpt/Model_LA_e/best81.0974523427757_2623576.pkl \n","drive emotion_LA.zip MOSEI_UMONS sample_data\n"]}],"source":["# checkpoints\n","!for file in \"1GyXRWhtf0_sJQacy5wT8vHoynwHkMo79\"; do wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='${file} -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id=\"${file} -O emotion_LA.zip && rm -rf /tmp/cookies.txt; done\n","!mkdir MOSEI_UMONS/ckpt/\n","!unzip emotion_LA.zip -d MOSEI_UMONS/ckpt\n","!ls\n"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":263,"status":"ok","timestamp":1646141678150,"user":{"displayName":"Nouamane Tazi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gg753z6h9fmTPmGyKajJFbNQG48KIqPziiTsxl4Tw=s64","userId":"11345629174419407363"},"user_tz":-60},"id":"KhLIkfFSj9Gw","outputId":"a73ad20c-37c4-4894-a12b-cb18bd1eba46"},"outputs":[{"name":"stdout","output_type":"stream","text":["/content/drive/MyDrive/projects/mosei_umons\n"]}],"source":["%cd /content/drive/MyDrive/projects/mosei_umons\n","\n","%reload_ext autoreload\n","%autoreload 2"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"_C5QS8DMj7oG"},"outputs":[],"source":["import re\n","import glob\n","import pickle\n","import os\n","import torch\n","import numpy as np\n","from utils.audio import load_spectrograms\n","from utils.compute_args import compute_args\n","from utils.tokenize import tokenize, create_dict, sent_to_ix, cmumosei_2, cmumosei_7, pad_feature\n","from model_LA import Model_LA\n","\n","import torch\n","device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n","\n","working_dir = \".\""]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":2259,"status":"ok","timestamp":1646142961242,"user":{"displayName":"Nouamane Tazi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gg753z6h9fmTPmGyKajJFbNQG48KIqPziiTsxl4Tw=s64","userId":"11345629174419407363"},"user_tz":-60},"id":"M9vfPOnvj7oH","outputId":"839be299-f6a0-40ac-e727-96373c27d562"},"outputs":[{"output_type":"execute_result","data":{"text/plain":[""]},"metadata":{},"execution_count":37}],"source":["# load model\n","\n","ckpts_path = os.path.join(working_dir, 'ckpt')\n","model_name = \"Model_LA_e\"\n","# Listing sorted checkpoints\n","ckpts = sorted(glob.glob(os.path.join(ckpts_path, model_name,'best*')), reverse=True)\n","\n","# Load original args\n","args = torch.load(ckpts[0], map_location=torch.device(device))['args']\n","args = compute_args(args)\n","pretrained_emb = np.load(\"train_glove.npy\")\n","token_to_ix = pickle.load(open(\"token_to_ix.pkl\", \"rb\")) \n","state_dict = torch.load(ckpts[0], map_location=torch.device(device))['state_dict']\n","\n","net = Model_LA(args, len(token_to_ix), pretrained_emb).to(device)\n","net.load_state_dict(state_dict)"]},{"cell_type":"code","execution_count":41,"metadata":{"executionInfo":{"elapsed":186,"status":"ok","timestamp":1646143061606,"user":{"displayName":"Nouamane Tazi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gg753z6h9fmTPmGyKajJFbNQG48KIqPziiTsxl4Tw=s64","userId":"11345629174419407363"},"user_tz":-60},"id":"1yFB5Yxbj7oF"},"outputs":[],"source":["video_path = os.path.join(working_dir, 'data/video/03bSnISJMiM_1.mp4')\n","transcript_path = os.path.join(working_dir, 'data/transcripts/03bSnISJMiM_1.txt')\n","transcript = None"]},{"cell_type":"markdown","source":["### Record video"],"metadata":{"id":"svzHt7xPudTi"}},{"cell_type":"code","source":["from IPython.display import display, Javascript,HTML\n","from google.colab.output import eval_js\n","from base64 import b64decode\n","\n","def record_video(filename):\n"," js=Javascript(\"\"\"\n"," async function recordVideo() {\n"," const options = { mimeType: \"video/webm; codecs=vp9\" };\n"," const div = document.createElement('div');\n"," const capture = document.createElement('button');\n"," const stopCapture = document.createElement(\"button\");\n"," \n"," capture.textContent = \"Start Recording\";\n"," capture.style.background = \"orange\";\n"," capture.style.color = \"white\";\n","\n"," stopCapture.textContent = \"Stop Recording\";\n"," stopCapture.style.background = \"red\";\n"," stopCapture.style.color = \"white\";\n"," div.appendChild(capture);\n","\n"," const video = document.createElement('video');\n"," const recordingVid = document.createElement(\"video\");\n"," video.style.display = 'block';\n","\n"," const stream = await navigator.mediaDevices.getUserMedia({audio:true, video: true});\n"," \n"," let recorder = new MediaRecorder(stream, options);\n"," document.body.appendChild(div);\n"," div.appendChild(video);\n","\n"," video.srcObject = stream;\n"," video.muted = true;\n","\n"," await video.play();\n","\n"," google.colab.output.setIframeHeight(document.documentElement.scrollHeight, true);\n","\n"," await new Promise((resolve) => {\n"," capture.onclick = resolve;\n"," });\n"," recorder.start();\n"," capture.replaceWith(stopCapture);\n","\n"," await new Promise((resolve) => stopCapture.onclick = resolve);\n"," recorder.stop();\n"," let recData = await new Promise((resolve) => recorder.ondataavailable = resolve);\n"," let arrBuff = await recData.data.arrayBuffer();\n"," \n"," // stop the stream and remove the video element\n"," stream.getVideoTracks()[0].stop();\n"," div.remove();\n","\n"," let binaryString = \"\";\n"," let bytes = new Uint8Array(arrBuff);\n"," bytes.forEach((byte) => {\n"," binaryString += String.fromCharCode(byte);\n"," })\n"," return btoa(binaryString);\n"," }\n"," \"\"\")\n"," try:\n"," display(js)\n"," data=eval_js('recordVideo({})')\n"," binary=b64decode(data)\n"," with open(filename,\"wb\") as video_file:\n"," video_file.write(binary)\n"," print(f\"Finished recording video at:{filename}\")\n"," except Exception as err:\n"," print(str(err))"],"metadata":{"id":"eICqoEV_uib-","executionInfo":{"status":"ok","timestamp":1646142546020,"user_tz":-60,"elapsed":230,"user":{"displayName":"Nouamane Tazi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gg753z6h9fmTPmGyKajJFbNQG48KIqPziiTsxl4Tw=s64","userId":"11345629174419407363"}}},"execution_count":18,"outputs":[]},{"cell_type":"code","source":["record_video(\"test.mp4\")\n","transcript = \"oh my god that is amazing!\" # what you said in video"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":0},"id":"YRy9RvRiukfM","executionInfo":{"status":"ok","timestamp":1646142572058,"user_tz":-60,"elapsed":16082,"user":{"displayName":"Nouamane Tazi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gg753z6h9fmTPmGyKajJFbNQG48KIqPziiTsxl4Tw=s64","userId":"11345629174419407363"}},"outputId":"deea9512-0503-4366-b62e-3d58b16279db"},"execution_count":19,"outputs":[{"output_type":"display_data","data":{"application/javascript":["\n"," async function recordVideo() {\n"," const options = { mimeType: \"video/webm; codecs=vp9\" };\n"," const div = document.createElement('div');\n"," const capture = document.createElement('button');\n"," const stopCapture = document.createElement(\"button\");\n"," \n"," capture.textContent = \"Start Recording\";\n"," capture.style.background = \"orange\";\n"," capture.style.color = \"white\";\n","\n"," stopCapture.textContent = \"Stop Recording\";\n"," stopCapture.style.background = \"red\";\n"," stopCapture.style.color = \"white\";\n"," div.appendChild(capture);\n","\n"," const video = document.createElement('video');\n"," const recordingVid = document.createElement(\"video\");\n"," video.style.display = 'block';\n","\n"," const stream = await navigator.mediaDevices.getUserMedia({audio:true, video: true});\n"," \n"," let recorder = new MediaRecorder(stream, options);\n"," document.body.appendChild(div);\n"," div.appendChild(video);\n","\n"," video.srcObject = stream;\n"," video.muted = true;\n","\n"," await video.play();\n","\n"," google.colab.output.setIframeHeight(document.documentElement.scrollHeight, true);\n","\n"," await new Promise((resolve) => {\n"," capture.onclick = resolve;\n"," });\n"," recorder.start();\n"," capture.replaceWith(stopCapture);\n","\n"," await new Promise((resolve) => stopCapture.onclick = resolve);\n"," recorder.stop();\n"," let recData = await new Promise((resolve) => recorder.ondataavailable = resolve);\n"," let arrBuff = await recData.data.arrayBuffer();\n"," \n"," // stop the stream and remove the video element\n"," stream.getVideoTracks()[0].stop();\n"," div.remove();\n","\n"," let binaryString = \"\";\n"," let bytes = new Uint8Array(arrBuff);\n"," bytes.forEach((byte) => {\n"," binaryString += String.fromCharCode(byte);\n"," })\n"," return btoa(binaryString);\n"," }\n"," "],"text/plain":[""]},"metadata":{}},{"output_type":"stream","name":"stdout","text":["Finished recording video at:test.mp4\n"]}]},{"cell_type":"markdown","source":["### Preview video"],"metadata":{"id":"NCMUOOnbufmF"}},{"cell_type":"code","execution_count":20,"metadata":{"id":"qg0Rs_BesNI2","colab":{"base_uri":"https://localhost:8080/","height":0},"executionInfo":{"status":"ok","timestamp":1646142583874,"user_tz":-60,"elapsed":649,"user":{"displayName":"Nouamane Tazi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gg753z6h9fmTPmGyKajJFbNQG48KIqPziiTsxl4Tw=s64","userId":"11345629174419407363"}},"outputId":"caa97f38-da0f-406c-b17d-21245f8cedd5"},"outputs":[{"output_type":"execute_result","data":{"text/html":["\n","\n"],"text/plain":[""]},"metadata":{},"execution_count":20}],"source":["from IPython.display import HTML\n","from base64 import b64encode\n","mp4 = open(video_path,'rb').read()\n","data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n","HTML(\"\"\"\n","\n","\"\"\" % data_url)"]},{"cell_type":"code","execution_count":38,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":519,"status":"ok","timestamp":1646142961758,"user":{"displayName":"Nouamane Tazi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gg753z6h9fmTPmGyKajJFbNQG48KIqPziiTsxl4Tw=s64","userId":"11345629174419407363"},"user_tz":-60},"id":"_zQJiszoj7oH","outputId":"d74ed60d-1374-401f-fa26-f2b22c2145bd"},"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.7/dist-packages/librosa/core/audio.py:165: UserWarning: PySoundFile failed. Trying audioread instead.\n"," warnings.warn(\"PySoundFile failed. Trying audioread instead.\")\n"]},{"output_type":"stream","name":"stdout","text":["Processed text shape: (60,)\n","Processed audio shape: (60, 80)\n","Processed video shape: (60, 80)\n"]}],"source":["# Data preprocessing\n","# text\n","def clean(w):\n"," return re.sub(\n"," r\"([.,'!?\\\"()*#:;])\",\n"," '',\n"," w.lower()\n"," ).replace('-', ' ').replace('/', ' ')\n","\n","text = open(transcript_path, 'r').read() if transcript is None else transcript\n","s = [clean(w) for w in text.split() if clean(w) != '']\n","\n","# Sound\n","_, mel, mag = load_spectrograms(video_path)\n","\n","l_max_len = args.lang_seq_len\n","a_max_len = args.audio_seq_len\n","v_max_len = args.video_seq_len\n","L = sent_to_ix(s, token_to_ix, max_token=l_max_len)\n","A = pad_feature(mel, a_max_len)\n","V = pad_feature(mel, v_max_len)\n","# print shapes\n","print(\"Processed text shape: \", L.shape)\n","print(\"Processed audio shape: \", A.shape)\n","print(\"Processed video shape: \", V.shape)"]},{"cell_type":"markdown","source":["# Prediction"],"metadata":{"id":"KN9jUESJxqLQ"}},{"cell_type":"code","execution_count":39,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":339,"status":"ok","timestamp":1646142962094,"user":{"displayName":"Nouamane Tazi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gg753z6h9fmTPmGyKajJFbNQG48KIqPziiTsxl4Tw=s64","userId":"11345629174419407363"},"user_tz":-60},"id":"M4A1UDkLj7oJ","outputId":"bf48b141-8cae-4248-834c-f14d11701f3b"},"outputs":[{"output_type":"stream","name":"stdout","text":["[[ 1.0882487 -2.0272958 -2.84339 -2.2552867 -3.6238117 -4.0526347]]\n"]}],"source":["net.train(False)\n","x = np.expand_dims(L,axis=0)\n","y = np.expand_dims(A,axis=0)\n","z = np.expand_dims(V,axis=0)\n","x, y, z = torch.from_numpy(x).to(device), torch.from_numpy(y).to(device), torch.from_numpy(z).float().to(device)\n","pred = net(x, y, z).cpu().data.numpy()\n","print(pred)"]},{"cell_type":"code","source":["label_to_ix = ['happy', 'sad', 'angry', 'fear', 'disgust', 'surprise']\n","result_dict = dict(zip(label_to_ix, pred[0]>0))\n","result_dict"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"r5ttE63dsk4-","executionInfo":{"status":"ok","timestamp":1646142962094,"user_tz":-60,"elapsed":30,"user":{"displayName":"Nouamane Tazi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gg753z6h9fmTPmGyKajJFbNQG48KIqPziiTsxl4Tw=s64","userId":"11345629174419407363"}},"outputId":"28c9dab2-2afc-4bb4-d4d1-8cf266911b28"},"execution_count":40,"outputs":[{"output_type":"execute_result","data":{"text/plain":["{'angry': False,\n"," 'disgust': False,\n"," 'fear': False,\n"," 'happy': True,\n"," 'sad': False,\n"," 'surprise': False}"]},"metadata":{},"execution_count":40}]}],"metadata":{"colab":{"collapsed_sections":["svzHt7xPudTi"],"name":"inference.ipynb","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"},"orig_nbformat":4},"nbformat":4,"nbformat_minor":0} \ No newline at end of file