In [None]:
from pathlib import Path 
import os
import sys
sys.path.append(str(Path(os.path.abspath('')).parent))

import torch
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.animation as animation

agent_path = Path(os.path.abspath('')).parent / 'models' / 'genrl_stickman_500k_2.pt'
print("Model path", agent_path)

agent = torch.load(agent_path)

In [None]:
from tools.genrl_utils import ViCLIPGlobalInstance, DOMAIN2PREDICATES
model_name = getattr(agent.cfg, 'viclip_model', 'viclip')
# Get ViCLIP
if 'viclip_global_instance' not in locals() or model_name != viclip_global_instance._model:
 viclip_global_instance = ViCLIPGlobalInstance(model_name)
 if not viclip_global_instance._instantiated:
 print("Instantiating")
 viclip_global_instance.instantiate()
 clip = viclip_global_instance.viclip
 tokenizer = viclip_global_instance.viclip_tokenizer

In [None]:
SAVE = True
DENOISE = True
REVERSE = False
REPEAT_TIME = 2 # standard is n_frames for = 1 
TEXT_OVERLAY = True

domain = agent.cfg.task.split('_')

labels_list = ['high kick', 'stand up straight', 'doing splits']

with torch.no_grad():
 wm = world_model = agent.wm
 connector = agent.wm.connector
 decoder = world_model.heads['decoder']
 n_frames = connector.n_frames
 
 # Get text(video) embed
 text_feat = []
 for text in labels_list:
 with torch.no_grad():
 text_feat.append(clip.get_txt_feat(text,))
 text_feat = torch.stack(text_feat, dim=0).to(clip.device)

 video_embed = text_feat

 B = video_embed.shape[0]
 T = 1

 # Get initial state
 init = connector.initial(B, init_embed=video_embed)

 # Get actions
 video_embed = video_embed.repeat(1,n_frames, 1)
 action = wm.connector.get_action(video_embed)

 with torch.no_grad():
 # Imagine
 prior = wm.connector.video_imagine(video_embed, None, sample=False, reset_every_n_frames=False, denoise=DENOISE)
 # Decode
 prior_recon = decoder(wm.decoder_input_fn(prior))['observation'].mean + 0.5

 # Plotting video
 R = int(np.sqrt(B))
 C = min((B + (R-1)) // R, B) 

 fig, axes = plt.subplots(R, C, figsize=(3.5 * C, 4 * R))
 fig.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
 fig.set_size_inches(4,4)
 
 if B == 1:
 axes = [[axes]]
 elif R == 1:
 axes = [axes] 
 axes = [ a for row in axes for a in row]

 file_path = f'temp_text2video.gif'

 if SAVE:
 ims = []
 for t in range(prior_recon.shape[1]):
 if t == 0 :
 continue
 toadd = []
 for b in range(prior_recon.shape[0]):
 ax = axes[b]
 ax.set_axis_off()
 img = np.clip(prior_recon[b, t if not REVERSE else -t].cpu().permute(1,2,0), 0, 1)
 frame = ax.imshow(img)
 if TEXT_OVERLAY: 
 test = ax.text(0,5, labels_list[b], color='white')
 toadd.append(frame) # add both the image and the text to the list of artists 
 ims.append(toadd)

 # Save GIFs
 anim = animation.ArtistAnimation(fig, ims, interval=700, blit=True, repeat_delay=700)
 writer = animation.PillowWriter(fps=15, metadata=dict(artist='Me'), bitrate=1800)
 domain = agent.cfg.task.split('_')[0]
 os.makedirs(f'videos/{domain}/text2video', exist_ok=True)
 file_path = f'videos/{domain}/text2video/{"_".join(labels_list).replace(" ","_")}.gif'
 print("GIF path: ", Path(os.path.abspath('')) / file_path)
 anim.save(file_path, writer=writer)