File size: 6,342 Bytes
ebd4e51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
from . import losses
from ..utils.rewards import init_scorer, get_self_critical_reward, get_self_critical_clipscore_reward
from ..utils.clipscore import CLIPScore
import numpy as np

class LossWrapper(torch.nn.Module):
    def __init__(self, model, opt):
        super(LossWrapper, self).__init__()
        self.opt = opt
        self.model = model
        if opt.label_smoothing > 0:
            self.crit = losses.LabelSmoothing(smoothing=opt.label_smoothing)
        else:
            self.crit = losses.LanguageModelCriterion()
        self.rl_crit = losses.RewardCriterion()
        self.struc_crit = losses.StructureLosses(opt)

        self.clipscore_model = None
        if self.opt.use_clipscore:
            use_grammar = getattr(self.opt, 'use_grammar', False)
            joint_out = getattr(self.opt, 'joint_out', False)
            self.clipscore_model = CLIPScore(
                mode=opt.clipscore_mode,
                use_grammar=use_grammar,
                joint_out=joint_out,
                )
            for p in self.clipscore_model.parameters():
                p.requires_grad = False

            if use_grammar:
                state_dict = torch.load(self.opt.clip_load_path, map_location='cpu')
                self.clipscore_model.load_state_dict(state_dict['state_dict'])

    def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices,
                sc_flag, struc_flag, clip_vis_feats=None):
        opt = self.opt
        
        out = {}
        if struc_flag:
            if opt.structure_loss_weight < 1:
                lm_loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
            else:
                lm_loss = torch.tensor(0).type_as(fc_feats)
            if opt.structure_loss_weight > 0:
                gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks,
                    opt={'sample_method':opt.train_sample_method,
                        'beam_size':opt.train_beam_size,
                        'output_logsoftmax': opt.struc_use_logsoftmax or opt.structure_loss_type == 'softmax_margin'\
                            or not 'margin' in opt.structure_loss_type,
                        'sample_n': opt.train_sample_n},
                    mode='sample')
                gts = [gts[_] for _ in gt_indices.tolist()]
                struc_loss = self.struc_crit(sample_logprobs, gen_result, gts)
            else:
                struc_loss = {'loss': torch.tensor(0).type_as(fc_feats),
                              'reward': torch.tensor(0).type_as(fc_feats)}
            loss = (1-opt.structure_loss_weight) * lm_loss + opt.structure_loss_weight * struc_loss['loss']
            out['lm_loss'] = lm_loss
            out['struc_loss'] = struc_loss['loss']
            out['reward'] = struc_loss['reward']
        elif not sc_flag:
            loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
        else:
            self.model.eval()
            with torch.no_grad():
                greedy_res, _ = self.model(fc_feats, att_feats, att_masks,
                    mode='sample',
                    opt={'sample_method': opt.sc_sample_method,
                         'beam_size': opt.sc_beam_size})
            self.model.train()
            gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks,
                    opt={'sample_method':opt.train_sample_method,
                        'beam_size':opt.train_beam_size,
                        'sample_n': opt.train_sample_n},
                    mode='sample')
            gts = [gts[_] for _ in gt_indices.tolist()]

            if getattr(self.opt, 'use_multi_rewards', False):
                assert self.opt.use_clipscore
                clipscore_reward_normalized, clipscore_unnormalized_mean, grammar_rewards = get_self_critical_clipscore_reward(
                    greedy_res, gts, gen_result, self.opt, self.clipscore_model, clip_vis_feats, self.model.vocab)

                if self.opt.clipscore_mode == 'clip_s':
                    out['CLIP-S'] = clipscore_unnormalized_mean
                elif self.opt.clipscore_mode == 'refclip_s':
                    out['RefCLIP-S'] = clipscore_unnormalized_mean

                if getattr(self.opt, 'use_grammar', False):
                    out['grammar_reward'] = grammar_rewards.mean()

                    reward = clipscore_reward_normalized + grammar_rewards


                else:
                    assert grammar_rewards is None

                    cider_reward_normalized, cider_unnormalized_mean = get_self_critical_reward(
                        greedy_res, gts, gen_result, self.opt)
                    out['CIDEr'] = cider_unnormalized_mean
                    if isinstance(cider_reward_normalized, np.ndarray):
                        cider_reward_normalized = torch.from_numpy(cider_reward_normalized).to(clipscore_reward_normalized.device)

                    reward = clipscore_reward_normalized + cider_reward_normalized
            else:
                if self.opt.use_clipscore:
                    clipscore_reward_normalized, clipscore_unnormalized_mean, _ = get_self_critical_clipscore_reward(
                        greedy_res, gts, gen_result, self.opt, self.clipscore_model, clip_vis_feats, self.model.vocab)
                    if self.opt.clipscore_mode == 'clip_s':
                        out['CLIP-S'] = clipscore_unnormalized_mean
                    elif self.opt.clipscore_mode == 'refclip_s':
                        out['RefCLIP-S'] = clipscore_unnormalized_mean
                    reward = clipscore_reward_normalized
                else:
                    cider_reward_normalized, cider_unnormalized_mean = get_self_critical_reward(
                        greedy_res, gts, gen_result, self.opt)
                    out['CIDEr'] = cider_unnormalized_mean
                    reward = cider_reward_normalized

            if isinstance(reward, np.ndarray):
                reward = torch.from_numpy(reward)
            reward = reward.to(sample_logprobs)
            loss = self.rl_crit(sample_logprobs, gen_result.data, reward)
            out['reward'] = reward[:,0].mean()
        out['loss'] = loss
        return out