Spaces:
Configuration error
Configuration error
File size: 10,077 Bytes
2a3a041 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
import torch.nn as nn
import random
import numpy as np
from src.modules.encoder import EncoderCNN, EncoderLabels
from src.modules.transformer_decoder import DecoderTransformer
from src.modules.multihead_attention import MultiheadAttention
from src.utils.metrics import softIoU, MaskedCrossEntropyCriterion
import pickle
import os
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def label2onehot(labels, pad_value):
# input labels to one hot vector
inp_ = torch.unsqueeze(labels, 2)
one_hot = torch.FloatTensor(labels.size(0), labels.size(1), pad_value + 1).zero_().to(device)
one_hot.scatter_(2, inp_, 1)
one_hot, _ = one_hot.max(dim=1)
# remove pad position
one_hot = one_hot[:, :-1]
# eos position is always 0
one_hot[:, 0] = 0
return one_hot
def mask_from_eos(ids, eos_value, mult_before=True):
mask = torch.ones(ids.size()).to(device).byte()
mask_aux = torch.ones(ids.size(0)).to(device).byte()
# find eos in ingredient prediction
for idx in range(ids.size(1)):
# force mask to have 1s in the first position to avoid division by 0 when predictions start with eos
if idx == 0:
continue
if mult_before:
mask[:, idx] = mask[:, idx] * mask_aux
mask_aux = mask_aux * (ids[:, idx] != eos_value)
else:
mask_aux = mask_aux * (ids[:, idx] != eos_value)
mask[:, idx] = mask[:, idx] * mask_aux
return mask
def get_model(args, ingr_vocab_size, instrs_vocab_size):
# build ingredients embedding
encoder_ingrs = EncoderLabels(args.embed_size, ingr_vocab_size,
args.dropout_encoder, scale_grad=False).to(device)
# build image model
encoder_image = EncoderCNN(args.embed_size, args.dropout_encoder, args.image_model)
decoder = DecoderTransformer(args.embed_size, instrs_vocab_size,
dropout=args.dropout_decoder_r, seq_length=args.maxseqlen,
num_instrs=args.maxnuminstrs,
attention_nheads=args.n_att, num_layers=args.transf_layers,
normalize_before=True,
normalize_inputs=False,
last_ln=False,
scale_embed_grad=False)
ingr_decoder = DecoderTransformer(args.embed_size, ingr_vocab_size, dropout=args.dropout_decoder_i,
seq_length=args.maxnumlabels,
num_instrs=1, attention_nheads=args.n_att_ingrs,
pos_embeddings=False,
num_layers=args.transf_layers_ingrs,
learned=False,
normalize_before=True,
normalize_inputs=True,
last_ln=True,
scale_embed_grad=False)
# recipe loss
criterion = MaskedCrossEntropyCriterion(ignore_index=[instrs_vocab_size-1], reduce=False)
# ingredients loss
label_loss = nn.BCELoss(reduce=False)
eos_loss = nn.BCELoss(reduce=False)
model = InverseCookingModel(encoder_ingrs, decoder, ingr_decoder, encoder_image,
crit=criterion, crit_ingr=label_loss, crit_eos=eos_loss,
pad_value=ingr_vocab_size-1,
ingrs_only=args.ingrs_only, recipe_only=args.recipe_only,
label_smoothing=args.label_smoothing_ingr)
return model
class InverseCookingModel(nn.Module):
def __init__(self, ingredient_encoder, recipe_decoder, ingr_decoder, image_encoder,
crit=None, crit_ingr=None, crit_eos=None,
pad_value=0, ingrs_only=True,
recipe_only=False, label_smoothing=0.0):
super(InverseCookingModel, self).__init__()
self.ingredient_encoder = ingredient_encoder
self.recipe_decoder = recipe_decoder
self.image_encoder = image_encoder
self.ingredient_decoder = ingr_decoder
self.crit = crit
self.crit_ingr = crit_ingr
self.pad_value = pad_value
self.ingrs_only = ingrs_only
self.recipe_only = recipe_only
self.crit_eos = crit_eos
self.label_smoothing = label_smoothing
def forward(self, img_inputs, captions, target_ingrs,
sample=False, keep_cnn_gradients=False):
if sample:
return self.sample(img_inputs, greedy=True)
targets = captions[:, 1:]
targets = targets.contiguous().view(-1)
img_features = self.image_encoder(img_inputs, keep_cnn_gradients)
losses = {}
target_one_hot = label2onehot(target_ingrs, self.pad_value)
target_one_hot_smooth = label2onehot(target_ingrs, self.pad_value)
# ingredient prediction
if not self.recipe_only:
target_one_hot_smooth[target_one_hot_smooth == 1] = (1-self.label_smoothing)
target_one_hot_smooth[target_one_hot_smooth == 0] = self.label_smoothing / target_one_hot_smooth.size(-1)
# decode ingredients with transformer
# autoregressive mode for ingredient decoder
ingr_ids, ingr_logits = self.ingredient_decoder.sample(None, None, greedy=True,
temperature=1.0, img_features=img_features,
first_token_value=0, replacement=False)
ingr_logits = torch.nn.functional.softmax(ingr_logits, dim=-1)
# find idxs for eos ingredient
# eos probability is the one assigned to the first position of the softmax
eos = ingr_logits[:, :, 0]
target_eos = ((target_ingrs == 0) ^ (target_ingrs == self.pad_value))
eos_pos = (target_ingrs == 0)
eos_head = ((target_ingrs != self.pad_value) & (target_ingrs != 0))
# select transformer steps to pool from
mask_perminv = mask_from_eos(target_ingrs, eos_value=0, mult_before=False)
ingr_probs = ingr_logits * mask_perminv.float().unsqueeze(-1)
ingr_probs, _ = torch.max(ingr_probs, dim=1)
# ignore predicted ingredients after eos in ground truth
ingr_ids[mask_perminv == 0] = self.pad_value
ingr_loss = self.crit_ingr(ingr_probs, target_one_hot_smooth)
ingr_loss = torch.mean(ingr_loss, dim=-1)
losses['ingr_loss'] = ingr_loss
# cardinality penalty
losses['card_penalty'] = torch.abs((ingr_probs*target_one_hot).sum(1) - target_one_hot.sum(1)) + \
torch.abs((ingr_probs*(1-target_one_hot)).sum(1))
eos_loss = self.crit_eos(eos, target_eos.float())
mult = 1/2
# eos loss is only computed for timesteps <= t_eos and equally penalizes 0s and 1s
losses['eos_loss'] = mult*(eos_loss * eos_pos.float()).sum(1) / (eos_pos.float().sum(1) + 1e-6) + \
mult*(eos_loss * eos_head.float()).sum(1) / (eos_head.float().sum(1) + 1e-6)
# iou
pred_one_hot = label2onehot(ingr_ids, self.pad_value)
# iou sample during training is computed using the true eos position
losses['iou'] = softIoU(pred_one_hot, target_one_hot)
if self.ingrs_only:
return losses
# encode ingredients
target_ingr_feats = self.ingredient_encoder(target_ingrs)
target_ingr_mask = mask_from_eos(target_ingrs, eos_value=0, mult_before=False)
target_ingr_mask = target_ingr_mask.float().unsqueeze(1)
outputs, ids = self.recipe_decoder(target_ingr_feats, target_ingr_mask, captions, img_features)
outputs = outputs[:, :-1, :].contiguous()
outputs = outputs.view(outputs.size(0) * outputs.size(1), -1)
loss = self.crit(outputs, targets)
losses['recipe_loss'] = loss
return losses
def sample(self, img_inputs, greedy=True, temperature=1.0, beam=-1, true_ingrs=None):
outputs = dict()
img_features = self.image_encoder(img_inputs)
if not self.recipe_only:
ingr_ids, ingr_probs = self.ingredient_decoder.sample(None, None, greedy=True, temperature=temperature,
beam=-1,
img_features=img_features, first_token_value=0,
replacement=False)
# mask ingredients after finding eos
sample_mask = mask_from_eos(ingr_ids, eos_value=0, mult_before=False)
ingr_ids[sample_mask == 0] = self.pad_value
outputs['ingr_ids'] = ingr_ids
outputs['ingr_probs'] = ingr_probs.data
mask = sample_mask
input_mask = mask.float().unsqueeze(1)
input_feats = self.ingredient_encoder(ingr_ids)
if self.ingrs_only:
return outputs
# option during sampling to use the real ingredients and not the predicted ones to infer the recipe
if true_ingrs is not None:
input_mask = mask_from_eos(true_ingrs, eos_value=0, mult_before=False)
true_ingrs[input_mask == 0] = self.pad_value
input_feats = self.ingredient_encoder(true_ingrs)
input_mask = input_mask.unsqueeze(1)
ids, probs = self.recipe_decoder.sample(input_feats, input_mask, greedy, temperature, beam, img_features, 0,
last_token_value=1)
outputs['recipe_probs'] = probs.data
outputs['recipe_ids'] = ids
return outputs
|