{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path \n", "import os\n", "import sys\n", "sys.path.append(str(Path(os.path.abspath('')).parent))\n", "\n", "import torch\n", "import numpy as np\n", "\n", "import matplotlib.pyplot as plt\n", "import matplotlib.animation as animation\n", "\n", "agent_path = Path(os.path.abspath('')).parent / 'models' / 'genrl_stickman_500k_2.pt'\n", "print(\"Model path\", agent_path)\n", "\n", "agent = torch.load(agent_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tools.genrl_utils import ViCLIPGlobalInstance, DOMAIN2PREDICATES\n", "model_name = getattr(agent.cfg, 'viclip_model', 'viclip')\n", "# Get ViCLIP\n", "if 'viclip_global_instance' not in locals() or model_name != viclip_global_instance._model:\n", " viclip_global_instance = ViCLIPGlobalInstance(model_name)\n", " if not viclip_global_instance._instantiated:\n", " print(\"Instantiating\")\n", " viclip_global_instance.instantiate()\n", " clip = viclip_global_instance.viclip\n", " tokenizer = viclip_global_instance.viclip_tokenizer" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "SAVE = True\n", "DENOISE = True\n", "REVERSE = False\n", "REPEAT_TIME = 2 # standard is n_frames for = 1 \n", "TEXT_OVERLAY = True\n", "\n", "domain = agent.cfg.task.split('_')\n", "\n", "labels_list = ['high kick', 'stand up straight', 'doing splits']\n", "\n", "with torch.no_grad():\n", " wm = world_model = agent.wm\n", " connector = agent.wm.connector\n", " decoder = world_model.heads['decoder']\n", " n_frames = connector.n_frames\n", " \n", " # Get text(video) embed\n", " text_feat = []\n", " for text in labels_list:\n", " with torch.no_grad():\n", " text_feat.append(clip.get_txt_feat(text,))\n", " text_feat = torch.stack(text_feat, dim=0).to(clip.device)\n", "\n", " video_embed = text_feat\n", "\n", " B = video_embed.shape[0]\n", " T = 1\n", "\n", " # Get initial state\n", " init = connector.initial(B, init_embed=video_embed)\n", "\n", " # Get actions\n", " video_embed = video_embed.repeat(1,n_frames, 1)\n", " action = wm.connector.get_action(video_embed)\n", "\n", " with torch.no_grad():\n", " # Imagine\n", " prior = wm.connector.video_imagine(video_embed, None, sample=False, reset_every_n_frames=False, denoise=DENOISE)\n", " # Decode\n", " prior_recon = decoder(wm.decoder_input_fn(prior))['observation'].mean + 0.5\n", "\n", " # Plotting video\n", " R = int(np.sqrt(B))\n", " C = min((B + (R-1)) // R, B) \n", "\n", " fig, axes = plt.subplots(R, C, figsize=(3.5 * C, 4 * R))\n", " fig.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)\n", " fig.set_size_inches(4,4)\n", " \n", " if B == 1:\n", " axes = [[axes]]\n", " elif R == 1:\n", " axes = [axes] \n", " axes = [ a for row in axes for a in row]\n", "\n", " file_path = f'temp_text2video.gif'\n", "\n", " if SAVE:\n", " ims = []\n", " for t in range(prior_recon.shape[1]):\n", " if t == 0 :\n", " continue\n", " toadd = []\n", " for b in range(prior_recon.shape[0]):\n", " ax = axes[b]\n", " ax.set_axis_off()\n", " img = np.clip(prior_recon[b, t if not REVERSE else -t].cpu().permute(1,2,0), 0, 1)\n", " frame = ax.imshow(img)\n", " if TEXT_OVERLAY: \n", " test = ax.text(0,5, labels_list[b], color='white')\n", " toadd.append(frame) # add both the image and the text to the list of artists \n", " ims.append(toadd)\n", "\n", " # Save GIFs\n", " anim = animation.ArtistAnimation(fig, ims, interval=700, blit=True, repeat_delay=700)\n", " writer = animation.PillowWriter(fps=15, metadata=dict(artist='Me'), bitrate=1800)\n", " domain = agent.cfg.task.split('_')[0]\n", " os.makedirs(f'videos/{domain}/text2video', exist_ok=True)\n", " file_path = f'videos/{domain}/text2video/{\"_\".join(labels_list).replace(\" \",\"_\")}.gif'\n", " print(\"GIF path: \", Path(os.path.abspath('')) / file_path)\n", " anim.save(file_path, writer=writer)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.10 ('base')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "3d597f4c481aa0f25dceb95d2a0067e73c0966dcbd003d741d821a7208527ecf" } } }, "nbformat": 4, "nbformat_minor": 2 }