Spaces:
Configuration error
Configuration error
import torch | |
import torch.nn as nn | |
import numpy as np | |
import os | |
from args import get_parser | |
import pickle | |
from model import get_model | |
from torchvision import transforms | |
from utils.output_ing import prepare_output | |
from PIL import Image | |
from tqdm import tqdm | |
import time | |
import glob | |
# Set ```data_dir``` to the path including vocabularies and model checkpoint | |
model_dir = '../data' | |
image_folder = '../data/demo_imgs' | |
output_file = "../data/predicted_ingr.pkl" | |
# code will run in gpu if available and if the flag is set to True, else it will run on cpu | |
use_gpu = False | |
device = torch.device('cuda' if torch.cuda.is_available() and use_gpu else 'cpu') | |
map_loc = None if torch.cuda.is_available() and use_gpu else 'cpu' | |
# code below was used to save vocab files so that they can be loaded without Vocabulary class | |
#ingrs_vocab = pickle.load(open(os.path.join(data_dir, 'final_recipe1m_vocab_ingrs.pkl'), 'rb')) | |
#ingrs_vocab = [min(w, key=len) if not isinstance(w, str) else w for w in ingrs_vocab.idx2word.values()] | |
#vocab = pickle.load(open(os.path.join(data_dir, 'final_recipe1m_vocab_toks.pkl'), 'rb')).idx2word | |
#pickle.dump(ingrs_vocab, open('../demo/ingr_vocab.pkl', 'wb')) | |
#pickle.dump(vocab, open('../demo/instr_vocab.pkl', 'wb')) | |
ingrs_vocab = pickle.load(open(os.path.join(model_dir, 'ingr_vocab.pkl'), 'rb')) | |
vocab = pickle.load(open(os.path.join(model_dir, 'instr_vocab.pkl'), 'rb')) | |
ingr_vocab_size = len(ingrs_vocab) | |
instrs_vocab_size = len(vocab) | |
output_dim = instrs_vocab_size | |
print (instrs_vocab_size, ingr_vocab_size) | |
t = time.time() | |
args = get_parser() | |
args.maxseqlen = 15 | |
args.ingrs_only=True | |
model = get_model(args, ingr_vocab_size, instrs_vocab_size) | |
# Load the trained model parameters | |
model_path = os.path.join(model_dir, 'modelbest.ckpt') | |
model.load_state_dict(torch.load(model_path, map_location=map_loc)) | |
model.to(device) | |
model.eval() | |
model.ingrs_only = True | |
model.recipe_only = False | |
print ('loaded model') | |
print ("Elapsed time:", time.time() -t) | |
transf_list_batch = [] | |
transf_list_batch.append(transforms.ToTensor()) | |
transf_list_batch.append(transforms.Normalize((0.485, 0.456, 0.406), | |
(0.229, 0.224, 0.225))) | |
to_input_transf = transforms.Compose(transf_list_batch) | |
greedy = True | |
beam = -1 | |
temperature = 1.0 | |
# import requests | |
# from io import BytesIO | |
# import random | |
# from collections import Counter | |
# use_urls = False # set to true to load images from demo_urls instead of those in test_imgs folder | |
# show_anyways = False #if True, it will show the recipe even if it's not valid | |
# image_folder = os.path.join(data_dir, 'demo_imgs') | |
# if not use_urls: | |
# demo_imgs = os.listdir(image_folder) | |
# random.shuffle(demo_imgs) | |
# demo_urls = ['https://food.fnr.sndimg.com/content/dam/images/food/fullset/2013/12/9/0/FNK_Cheesecake_s4x3.jpg.rend.hgtvcom.826.620.suffix/1387411272847.jpeg', | |
# 'https://www.196flavors.com/wp-content/uploads/2014/10/california-roll-3-FP.jpg'] | |
files_path = glob.glob(f"{image_folder}/*/*/*.jpg") | |
print(f"total data: {len(files_path)}") | |
res = [] | |
for idx, img_file in tqdm(enumerate(files_path)): | |
# if use_urls: | |
# response = requests.get(img_file) | |
# image = Image.open(BytesIO(response.content)) | |
# else: | |
image = Image.open(img_file).convert('RGB') | |
transf_list = [] | |
transf_list.append(transforms.Resize(256)) | |
transf_list.append(transforms.CenterCrop(224)) | |
transform = transforms.Compose(transf_list) | |
image_transf = transform(image) | |
image_tensor = to_input_transf(image_transf).unsqueeze(0).to(device) | |
# plt.imshow(image_transf) | |
# plt.axis('off') | |
# plt.show() | |
# plt.close() | |
with torch.no_grad(): | |
outputs = model.sample(image_tensor, greedy=greedy, | |
temperature=temperature, beam=beam, true_ingrs=None) | |
ingr_ids = outputs['ingr_ids'].cpu().numpy() | |
print(ingr_ids) | |
outs = prepare_output(ingr_ids[0], ingrs_vocab) | |
# print(ingrs_vocab.idx2word) | |
print(outs) | |
# print ('Pic ' + str(idx+1) + ':') | |
# print ('\nIngredients:') | |
# print (', '.join(outs['ingrs'])) | |
# print ('='*20) | |
res.append({ | |
"id": img_file, | |
"ingredients": outs['ingrs'] | |
}) | |
with open(output_file, "wb") as fp: #Pickling | |
pickle.dump(res, fp) | |