Spaces:
Configuration error
Configuration error
File size: 4,350 Bytes
2a3a041 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch
import torch.nn as nn
import numpy as np
import os
from args import get_parser
import pickle
from model import get_model
from torchvision import transforms
from utils.output_ing import prepare_output
from PIL import Image
from tqdm import tqdm
import time
import glob
# Set ```data_dir``` to the path including vocabularies and model checkpoint
model_dir = '../data'
image_folder = '../data/demo_imgs'
output_file = "../data/predicted_ingr.pkl"
# code will run in gpu if available and if the flag is set to True, else it will run on cpu
use_gpu = False
device = torch.device('cuda' if torch.cuda.is_available() and use_gpu else 'cpu')
map_loc = None if torch.cuda.is_available() and use_gpu else 'cpu'
# code below was used to save vocab files so that they can be loaded without Vocabulary class
#ingrs_vocab = pickle.load(open(os.path.join(data_dir, 'final_recipe1m_vocab_ingrs.pkl'), 'rb'))
#ingrs_vocab = [min(w, key=len) if not isinstance(w, str) else w for w in ingrs_vocab.idx2word.values()]
#vocab = pickle.load(open(os.path.join(data_dir, 'final_recipe1m_vocab_toks.pkl'), 'rb')).idx2word
#pickle.dump(ingrs_vocab, open('../demo/ingr_vocab.pkl', 'wb'))
#pickle.dump(vocab, open('../demo/instr_vocab.pkl', 'wb'))
ingrs_vocab = pickle.load(open(os.path.join(model_dir, 'ingr_vocab.pkl'), 'rb'))
vocab = pickle.load(open(os.path.join(model_dir, 'instr_vocab.pkl'), 'rb'))
ingr_vocab_size = len(ingrs_vocab)
instrs_vocab_size = len(vocab)
output_dim = instrs_vocab_size
print (instrs_vocab_size, ingr_vocab_size)
t = time.time()
args = get_parser()
args.maxseqlen = 15
args.ingrs_only=True
model = get_model(args, ingr_vocab_size, instrs_vocab_size)
# Load the trained model parameters
model_path = os.path.join(model_dir, 'modelbest.ckpt')
model.load_state_dict(torch.load(model_path, map_location=map_loc))
model.to(device)
model.eval()
model.ingrs_only = True
model.recipe_only = False
print ('loaded model')
print ("Elapsed time:", time.time() -t)
transf_list_batch = []
transf_list_batch.append(transforms.ToTensor())
transf_list_batch.append(transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225)))
to_input_transf = transforms.Compose(transf_list_batch)
greedy = True
beam = -1
temperature = 1.0
# import requests
# from io import BytesIO
# import random
# from collections import Counter
# use_urls = False # set to true to load images from demo_urls instead of those in test_imgs folder
# show_anyways = False #if True, it will show the recipe even if it's not valid
# image_folder = os.path.join(data_dir, 'demo_imgs')
# if not use_urls:
# demo_imgs = os.listdir(image_folder)
# random.shuffle(demo_imgs)
# demo_urls = ['https://food.fnr.sndimg.com/content/dam/images/food/fullset/2013/12/9/0/FNK_Cheesecake_s4x3.jpg.rend.hgtvcom.826.620.suffix/1387411272847.jpeg',
# 'https://www.196flavors.com/wp-content/uploads/2014/10/california-roll-3-FP.jpg']
files_path = glob.glob(f"{image_folder}/*/*/*.jpg")
print(f"total data: {len(files_path)}")
res = []
for idx, img_file in tqdm(enumerate(files_path)):
# if use_urls:
# response = requests.get(img_file)
# image = Image.open(BytesIO(response.content))
# else:
image = Image.open(img_file).convert('RGB')
transf_list = []
transf_list.append(transforms.Resize(256))
transf_list.append(transforms.CenterCrop(224))
transform = transforms.Compose(transf_list)
image_transf = transform(image)
image_tensor = to_input_transf(image_transf).unsqueeze(0).to(device)
# plt.imshow(image_transf)
# plt.axis('off')
# plt.show()
# plt.close()
with torch.no_grad():
outputs = model.sample(image_tensor, greedy=greedy,
temperature=temperature, beam=beam, true_ingrs=None)
ingr_ids = outputs['ingr_ids'].cpu().numpy()
print(ingr_ids)
outs = prepare_output(ingr_ids[0], ingrs_vocab)
# print(ingrs_vocab.idx2word)
print(outs)
# print ('Pic ' + str(idx+1) + ':')
# print ('\nIngredients:')
# print (', '.join(outs['ingrs']))
# print ('='*20)
res.append({
"id": img_file,
"ingredients": outs['ingrs']
})
with open(output_file, "wb") as fp: #Pickling
pickle.dump(res, fp)
|