File size: 15,180 Bytes
2d9a728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from pathlib import Path

MODELS_ROOT_PATH = Path(__file__).parent.parent / 'models'
INTERNVIDEO_PATH = Path(__file__).parent.parent / 'third_party' / 'InternVideo'

DOMAIN2PREDICATES = {
    'walker' : ['taking a walk', 'standing up vertically on both feet', 'single-leg balancing', "standing upside down",  'high kick', 'walking', 'stepping forward', 'running fast', 
                'standing on one bended knee', 'lying down on the back with one raised leg', 'sitting on the knees', 'dog yoga pose',  'lying down horizontally', ],
    'stickman' : ['taking a walk', 'standing up vertically', 'one leg balancing',  'high kick', 'walking', 'running fast', 
                'praying', 'lying down with one raised leg', 'dog yoga pose',  'lying down horizontally', 'punching', 'raised hands' ],
    'cheetah' : ['jumping', 'crawling', 'running', 'flipping', 'standing up', 'hopping', 'lying down', 'falling', 
                 'standing on the knees'],
    'quadruped' : ['jumping', 'crawling', 'walking', 'standing up', 
                   'hopping', 'lying down', 'falling', 'standing on the knees'],
    'finger' : ['spin', 'touch', 'rotate', 'horizontal', 'vertical', "not moving", "is not touching", "staying far away", "staying still"],
    'pendulum' : ['horizontal', 'vertical', 'left', 'right', 
                  'swingup', 'balance'],
    'hopper' : ['jumping', 'crawling', 'walking', 'standing up', 
                'hopping', 'lying down', 'falling', 'standing on the knees'],
    'reacher' : ['horizontal', 'vertical', 'ball on the left', 'ball on the right', 'touch the ball with the elbow', 'touch the ball with the tip', 'arm reaches the sphere', 'rotating', 'bending', 'keeping straight', "not moving", "is not touching"],
    'jaco' : ['horizontal', 'vertical', 'left', 'right', 'spin', 'touch', 'rotate', 'bend', 'straight', "is not touching"],
    'kitchen' : [ "touch", "pick up", "lift", "grasp", "hold", "pull", "open", "close", 
                "push", "sweep", "slide"] + ['switch light on', 'open the microwave', 'move the kettle', 'turn on the burner'],
}

TASK2PROMPT = {
    "quadruped_run" : 'spider running fast',
    "quadruped_walk" : 'spider walking fast',
    "quadruped_stand" : 'spider standing',
    "quadruped_jump" : 'spider jumping',

    "quadruped_two_legs" : 'on two legs',
    "quadruped_lie_down" : 'lying down',
    
    "cheetah_run" : 'running like a quadruped',
    
    "cheetah_flipping" : 'quadruped rotating flips',
    "cheetah_standing" : 'standing like a human',
    "cheetah_lying_down" : 'lying down',

    'stickman_walk' : 'robot walk fast clean',
    'stickman_run' : 'robot run fast clean',
    'stickman_stand' : 'standing',
    'stickman_urlb_flip' : 'doing flips',

    'stickman_flip' : 'doing flips',
    'stickman_flipping' : 'doing flips',
    'stickman_backflip' : 'doing backflips',
    'stickman_one_foot' : 'stand on one foot',
    'stickman_high_kick' : 'stand up and kick', 
    'stickman_lying_down' : 'lying down horizontally',
    'stickman_legs_up' : 'lying down with feet up',
    'stickman_sit_knees' : 'praying',
    'stickman_lunge_pose' : 'lunge_pose',
    'stickman_headstand' : 'headstand',
    'stickman_boxing' : 'punch',
    'stickman_hands_up' : 'standing with the hands up',

    'walker_walk' : 'walk fast clean',
    'walker_run' : 'run fast clean',
    'walker_stand' : 'standing up straight',
    'walker_urlb_flip' : 'doing backflips',

    'walker_flip' : 'doing flips',
    'walker_flipping' : 'doing backflips',
    'walker_backflip' : 'doing backflips',
    'walker_one_foot' : 'stand on one foot',
    'walker_high_kick' : 'stand up and kick', 
    'walker_lying_down' : 'lying down horizontally',
    'walker_arabesque' : 'arabesque position', 
    'walker_legs_up' : 'lying down with feet up',
    'walker_sit_knees' : 'praying',
    'walker_lunge_pose' : 'lunge_pose',
    'walker_headstand' : 'headstand',

    'kitchen_microwave' : 'opening the microwave fully open',
    'kitchen_light' : 'activate the light',
    'kitchen_burner' : 'the burner becomes red',
    'kitchen_slide' : 'slide cabinet above the knobs',

    'kitchen_kettle' : 'pushing up the kettle',

    'jaco_reach_top_left' : 'robot grasp the red cube',
    'jaco_reach_bottom_left' : 'robot grasp the red cube',
    'jaco_reach_top_right' : 'robot grasp the red cube',
    'jaco_reach_bottom_right' : 'robot grasp the red cube',
}

class ViCLIPGlobalInstance:
    def __init__(self, model='internvideo2'):
        self._instantiated = False
        self._model = model
    
    def instantiate(self, device='cuda'):
        from torchvision.transforms import transforms as vision_transf
        import sys

        self._instantiated = True
        
        if self._model =='internvideo2':
            sys.path.insert(0, str(INTERNVIDEO_PATH / 'InternVideo2/multi_modality/demo/'))
            sys.path.insert(0, str(INTERNVIDEO_PATH / 'InternVideo2/multi_modality'))
            import numpy as np
            from small_config import (Config, eval_dict_leaf)
            from small_utils import setup_internvideo2 
            config = Config.from_file(INTERNVIDEO_PATH / 'InternVideo2/multi_modality/demo/internvideo2_stage2_config.py')
            config = eval_dict_leaf(config)
            config.model.vision_encoder.num_frames = 8
            config.num_frames = 8
            config.num_frames_test = 8
            # # >> can be configured in case the bert model doesn't load
            # config.model.text_encoder.pretrained = str(MODELS_ROOT_PATH / 'bert-large-uncased')
            config.model.text_encoder.config = str(INTERNVIDEO_PATH / 'InternVideo2/multi_modality') + "/" + config.model.text_encoder.config
            model_pth = str(MODELS_ROOT_PATH / 'InternVideo2-stage2_1b-224p-f4.pt')
            config.pretrained_path = model_pth
            config['model']['vision_encoder']['pretrained'] = model_pth
            intern_model, tokenizer = setup_internvideo2(config)
            self.viclip_tokenizer = tokenizer
            self.viclip = intern_model
            self.viclip.device = device 
            self.viclip.to(self.viclip.device)
            self.viclip.eval()
            self.viclip.n_frames = 8
            self.viclip.preprocess_transf = vision_transf.Compose([
                vision_transf.Resize(size=(224, 224), interpolation=vision_transf.InterpolationMode.BILINEAR), 
                vision_transf.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
            sys.path.pop(0)
            sys.path.pop(0)
        else:
            raise NotImplementedError(f"Model {self._model} not implemented")
    
        vid_feat = self.viclip.get_vid_features(torch.zeros(1,self.viclip.n_frames,3,224,224, device=self.viclip.device))
        self.viclip_emb_dim = vid_feat.shape[1]


def report_text2video(agent, data,):
    report = {}

    domain = agent.cfg.task.split('_')[0]
    labels_list = DOMAIN2PREDICATES[domain]

    wm = world_model = agent.wm
    decoder = world_model.heads['decoder'] # B, T, C, H, W
    connector = agent.wm.connector
    n_frames = connector.n_frames

    if hasattr(world_model, 'viclip_model'):
        clip = world_model.viclip_model
    else:
        # Get ViCLIP
        viclip_global_instance = globals()['viclip_global_instance']
        if not viclip_global_instance._instantiated:
            viclip_global_instance.instantiate()
        clip = viclip_global_instance.viclip

    # Get text(video) embed
    text_feat = []
    for text in labels_list:
        with torch.no_grad():
            text_feat.append(clip.get_txt_feat(text,))
    text_feat = torch.stack(text_feat, dim=0)
    # Check device is right
    video_embed = text_feat.to(agent.device)
    B = video_embed.shape[0]

    # Get actions
    video_embed = video_embed.repeat(1,n_frames, 1)
    # Imagine
    prior = wm.connector.video_imagine(video_embed, dreamer_init=None, sample=False, reset_every_n_frames=False, denoise=True)
    prior_recon = decoder(wm.decoder_input_fn(prior))['observation'].mean + 0.5
    report[f'text_to_video'] = prior_recon
    return report    

def max_cosine_similarity(u, v, dim=-1):
    max_norm = torch.max(torch.norm(u, dim=dim), torch.norm(v, dim=dim)).unsqueeze(-1)
    return torch.sum((u / max_norm) * (v / max_norm), dim=dim)

def neg_mse_fn(a, b, dim=-1, scale=True):
    dist = - torch.norm(a - b, dim=dim)
    if scale:
        dist = dist / np.sqrt(a.shape[-1]).item()
    return dist

def compute_reward(agent, agent_seq, target_seq, score_fn='cosine',):
    if score_fn in ['cosine', 'max_cosine', 'neg_mse', 'exp_neg_mse']:
        distance_fn = dict(cosine=F.cosine_similarity, max_cosine=max_cosine_similarity, neg_mse=neg_mse_fn, exp_neg_mse=neg_mse_fn)[score_fn]
        target_stoch = agent.wm.connector.get_stoch( target_seq )
        agent_stoch = agent.wm.rssm.get_stoch( agent_seq )
        conv_target = agent.wm.heads['decoder']._conv_in[0](target_stoch)
        conv_agent = agent.wm.heads['decoder']._conv_in[0](agent_stoch)
        reward = distance_fn(conv_target, conv_agent, dim=-1)
        if score_fn == 'exp_neg_mse':
            reward = torch.exp(reward)
    elif score_fn == 'neg_kl':
        agent_dist = agent.wm.rssm.get_dist( agent_seq )
        target_dist = agent.wm.connector.get_dist( target_seq )
        reward =  -torch.distributions.kl_divergence(agent_dist, target_dist,)
        # scaling factor ( x log x w.r.t. to classes, or just x)
        if 'logit' in target_seq:
            reward = reward / ( np.log(target_seq['logit'].shape[-1]) * target_seq['logit'].shape[-2] )
        else:
            reward = reward / target_seq['mean'].shape[-1]
    elif score_fn == 'max_like':
        agent_dist = agent.wm.rssm.get_dist( agent_seq )
        target_sample = target_seq['stoch']
        reward =  agent_dist.log_prob(target_sample)
    elif score_fn == 'combo':
        return compute_reward(agent, agent_seq, target_seq, 'cosine') + compute_reward(agent, agent_seq, target_seq, 'neg_kl')        
    else:
        raise NotImplementedError(f"{score_fn} reward not implemented")
    return reward

def video_text_reward(agent, seq, score_fn='cosine', 
                      sample_for_target=False, weighted_align=False, align_initial=False, align_sequence=False, 
                      task_prompt='', skip_first_target=False, **kwargs):
    wm = world_model = agent.wm
    connector = agent.wm.connector
    n_frames = connector.n_frames

    T, B = seq['deter'].shape[:2]
    imagined_steps = T

    if not hasattr(agent, 'unconditional_target'):
        if hasattr(world_model, 'viclip_model'):
            clip = world_model.viclip_model
        else:
            # Get ViCLIP
            viclip_global_instance = globals()['viclip_global_instance']
            if not viclip_global_instance._instantiated:
                viclip_global_instance.instantiate()
            clip = viclip_global_instance.viclip

        if task_prompt != '':
            task = [task_prompt]
        else:
            task = [ TASK2PROMPT[agent.cfg.task] ]
            
        # Get text(video) embed
        with torch.no_grad():
            text_feat = clip.get_txt_feat(task[0],)
        # Check device is right
        video_embed = text_feat.to(agent.device)

        # Unconditional gen
        if skip_first_target:
            video_embed = video_embed.reshape(1, 1, -1).repeat(B, imagined_steps + 1, 1)
            unconditional_stats = wm.connector.video_imagine(video_embed, dreamer_init=None, sample=sample_for_target, reset_every_n_frames=False, denoise=True) 
            unconditional_stats = { k: v[:,1:].permute([1,0] + list(range(2, len(v.shape)))) for k,v in unconditional_stats.items() }
        else:
            video_embed = video_embed.reshape(1, 1, -1).repeat(B, imagined_steps, 1)
            unconditional_stats = wm.connector.video_imagine(video_embed, dreamer_init=None, sample=sample_for_target, reset_every_n_frames=False, denoise=True) 
            unconditional_stats = { k: v.permute([1,0] + list(range(2, len(v.shape)))) for k,v in unconditional_stats.items() }
        agent.unconditional_target = unconditional_stats
    else:
        unconditional_stats = agent.unconditional_target
        
    agent_seq = seq
    target_seq = unconditional_stats
    if align_initial:
        assert not align_sequence, 'Cannot align initial and sequence at the same time'
        init_seq = { k: v[0] for k,v in target_seq.items() }
        init_score = compute_reward(agent, agent_seq, init_seq, score_fn=score_fn,)
        if weighted_align:
            w = 0.99 * torch.ones_like(init_score, device=init_score.device)
            w = torch.cumprod(w, dim=1)
            init_score = w * init_score
        # 
        best_indexes_one_hot = F.one_hot(torch.argmax(init_score, dim=0), num_classes=target_seq['stoch'].shape[0])
        ts_idx = torch.clip(torch.cumsum(torch.cumsum(best_indexes_one_hot, dim=1), dim=1) - 1, min=0).T
        new_target_seq = {}
        for k,v in target_seq.items():
            if len(v.shape) == 4:
                new_ts = ts_idx.unsqueeze(-1).unsqueeze(-1).repeat(1,1, v.shape[-2], v.shape[-1])
            else:
                new_ts = ts_idx.unsqueeze(-1).repeat(1,1, v.shape[-1])
            new_target_seq[k] = torch.gather(v, 0, new_ts) # out[i][j][k] = input[index[i][j][k]][j][k]
        return compute_reward(agent, agent_seq, new_target_seq, score_fn=score_fn,).unsqueeze(-1)
    elif align_sequence:
        align_score = []
        get_prev_a_b = lambda d, a, b : { k : v[a:b] for k,v in d.items() } 
        shorter_target_seq = get_prev_a_b(unconditional_stats, 0, n_frames)
        for t in range(T-n_frames):
            cur_agent_seq = get_prev_a_b(seq, t, t+n_frames)
            score = compute_reward(agent, cur_agent_seq, shorter_target_seq, score_fn=score_fn,).mean(dim=0) # 0 is time dimension
            align_score.append(score) 
        align_score = torch.stack(align_score, dim=0)
        if weighted_align:
            w = 0.99 * torch.ones_like(align_score, device=align_score.device)
            w = torch.cumprod(w, dim=1)
            align_score = w * align_score
        best_indexes_one_hot = F.one_hot(torch.argmax(align_score, dim=0), num_classes=target_seq['stoch'].shape[0])
        ts_idx = torch.clip(torch.cumsum(torch.cumsum(best_indexes_one_hot, dim=1), dim=1) - 1, min=0).T
        new_target_seq = {}
        for k,v in target_seq.items():
            if len(v.shape) == 4:
                new_ts = ts_idx.unsqueeze(-1).unsqueeze(-1).repeat(1,1, v.shape[-2], v.shape[-1])
            else:
                new_ts = ts_idx.unsqueeze(-1).repeat(1,1, v.shape[-1])
            new_target_seq[k] = torch.gather(v, 0, new_ts) # out[i][j][k] = input[index[i][j][k]][j][k]
        return compute_reward(agent, agent_seq, new_target_seq, score_fn=score_fn,).unsqueeze(-1)
    else:
        neg_kl = compute_reward(agent, agent_seq, target_seq, score_fn=score_fn,)
    
    return neg_kl.unsqueeze(-1)

global viclip_global_instance
viclip_global_instance = ViCLIPGlobalInstance()