johnsonhung commited on
Commit
2a3a041
·
1 Parent(s): 2d2ef3c
.gitignore ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ .hypothesis/
48
+ .pytest_cache/
49
+
50
+ # Translations
51
+ *.mo
52
+ *.pot
53
+
54
+ # Django stuff:
55
+ *.log
56
+ local_settings.py
57
+ db.sqlite3
58
+
59
+ # Flask stuff:
60
+ instance/
61
+ .webassets-cache
62
+
63
+ # Scrapy stuff:
64
+ .scrapy
65
+
66
+ # Sphinx documentation
67
+ docs/_build/
68
+
69
+ # PyBuilder
70
+ target/
71
+
72
+ # Jupyter Notebook
73
+ .ipynb_checkpoints
74
+
75
+ # pyenv
76
+ .python-version
77
+
78
+ # celery beat schedule file
79
+ celerybeat-schedule
80
+
81
+ # SageMath parsed files
82
+ *.sage.py
83
+
84
+ # Environments
85
+ .env
86
+ .venv
87
+ env/
88
+ venv/
89
+ ENV/
90
+ env.bak/
91
+ venv.bak/
92
+
93
+ # Spyder project settings
94
+ .spyderproject
95
+ .spyproject
96
+
97
+ # Rope project settings
98
+ .ropeproject
99
+
100
+ # mkdocs documentation
101
+ /site
102
+
103
+ # mypy
104
+ .mypy_cache/
105
+ *.pkl
106
+ *.png
107
+ *.json
108
+ *.nfs*
109
+ *.tex
110
+ .idea/
111
+ *.bin
112
+ *.ckpt
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
4
+ Please read the [full text](https://code.fb.com/codeofconduct/)
5
+ so that you can understand what actions will and will not be tolerated.
CONTRIBUTING.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `master`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Facebook's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26
+ disclosure of security bugs. In those cases, please go through the process
27
+ outlined on that page and do not file a public issue.
28
+
29
+ ## Coding Style
30
+ * 4 spaces for indentation rather than tabs
31
+ * 100 character line length
32
+ * PEP8 formatting
33
+
34
+ ## License
35
+ By contributing to this project, you agree that your contributions will be licensed
36
+ under the LICENSE file in the root directory of this source tree.
LICENSE.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Facebook, Inc. and its affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,119 @@
1
- ---
2
- title: Recipedia
3
- emoji: 💩
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 3.1.7
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Inverse Cooking: Recipe Generation from Food Images
2
+
3
+ Code supporting the paper:
4
+
5
+ *Amaia Salvador, Michal Drozdzal, Xavier Giro-i-Nieto, Adriana Romero.
6
+ [Inverse Cooking: Recipe Generation from Food Images. ](https://arxiv.org/abs/1812.06164)
7
+ CVPR 2019*
8
+
9
+
10
+ If you find this code useful in your research, please consider citing using the
11
+ following BibTeX entry:
12
+
13
+ ```
14
+ @InProceedings{Salvador2019inversecooking,
15
+ author = {Salvador, Amaia and Drozdzal, Michal and Giro-i-Nieto, Xavier and Romero, Adriana},
16
+ title = {Inverse Cooking: Recipe Generation From Food Images},
17
+ booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
18
+ month = {June},
19
+ year = {2019}
20
+ }
21
+ ```
22
+
23
+ ### Installation
24
+
25
+ This code uses Python 3.6 and PyTorch 0.4.1 cuda version 9.0.
26
+
27
+ - Installing PyTorch:
28
+ ```bash
29
+ $ conda install pytorch=0.4.1 cuda90 -c pytorch
30
+ ```
31
+
32
+ - Install dependencies
33
+ ```bash
34
+ $ pip install -r requirements.txt
35
+ ```
36
+
37
+ ### Pretrained model
38
+
39
+ - Download ingredient and instruction vocabularies [here](https://dl.fbaipublicfiles.com/inversecooking/ingr_vocab.pkl) and [here](https://dl.fbaipublicfiles.com/inversecooking/instr_vocab.pkl), respectively.
40
+ - Download pretrained model [here](https://dl.fbaipublicfiles.com/inversecooking/modelbest.ckpt).
41
+
42
+ ### Demo
43
+
44
+ You can use our pretrained model to get recipes for your images.
45
+
46
+ Download the required files (listed above), place them under the ```data``` directory, and try our demo notebook ```src/demo.ipynb```.
47
+
48
+ Note: The demo will run on GPU if a device is found, else it will use CPU.
49
+
50
+ ### Data
51
+
52
+ - Download [Recipe1M](http://im2recipe.csail.mit.edu/dataset/download) (registration required)
53
+ - Extract files somewhere (we refer to this path as ```path_to_dataset```).
54
+ - The contents of ```path_to_dataset``` should be the following:
55
+ ```
56
+ det_ingrs.json
57
+ layer1.json
58
+ layer2.json
59
+ images/
60
+ images/train
61
+ images/val
62
+ images/test
63
+ ```
64
+
65
+ *Note: all python calls below must be run from ```./src```*
66
+ ### Build vocabularies
67
+
68
+ ```bash
69
+ $ python build_vocab.py --recipe1m_path path_to_dataset
70
+ ```
71
+
72
+ ### Images to LMDB (Optional, but recommended)
73
+
74
+ For fast loading during training:
75
+
76
+ ```bash
77
+ $ python utils/ims2file.py --recipe1m_path path_to_dataset
78
+ ```
79
+
80
+ If you decide not to create this file, use the flag ```--load_jpeg``` when training the model.
81
+
82
+ ### Training
83
+
84
+ Create a directory to store checkpoints for all models you train
85
+ (e.g. ```../checkpoints``` and point ```--save_dir``` to it.)
86
+
87
+ We train our model in two stages:
88
+
89
+ 1. Ingredient prediction from images
90
+
91
+ ```bash
92
+ python train.py --model_name im2ingr --batch_size 150 --finetune_after 0 --ingrs_only \
93
+ --es_metric iou_sample --loss_weight 0 1000.0 1.0 1.0 \
94
+ --learning_rate 1e-4 --scale_learning_rate_cnn 1.0 \
95
+ --save_dir ../checkpoints --recipe1m_dir path_to_dataset
96
+ ```
97
+
98
+ 2. Recipe generation from images and ingredients (loading from 1.)
99
+
100
+ ```bash
101
+ python train.py --model_name model --batch_size 256 --recipe_only --transfer_from im2ingr \
102
+ --save_dir ../checkpoints --recipe1m_dir path_to_dataset
103
+ ```
104
+
105
+ Check training progress with Tensorboard from ```../checkpoints```:
106
+
107
+ ```bash
108
+ $ tensorboard --logdir='../tb_logs' --port=6006
109
+ ```
110
+
111
+ ### Evaluation
112
+
113
+ - Save generated recipes to disk with
114
+ ```python sample.py --model_name model --save_dir ../checkpoints --recipe1m_dir path_to_dataset --greedy --eval_split test```.
115
+ - This script will return ingredient metrics (F1 and IoU)
116
+
117
+ ### License
118
+
119
+ inversecooking is released under MIT license, see [LICENSE](LICENSE.md) for details.
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import requests
3
+ import pickle
4
+ from io import BytesIO
5
+ import gradio as gr
6
+ from src.args import get_parser
7
+ from src.model import get_model
8
+ import torch
9
+ import os
10
+ from src.model1_inf import im2ingr
11
+ import numpy as np
12
+
13
+ response = requests.get("https://i.imgur.com/DwR24EM.jpeg")
14
+ dog_img = Image.open(BytesIO(response.content))
15
+
16
+ def img2ingr(image):
17
+ # img_file = '../data/demo_imgs/1.jpg'
18
+ # image = Image.open(img_file).convert('RGB')
19
+ img = Image.fromarray(np.uint8(image)).convert('RGB')
20
+ ingr = im2ingr(img, ingrs_vocab, model)
21
+ return ' '.join(ingr)
22
+
23
+ def img_ingr2recipe(image, ingr):
24
+ print(image.shape, ingr)
25
+ return dog_img, "A delicious meme dog \n--------\n1. Cook it!\n2. GL&HF"
26
+
27
+ def change_checkbox(predicted_ingr):
28
+ return gr.update(label="Ingredient required", interactive=True, choices=predicted_ingr.split(), value=predicted_ingr.split())
29
+
30
+ def add_ingr(new_ingr):
31
+ print(new_ingr)
32
+ return "hello"
33
+
34
+ def add_to_checkbox(old_ingr, new_ingr):
35
+ # chack if in dict or not
36
+ return gr.update(label="Ingredient required", interactive=True, choices=[*old_ingr, new_ingr], value=[*old_ingr, new_ingr])
37
+
38
+
39
+ """ load model1 """
40
+ args = get_parser()
41
+
42
+ # basic parameters
43
+ model_dir = './data'
44
+ data_dir = './data'
45
+ example_dir = './data/demo_imgs/'
46
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
47
+ map_loc = None if torch.cuda.is_available() else 'cpu'
48
+
49
+ # load ingredients vocab
50
+ ingrs_vocab = pickle.load(open(os.path.join(model_dir, 'ingr_vocab.pkl'), 'rb'))
51
+ vocab = pickle.load(open(os.path.join(data_dir, 'instr_vocab.pkl'), 'rb'))
52
+
53
+ ingr_vocab_size = len(ingrs_vocab)
54
+ instrs_vocab_size = len(vocab)
55
+
56
+ # model setting and loading
57
+ args.maxseqlen = 15
58
+ args.ingrs_only=True
59
+ model = get_model(args, ingr_vocab_size, instrs_vocab_size)
60
+ model_path = os.path.join(model_dir, 'modelbest.ckpt')
61
+ model.load_state_dict(torch.load(model_path, map_location=map_loc))
62
+ model.to(device)
63
+ model.eval()
64
+ model.ingrs_only = True
65
+ model.recipe_only = False
66
+
67
+ """ load model2 """
68
+
69
+
70
+
71
+
72
+ """ gradio """
73
+ # input image -> list all required ingrs -> checkbox for selecting ingrs / input_box for input more ingrs user want -> output: recipe and its image
74
+ with gr.Blocks() as demo:
75
+ gr.Markdown(
76
+ """
77
+ # Recipedia
78
+ Start finding the yummy recipe ...
79
+ """)
80
+ with gr.Tabs():
81
+ with gr.TabItem("User"):
82
+ # input image
83
+ image_input = gr.Image(label="Upload the image of your yummy food", type='filepath')
84
+ gr.Examples(examples=[example_dir+"1.jpg", example_dir+"2.jpg", example_dir+"3.jpg", example_dir+"4.jpg", example_dir+"5.jpg", example_dir+"6.jpg"], inputs=image_input)
85
+ with gr.Row():
86
+ # clear_img_btn = gr.Button("Clear")
87
+ image_btn = gr.Button("Upload", variant="primary")
88
+ # list all required ingrs -> checkbox for selecting ingrs / input_box for input more ingrs user want
89
+ predicted_ingr = gr.Textbox(visible=False)
90
+
91
+ with gr.Row():
92
+ checkboxes = gr.CheckboxGroup(label="Ingredient required", interactive=True)
93
+ new_ingr = gr.Textbox(label="Addtional ingredients", max_lines=1)
94
+ # with gr.Row():
95
+ # new_btn_clear = gr.Button("Clear")
96
+ # new_btn = gr.Button("Add", variant="primary")
97
+
98
+ add_ingr = gr.Textbox(visible=False)
99
+
100
+ with gr.Row():
101
+ clear_ingr_btn = gr.Button("Reset")
102
+ ingr_btn = gr.Button("Confirm", variant="primary")
103
+
104
+ # output: recipe and its image
105
+ with gr.Row():
106
+ out_recipe = gr.Textbox(label="Your recipe", value="Spagetti ---\n1. cook it!")
107
+ out_image = gr.Image(label="Looks yummy ><")
108
+
109
+ with gr.TabItem("Example"):
110
+ image_button = gr.Button("Flip")
111
+
112
+ image_btn.click(img2ingr, inputs=image_input, outputs=predicted_ingr)
113
+ predicted_ingr.change(fn=change_checkbox, inputs=predicted_ingr, outputs=checkboxes)
114
+
115
+ # new_btn.click(img2ingr, inputs=new_ingr, outputs=predicted_ingr)
116
+ new_ingr.submit(fn=add_to_checkbox, inputs=[checkboxes, new_ingr], outputs=checkboxes)
117
+
118
+ ingr_btn.click(img_ingr2recipe, inputs=[image_input, checkboxes], outputs=[out_image, out_recipe])
119
+
120
+
121
+ demo.launch(debug=True, share=True)
data/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # Vocabulary file will be saved here
data/demo_imgs/1.jpg ADDED
data/demo_imgs/2.jpg ADDED
data/demo_imgs/3.jpg ADDED
data/demo_imgs/4.jpg ADDED
data/demo_imgs/5.jpg ADDED
data/demo_imgs/6.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ scipy
3
+ matplotlib
4
+ # torch==0.4.1
5
+ # torchvision==0.2.1
6
+ nltk
7
+ Pillow
8
+ tqdm
9
+ lmdb
10
+ tensorflow
11
+ tensorboardX
src/args.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ import argparse
4
+ import os
5
+
6
+
7
+ def get_parser():
8
+
9
+ parser = argparse.ArgumentParser()
10
+
11
+ parser.add_argument('--save_dir', type=str, default='path/to/save/models',
12
+ help='path where checkpoints will be saved')
13
+
14
+ parser.add_argument('--project_name', type=str, default='inversecooking',
15
+ help='name of the directory where models will be saved within save_dir')
16
+
17
+ parser.add_argument('--model_name', type=str, default='model',
18
+ help='save_dir/project_name/model_name will be the path where logs and checkpoints are stored')
19
+
20
+ parser.add_argument('--transfer_from', type=str, default='',
21
+ help='specify model name to transfer from')
22
+
23
+ parser.add_argument('--suff', type=str, default='',
24
+ help='the id of the dictionary to load for training')
25
+
26
+ parser.add_argument('--image_model', type=str, default='resnet50', choices=['resnet18', 'resnet50', 'resnet101',
27
+ 'resnet152', 'inception_v3'])
28
+
29
+ parser.add_argument('--recipe1m_dir', type=str, default='path/to/recipe1m',
30
+ help='directory where recipe1m dataset is extracted')
31
+
32
+ parser.add_argument('--aux_data_dir', type=str, default='../data',
33
+ help='path to other necessary data files (eg. vocabularies)')
34
+
35
+ parser.add_argument('--crop_size', type=int, default=224, help='size for randomly or center cropping images')
36
+
37
+ parser.add_argument('--image_size', type=int, default=256, help='size to rescale images')
38
+
39
+ parser.add_argument('--log_step', type=int , default=10, help='step size for printing log info')
40
+
41
+ parser.add_argument('--learning_rate', type=float, default=0.001,
42
+ help='base learning rate')
43
+
44
+ parser.add_argument('--scale_learning_rate_cnn', type=float, default=0.01,
45
+ help='lr multiplier for cnn weights')
46
+
47
+ parser.add_argument('--lr_decay_rate', type=float, default=0.99,
48
+ help='learning rate decay factor')
49
+
50
+ parser.add_argument('--lr_decay_every', type=int, default=1,
51
+ help='frequency of learning rate decay (default is every epoch)')
52
+
53
+ parser.add_argument('--weight_decay', type=float, default=0.)
54
+
55
+ parser.add_argument('--embed_size', type=int, default=512,
56
+ help='hidden size for all projections')
57
+
58
+ parser.add_argument('--n_att', type=int, default=8,
59
+ help='number of attention heads in the instruction decoder')
60
+
61
+ parser.add_argument('--n_att_ingrs', type=int, default=4,
62
+ help='number of attention heads in the ingredient decoder')
63
+
64
+ parser.add_argument('--transf_layers', type=int, default=16,
65
+ help='number of transformer layers in the instruction decoder')
66
+
67
+ parser.add_argument('--transf_layers_ingrs', type=int, default=4,
68
+ help='number of transformer layers in the ingredient decoder')
69
+
70
+ parser.add_argument('--num_epochs', type=int, default=400,
71
+ help='maximum number of epochs')
72
+
73
+ parser.add_argument('--batch_size', type=int, default=128)
74
+
75
+ parser.add_argument('--num_workers', type=int, default=8)
76
+
77
+ parser.add_argument('--dropout_encoder', type=float, default=0.3,
78
+ help='dropout ratio for the image and ingredient encoders')
79
+
80
+ parser.add_argument('--dropout_decoder_r', type=float, default=0.3,
81
+ help='dropout ratio in the instruction decoder')
82
+
83
+ parser.add_argument('--dropout_decoder_i', type=float, default=0.3,
84
+ help='dropout ratio in the ingredient decoder')
85
+
86
+ parser.add_argument('--finetune_after', type=int, default=-1,
87
+ help='epoch to start training cnn. -1 is never, 0 is from the beginning')
88
+
89
+ parser.add_argument('--loss_weight', nargs='+', type=float, default=[1.0, 0.0, 0.0, 0.0],
90
+ help='training loss weights. 1) instruction, 2) ingredient, 3) eos 4) cardinality')
91
+
92
+ parser.add_argument('--max_eval', type=int, default=4096,
93
+ help='number of validation samples to evaluate during training')
94
+
95
+ parser.add_argument('--label_smoothing_ingr', type=float, default=0.1,
96
+ help='label smoothing for bce loss for ingredients')
97
+
98
+ parser.add_argument('--patience', type=int, default=50,
99
+ help='maximum number of epochs to allow before early stopping')
100
+
101
+ parser.add_argument('--maxseqlen', type=int, default=15,
102
+ help='maximum length of each instruction')
103
+
104
+ parser.add_argument('--maxnuminstrs', type=int, default=10,
105
+ help='maximum number of instructions')
106
+
107
+ parser.add_argument('--maxnumims', type=int, default=5,
108
+ help='maximum number of images per sample')
109
+
110
+ parser.add_argument('--maxnumlabels', type=int, default=20,
111
+ help='maximum number of ingredients per sample')
112
+
113
+ parser.add_argument('--es_metric', type=str, default='loss', choices=['loss', 'iou_sample'],
114
+ help='early stopping metric to track')
115
+
116
+ parser.add_argument('--eval_split', type=str, default='val')
117
+
118
+ parser.add_argument('--numgens', type=int, default=3)
119
+
120
+ parser.add_argument('--greedy', dest='greedy', action='store_true',
121
+ help='enables greedy sampling (inference only)')
122
+ parser.set_defaults(greedy=False)
123
+
124
+ parser.add_argument('--temperature', type=float, default=1.0,
125
+ help='sampling temperature (when greedy is False)')
126
+
127
+ parser.add_argument('--beam', type=int, default=-1,
128
+ help='beam size. -1 means no beam search (either greedy or sampling)')
129
+
130
+ parser.add_argument('--ingrs_only', dest='ingrs_only', action='store_true',
131
+ help='train or evaluate the model only for ingredient prediction')
132
+ parser.set_defaults(ingrs_only=False)
133
+
134
+ parser.add_argument('--recipe_only', dest='recipe_only', action='store_true',
135
+ help='train or evaluate the model only for instruction generation')
136
+ parser.set_defaults(recipe_only=False)
137
+
138
+ parser.add_argument('--log_term', dest='log_term', action='store_true',
139
+ help='if used, shows training log in stdout instead of saving it to a file.')
140
+ parser.set_defaults(log_term=False)
141
+
142
+ parser.add_argument('--notensorboard', dest='tensorboard', action='store_false',
143
+ help='if used, tensorboard logs will not be saved')
144
+ parser.set_defaults(tensorboard=True)
145
+
146
+ parser.add_argument('--resume', dest='resume', action='store_true',
147
+ help='resume training from the checkpoint in model_name')
148
+ parser.set_defaults(resume=False)
149
+
150
+ parser.add_argument('--nodecay_lr', dest='decay_lr', action='store_false',
151
+ help='disables learning rate decay')
152
+ parser.set_defaults(decay_lr=True)
153
+
154
+ parser.add_argument('--load_jpeg', dest='use_lmdb', action='store_false',
155
+ help='if used, images are loaded from jpg files instead of lmdb')
156
+ parser.set_defaults(use_lmdb=True)
157
+
158
+ parser.add_argument('--get_perplexity', dest='get_perplexity', action='store_true',
159
+ help='used to get perplexity in evaluation')
160
+ parser.set_defaults(get_perplexity=False)
161
+
162
+ parser.add_argument('--use_true_ingrs', dest='use_true_ingrs', action='store_true',
163
+ help='if used, true ingredients will be used as input to obtain the recipe in evaluation')
164
+ parser.set_defaults(use_true_ingrs=False)
165
+
166
+ args = parser.parse_args()
167
+
168
+ return args
src/build_vocab.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ import nltk
4
+ import pickle
5
+ import argparse
6
+ from collections import Counter
7
+ import json
8
+ import os
9
+ from tqdm import *
10
+ import numpy as np
11
+ import re
12
+
13
+
14
+ class Vocabulary(object):
15
+ """Simple vocabulary wrapper."""
16
+ def __init__(self):
17
+ self.word2idx = {}
18
+ self.idx2word = {}
19
+ self.idx = 0
20
+
21
+ def add_word(self, word, idx=None):
22
+ if idx is None:
23
+ if not word in self.word2idx:
24
+ self.word2idx[word] = self.idx
25
+ self.idx2word[self.idx] = word
26
+ self.idx += 1
27
+ return self.idx
28
+ else:
29
+ if not word in self.word2idx:
30
+ self.word2idx[word] = idx
31
+ if idx in self.idx2word.keys():
32
+ self.idx2word[idx].append(word)
33
+ else:
34
+ self.idx2word[idx] = [word]
35
+
36
+ return idx
37
+
38
+ def __call__(self, word):
39
+ if not word in self.word2idx:
40
+ return self.word2idx['<pad>']
41
+ return self.word2idx[word]
42
+
43
+ def __len__(self):
44
+ return len(self.idx2word)
45
+
46
+
47
+ def get_ingredient(det_ingr, replace_dict):
48
+ det_ingr_undrs = det_ingr['text'].lower()
49
+ det_ingr_undrs = ''.join(i for i in det_ingr_undrs if not i.isdigit())
50
+
51
+ for rep, char_list in replace_dict.items():
52
+ for c_ in char_list:
53
+ if c_ in det_ingr_undrs:
54
+ det_ingr_undrs = det_ingr_undrs.replace(c_, rep)
55
+ det_ingr_undrs = det_ingr_undrs.strip()
56
+ det_ingr_undrs = det_ingr_undrs.replace(' ', '_')
57
+
58
+ return det_ingr_undrs
59
+
60
+
61
+ def get_instruction(instruction, replace_dict, instruction_mode=True):
62
+ instruction = instruction.lower()
63
+
64
+ for rep, char_list in replace_dict.items():
65
+ for c_ in char_list:
66
+ if c_ in instruction:
67
+ instruction = instruction.replace(c_, rep)
68
+ instruction = instruction.strip()
69
+ # remove sentences starting with "1.", "2.", ... from the targets
70
+ if len(instruction) > 0 and instruction[0].isdigit() and instruction_mode:
71
+ instruction = ''
72
+ return instruction
73
+
74
+
75
+ def remove_plurals(counter_ingrs, ingr_clusters):
76
+ del_ingrs = []
77
+
78
+ for k, v in counter_ingrs.items():
79
+
80
+ if len(k) == 0:
81
+ del_ingrs.append(k)
82
+ continue
83
+
84
+ gotit = 0
85
+ if k[-2:] == 'es':
86
+ if k[:-2] in counter_ingrs.keys():
87
+ counter_ingrs[k[:-2]] += v
88
+ ingr_clusters[k[:-2]].extend(ingr_clusters[k])
89
+ del_ingrs.append(k)
90
+ gotit = 1
91
+
92
+ if k[-1] == 's' and gotit == 0:
93
+ if k[:-1] in counter_ingrs.keys():
94
+ counter_ingrs[k[:-1]] += v
95
+ ingr_clusters[k[:-1]].extend(ingr_clusters[k])
96
+ del_ingrs.append(k)
97
+ for item in del_ingrs:
98
+ del counter_ingrs[item]
99
+ del ingr_clusters[item]
100
+ return counter_ingrs, ingr_clusters
101
+
102
+
103
+ def cluster_ingredients(counter_ingrs):
104
+ mydict = dict()
105
+ mydict_ingrs = dict()
106
+
107
+ for k, v in counter_ingrs.items():
108
+
109
+ w1 = k.split('_')[-1]
110
+ w2 = k.split('_')[0]
111
+ lw = [w1, w2]
112
+ if len(k.split('_')) > 1:
113
+ w3 = k.split('_')[0] + '_' + k.split('_')[1]
114
+ w4 = k.split('_')[-2] + '_' + k.split('_')[-1]
115
+
116
+ lw = [w1, w2, w4, w3]
117
+
118
+ gotit = 0
119
+ for w in lw:
120
+ if w in counter_ingrs.keys():
121
+ # check if its parts are
122
+ parts = w.split('_')
123
+ if len(parts) > 0:
124
+ if parts[0] in counter_ingrs.keys():
125
+ w = parts[0]
126
+ elif parts[1] in counter_ingrs.keys():
127
+ w = parts[1]
128
+ if w in mydict.keys():
129
+ mydict[w] += v
130
+ mydict_ingrs[w].append(k)
131
+ else:
132
+ mydict[w] = v
133
+ mydict_ingrs[w] = [k]
134
+ gotit = 1
135
+ break
136
+ if gotit == 0:
137
+ mydict[k] = v
138
+ mydict_ingrs[k] = [k]
139
+
140
+ return mydict, mydict_ingrs
141
+
142
+
143
+ def update_counter(list_, counter_toks, istrain=False):
144
+ for sentence in list_:
145
+ tokens = nltk.tokenize.word_tokenize(sentence)
146
+ if istrain:
147
+ counter_toks.update(tokens)
148
+
149
+
150
+ def build_vocab_recipe1m(args):
151
+ print ("Loading data...")
152
+ dets = json.load(open(os.path.join(args.recipe1m_path, 'det_ingrs.json'), 'r'))
153
+ layer1 = json.load(open(os.path.join(args.recipe1m_path, 'layer1.json'), 'r'))
154
+ layer2 = json.load(open(os.path.join(args.recipe1m_path, 'layer2.json'), 'r'))
155
+
156
+ id2im = {}
157
+
158
+ for i, entry in enumerate(layer2):
159
+ id2im[entry['id']] = i
160
+
161
+ print("Loaded data.")
162
+ print("Found %d recipes in the dataset." % (len(layer1)))
163
+ replace_dict_ingrs = {'and': ['&', "'n"], '': ['%', ',', '.', '#', '[', ']', '!', '?']}
164
+ replace_dict_instrs = {'and': ['&', "'n"], '': ['#', '[', ']']}
165
+
166
+ idx2ind = {}
167
+ for i, entry in enumerate(dets):
168
+ idx2ind[entry['id']] = i
169
+
170
+ ingrs_file = args.save_path + 'allingrs_count.pkl'
171
+ instrs_file = args.save_path + 'allwords_count.pkl'
172
+
173
+ #####
174
+ # 1. Count words in dataset and clean
175
+ #####
176
+ if os.path.exists(ingrs_file) and os.path.exists(instrs_file) and not args.forcegen:
177
+ print ("loading pre-extracted word counters")
178
+ counter_ingrs = pickle.load(open(args.save_path + 'allingrs_count.pkl', 'rb'))
179
+ counter_toks = pickle.load(open(args.save_path + 'allwords_count.pkl', 'rb'))
180
+ else:
181
+ counter_toks = Counter()
182
+ counter_ingrs = Counter()
183
+ counter_ingrs_raw = Counter()
184
+
185
+ for i, entry in tqdm(enumerate(layer1)):
186
+
187
+ # get all instructions for this recipe
188
+ instrs = entry['instructions']
189
+
190
+ instrs_list = []
191
+ ingrs_list = []
192
+
193
+ # retrieve pre-detected ingredients for this entry
194
+ det_ingrs = dets[idx2ind[entry['id']]]['ingredients']
195
+
196
+ valid = dets[idx2ind[entry['id']]]['valid']
197
+ det_ingrs_filtered = []
198
+
199
+ for j, det_ingr in enumerate(det_ingrs):
200
+ if len(det_ingr) > 0 and valid[j]:
201
+ det_ingr_undrs = get_ingredient(det_ingr, replace_dict_ingrs)
202
+ det_ingrs_filtered.append(det_ingr_undrs)
203
+ ingrs_list.append(det_ingr_undrs)
204
+
205
+ # get raw text for instructions of this entry
206
+ acc_len = 0
207
+ for instr in instrs:
208
+ instr = instr['text']
209
+ instr = get_instruction(instr, replace_dict_instrs)
210
+ if len(instr) > 0:
211
+ instrs_list.append(instr)
212
+ acc_len += len(instr)
213
+
214
+ # discard recipes with too few or too many ingredients or instruction words
215
+ if len(ingrs_list) < args.minnumingrs or len(instrs_list) < args.minnuminstrs \
216
+ or len(instrs_list) >= args.maxnuminstrs or len(ingrs_list) >= args.maxnumingrs \
217
+ or acc_len < args.minnumwords:
218
+ continue
219
+
220
+ # tokenize sentences and update counter
221
+ update_counter(instrs_list, counter_toks, istrain=entry['partition'] == 'train')
222
+ title = nltk.tokenize.word_tokenize(entry['title'].lower())
223
+ if entry['partition'] == 'train':
224
+ counter_toks.update(title)
225
+ if entry['partition'] == 'train':
226
+ counter_ingrs.update(ingrs_list)
227
+
228
+ pickle.dump(counter_ingrs, open(args.save_path + 'allingrs_count.pkl', 'wb'))
229
+ pickle.dump(counter_toks, open(args.save_path + 'allwords_count.pkl', 'wb'))
230
+ pickle.dump(counter_ingrs_raw, open(args.save_path + 'allingrs_raw_count.pkl', 'wb'))
231
+
232
+ # manually add missing entries for better clustering
233
+ base_words = ['peppers', 'tomato', 'spinach_leaves', 'turkey_breast', 'lettuce_leaf',
234
+ 'chicken_thighs', 'milk_powder', 'bread_crumbs', 'onion_flakes',
235
+ 'red_pepper', 'pepper_flakes', 'juice_concentrate', 'cracker_crumbs', 'hot_chili',
236
+ 'seasoning_mix', 'dill_weed', 'pepper_sauce', 'sprouts', 'cooking_spray', 'cheese_blend',
237
+ 'basil_leaves', 'pineapple_chunks', 'marshmallow', 'chile_powder',
238
+ 'cheese_blend', 'corn_kernels', 'tomato_sauce', 'chickens', 'cracker_crust',
239
+ 'lemonade_concentrate', 'red_chili', 'mushroom_caps', 'mushroom_cap', 'breaded_chicken',
240
+ 'frozen_pineapple', 'pineapple_chunks', 'seasoning_mix', 'seaweed', 'onion_flakes',
241
+ 'bouillon_granules', 'lettuce_leaf', 'stuffing_mix', 'parsley_flakes', 'chicken_breast',
242
+ 'basil_leaves', 'baguettes', 'green_tea', 'peanut_butter', 'green_onion', 'fresh_cilantro',
243
+ 'breaded_chicken', 'hot_pepper', 'dried_lavender', 'white_chocolate',
244
+ 'dill_weed', 'cake_mix', 'cheese_spread', 'turkey_breast', 'chucken_thighs', 'basil_leaves',
245
+ 'mandarin_orange', 'laurel', 'cabbage_head', 'pistachio', 'cheese_dip',
246
+ 'thyme_leave', 'boneless_pork', 'red_pepper', 'onion_dip', 'skinless_chicken', 'dark_chocolate',
247
+ 'canned_corn', 'muffin', 'cracker_crust', 'bread_crumbs', 'frozen_broccoli',
248
+ 'philadelphia', 'cracker_crust', 'chicken_breast']
249
+
250
+ for base_word in base_words:
251
+
252
+ if base_word not in counter_ingrs.keys():
253
+ counter_ingrs[base_word] = 1
254
+
255
+ counter_ingrs, cluster_ingrs = cluster_ingredients(counter_ingrs)
256
+ counter_ingrs, cluster_ingrs = remove_plurals(counter_ingrs, cluster_ingrs)
257
+
258
+ # If the word frequency is less than 'threshold', then the word is discarded.
259
+ words = [word for word, cnt in counter_toks.items() if cnt >= args.threshold_words]
260
+ ingrs = {word: cnt for word, cnt in counter_ingrs.items() if cnt >= args.threshold_ingrs}
261
+
262
+ # Recipe vocab
263
+ # Create a vocab wrapper and add some special tokens.
264
+ vocab_toks = Vocabulary()
265
+ vocab_toks.add_word('<start>')
266
+ vocab_toks.add_word('<end>')
267
+ vocab_toks.add_word('<eoi>')
268
+
269
+ # Add the words to the vocabulary.
270
+ for i, word in enumerate(words):
271
+ vocab_toks.add_word(word)
272
+ vocab_toks.add_word('<pad>')
273
+
274
+ # Ingredient vocab
275
+ # Create a vocab wrapper for ingredients
276
+ vocab_ingrs = Vocabulary()
277
+ idx = vocab_ingrs.add_word('<end>')
278
+ # this returns the next idx to add words to
279
+ # Add the ingredients to the vocabulary.
280
+ for k, _ in ingrs.items():
281
+ for ingr in cluster_ingrs[k]:
282
+ idx = vocab_ingrs.add_word(ingr, idx)
283
+ idx += 1
284
+ _ = vocab_ingrs.add_word('<pad>', idx)
285
+
286
+ print("Total ingr vocabulary size: {}".format(len(vocab_ingrs)))
287
+ print("Total token vocabulary size: {}".format(len(vocab_toks)))
288
+
289
+ dataset = {'train': [], 'val': [], 'test': []}
290
+
291
+ ######
292
+ # 2. Tokenize and build dataset based on vocabularies.
293
+ ######
294
+ for i, entry in tqdm(enumerate(layer1)):
295
+
296
+ # get all instructions for this recipe
297
+ instrs = entry['instructions']
298
+
299
+ instrs_list = []
300
+ ingrs_list = []
301
+ images_list = []
302
+
303
+ # retrieve pre-detected ingredients for this entry
304
+ det_ingrs = dets[idx2ind[entry['id']]]['ingredients']
305
+ valid = dets[idx2ind[entry['id']]]['valid']
306
+ labels = []
307
+
308
+ for j, det_ingr in enumerate(det_ingrs):
309
+ if len(det_ingr) > 0 and valid[j]:
310
+ det_ingr_undrs = get_ingredient(det_ingr, replace_dict_ingrs)
311
+ ingrs_list.append(det_ingr_undrs)
312
+ label_idx = vocab_ingrs(det_ingr_undrs)
313
+ if label_idx is not vocab_ingrs('<pad>') and label_idx not in labels:
314
+ labels.append(label_idx)
315
+
316
+ # get raw text for instructions of this entry
317
+ acc_len = 0
318
+ for instr in instrs:
319
+ instr = instr['text']
320
+ instr = get_instruction(instr, replace_dict_instrs)
321
+ if len(instr) > 0:
322
+ acc_len += len(instr)
323
+ instrs_list.append(instr)
324
+
325
+ # we discard recipes with too many or too few ingredients or instruction words
326
+ if len(labels) < args.minnumingrs or len(instrs_list) < args.minnuminstrs \
327
+ or len(instrs_list) >= args.maxnuminstrs or len(labels) >= args.maxnumingrs \
328
+ or acc_len < args.minnumwords:
329
+ continue
330
+
331
+ if entry['id'] in id2im.keys():
332
+ ims = layer2[id2im[entry['id']]]
333
+
334
+ # copy image paths for this recipe
335
+ for im in ims['images']:
336
+ images_list.append(im['id'])
337
+
338
+ # tokenize sentences
339
+ toks = []
340
+
341
+ for instr in instrs_list:
342
+ tokens = nltk.tokenize.word_tokenize(instr)
343
+ toks.append(tokens)
344
+
345
+ title = nltk.tokenize.word_tokenize(entry['title'].lower())
346
+
347
+ newentry = {'id': entry['id'], 'instructions': instrs_list, 'tokenized': toks,
348
+ 'ingredients': ingrs_list, 'images': images_list, 'title': title}
349
+ dataset[entry['partition']].append(newentry)
350
+
351
+ print('Dataset size:')
352
+ for split in dataset.keys():
353
+ print(split, ':', len(dataset[split]))
354
+
355
+ return vocab_ingrs, vocab_toks, dataset
356
+
357
+
358
+ def main(args):
359
+
360
+ vocab_ingrs, vocab_toks, dataset = build_vocab_recipe1m(args)
361
+
362
+ with open(os.path.join(args.save_path, args.suff+'recipe1m_vocab_ingrs.pkl'), 'wb') as f:
363
+ pickle.dump(vocab_ingrs, f)
364
+ with open(os.path.join(args.save_path, args.suff+'recipe1m_vocab_toks.pkl'), 'wb') as f:
365
+ pickle.dump(vocab_toks, f)
366
+
367
+ for split in dataset.keys():
368
+ with open(os.path.join(args.save_path, args.suff+'recipe1m_' + split + '.pkl'), 'wb') as f:
369
+ pickle.dump(dataset[split], f)
370
+
371
+
372
+ if __name__ == '__main__':
373
+
374
+ parser = argparse.ArgumentParser()
375
+ parser.add_argument('--recipe1m_path', type=str,
376
+ default='path/to/recipe1m',
377
+ help='recipe1m path')
378
+
379
+ parser.add_argument('--save_path', type=str, default='../data/',
380
+ help='path for saving vocabulary wrapper')
381
+
382
+ parser.add_argument('--suff', type=str, default='')
383
+
384
+ parser.add_argument('--threshold_ingrs', type=int, default=10,
385
+ help='minimum ingr count threshold')
386
+
387
+ parser.add_argument('--threshold_words', type=int, default=10,
388
+ help='minimum word count threshold')
389
+
390
+ parser.add_argument('--maxnuminstrs', type=int, default=20,
391
+ help='max number of instructions (sentences)')
392
+
393
+ parser.add_argument('--maxnumingrs', type=int, default=20,
394
+ help='max number of ingredients')
395
+
396
+ parser.add_argument('--minnuminstrs', type=int, default=2,
397
+ help='max number of instructions (sentences)')
398
+
399
+ parser.add_argument('--minnumingrs', type=int, default=2,
400
+ help='max number of ingredients')
401
+
402
+ parser.add_argument('--minnumwords', type=int, default=20,
403
+ help='minimum number of characters in recipe')
404
+
405
+ parser.add_argument('--forcegen', dest='forcegen', action='store_true')
406
+ parser.set_defaults(forcegen=False)
407
+
408
+ args = parser.parse_args()
409
+ main(args)
src/data_loader.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ import torch
4
+ import torchvision.transforms as transforms
5
+ import torch.utils.data as data
6
+ import os
7
+ import pickle
8
+ import numpy as np
9
+ import nltk
10
+ from PIL import Image
11
+ from build_vocab import Vocabulary
12
+ import random
13
+ import json
14
+ import lmdb
15
+
16
+
17
+ class Recipe1MDataset(data.Dataset):
18
+
19
+ def __init__(self, data_dir, aux_data_dir, split, maxseqlen, maxnuminstrs, maxnumlabels, maxnumims,
20
+ transform=None, max_num_samples=-1, use_lmdb=False, suff=''):
21
+
22
+ self.ingrs_vocab = pickle.load(open(os.path.join(aux_data_dir, suff + 'recipe1m_vocab_ingrs.pkl'), 'rb'))
23
+ self.instrs_vocab = pickle.load(open(os.path.join(aux_data_dir, suff + 'recipe1m_vocab_toks.pkl'), 'rb'))
24
+ self.dataset = pickle.load(open(os.path.join(aux_data_dir, suff + 'recipe1m_'+split+'.pkl'), 'rb'))
25
+
26
+ self.label2word = self.get_ingrs_vocab()
27
+
28
+ self.use_lmdb = use_lmdb
29
+ if use_lmdb:
30
+ self.image_file = lmdb.open(os.path.join(aux_data_dir, 'lmdb_' + split), max_readers=1, readonly=True,
31
+ lock=False, readahead=False, meminit=False)
32
+
33
+ self.ids = []
34
+ self.split = split
35
+ for i, entry in enumerate(self.dataset):
36
+ if len(entry['images']) == 0:
37
+ continue
38
+ self.ids.append(i)
39
+
40
+ self.root = os.path.join(data_dir, 'images', split)
41
+ self.transform = transform
42
+ self.max_num_labels = maxnumlabels
43
+ self.maxseqlen = maxseqlen
44
+ self.max_num_instrs = maxnuminstrs
45
+ self.maxseqlen = maxseqlen*maxnuminstrs
46
+ self.maxnumims = maxnumims
47
+ if max_num_samples != -1:
48
+ random.shuffle(self.ids)
49
+ self.ids = self.ids[:max_num_samples]
50
+
51
+ def get_instrs_vocab(self):
52
+ return self.instrs_vocab
53
+
54
+ def get_instrs_vocab_size(self):
55
+ return len(self.instrs_vocab)
56
+
57
+ def get_ingrs_vocab(self):
58
+ return [min(w, key=len) if not isinstance(w, str) else w for w in
59
+ self.ingrs_vocab.idx2word.values()] # includes 'pad' ingredient
60
+
61
+ def get_ingrs_vocab_size(self):
62
+ return len(self.ingrs_vocab)
63
+
64
+ def __getitem__(self, index):
65
+ """Returns one data pair (image and caption)."""
66
+
67
+ sample = self.dataset[self.ids[index]]
68
+ img_id = sample['id']
69
+ captions = sample['tokenized']
70
+ paths = sample['images'][0:self.maxnumims]
71
+
72
+ idx = index
73
+
74
+ labels = self.dataset[self.ids[idx]]['ingredients']
75
+ title = sample['title']
76
+
77
+ tokens = []
78
+ tokens.extend(title)
79
+ # add fake token to separate title from recipe
80
+ tokens.append('<eoi>')
81
+ for c in captions:
82
+ tokens.extend(c)
83
+ tokens.append('<eoi>')
84
+
85
+ ilabels_gt = np.ones(self.max_num_labels) * self.ingrs_vocab('<pad>')
86
+ pos = 0
87
+
88
+ true_ingr_idxs = []
89
+ for i in range(len(labels)):
90
+ true_ingr_idxs.append(self.ingrs_vocab(labels[i]))
91
+
92
+ for i in range(self.max_num_labels):
93
+ if i >= len(labels):
94
+ label = '<pad>'
95
+ else:
96
+ label = labels[i]
97
+ label_idx = self.ingrs_vocab(label)
98
+ if label_idx not in ilabels_gt:
99
+ ilabels_gt[pos] = label_idx
100
+ pos += 1
101
+
102
+ ilabels_gt[pos] = self.ingrs_vocab('<end>')
103
+ ingrs_gt = torch.from_numpy(ilabels_gt).long()
104
+
105
+ if len(paths) == 0:
106
+ path = None
107
+ image_input = torch.zeros((3, 224, 224))
108
+ else:
109
+ if self.split == 'train':
110
+ img_idx = np.random.randint(0, len(paths))
111
+ else:
112
+ img_idx = 0
113
+ path = paths[img_idx]
114
+ if self.use_lmdb:
115
+ try:
116
+ with self.image_file.begin(write=False) as txn:
117
+ image = txn.get(path.encode())
118
+ image = np.fromstring(image, dtype=np.uint8)
119
+ image = np.reshape(image, (256, 256, 3))
120
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
121
+ except:
122
+ print ("Image id not found in lmdb. Loading jpeg file...")
123
+ image = Image.open(os.path.join(self.root, path[0], path[1],
124
+ path[2], path[3], path)).convert('RGB')
125
+ else:
126
+ image = Image.open(os.path.join(self.root, path[0], path[1], path[2], path[3], path)).convert('RGB')
127
+ if self.transform is not None:
128
+ image = self.transform(image)
129
+ image_input = image
130
+
131
+ # Convert caption (string) to word ids.
132
+ caption = []
133
+
134
+ caption = self.caption_to_idxs(tokens, caption)
135
+ caption.append(self.instrs_vocab('<end>'))
136
+
137
+ caption = caption[0:self.maxseqlen]
138
+ target = torch.Tensor(caption)
139
+
140
+ return image_input, target, ingrs_gt, img_id, path, self.instrs_vocab('<pad>')
141
+
142
+ def __len__(self):
143
+ return len(self.ids)
144
+
145
+ def caption_to_idxs(self, tokens, caption):
146
+
147
+ caption.append(self.instrs_vocab('<start>'))
148
+ for token in tokens:
149
+ caption.append(self.instrs_vocab(token))
150
+ return caption
151
+
152
+
153
+ def collate_fn(data):
154
+
155
+ # Sort a data list by caption length (descending order).
156
+ # data.sort(key=lambda x: len(x[2]), reverse=True)
157
+ image_input, captions, ingrs_gt, img_id, path, pad_value = zip(*data)
158
+
159
+ # Merge images (from tuple of 3D tensor to 4D tensor).
160
+
161
+ image_input = torch.stack(image_input, 0)
162
+ ingrs_gt = torch.stack(ingrs_gt, 0)
163
+
164
+ # Merge captions (from tuple of 1D tensor to 2D tensor).
165
+ lengths = [len(cap) for cap in captions]
166
+ targets = torch.ones(len(captions), max(lengths)).long()*pad_value[0]
167
+
168
+ for i, cap in enumerate(captions):
169
+ end = lengths[i]
170
+ targets[i, :end] = cap[:end]
171
+
172
+ return image_input, targets, ingrs_gt, img_id, path
173
+
174
+
175
+ def get_loader(data_dir, aux_data_dir, split, maxseqlen,
176
+ maxnuminstrs, maxnumlabels, maxnumims, transform, batch_size,
177
+ shuffle, num_workers, drop_last=False,
178
+ max_num_samples=-1,
179
+ use_lmdb=False,
180
+ suff=''):
181
+
182
+ dataset = Recipe1MDataset(data_dir=data_dir, aux_data_dir=aux_data_dir, split=split,
183
+ maxseqlen=maxseqlen, maxnumlabels=maxnumlabels, maxnuminstrs=maxnuminstrs,
184
+ maxnumims=maxnumims,
185
+ transform=transform,
186
+ max_num_samples=max_num_samples,
187
+ use_lmdb=use_lmdb,
188
+ suff=suff)
189
+
190
+ data_loader = torch.utils.data.DataLoader(dataset=dataset,
191
+ batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
192
+ drop_last=drop_last, collate_fn=collate_fn, pin_memory=True)
193
+ return data_loader, dataset
src/demo.ipynb ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Inverse Cooking: Recipe Generation from Food Images"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import matplotlib.pyplot as plt\n",
17
+ "import torch\n",
18
+ "import torch.nn as nn\n",
19
+ "import numpy as np\n",
20
+ "import os\n",
21
+ "from args import get_parser\n",
22
+ "import pickle\n",
23
+ "from model import get_model\n",
24
+ "from torchvision import transforms\n",
25
+ "from utils.output_utils import prepare_output\n",
26
+ "from PIL import Image\n",
27
+ "import time"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "markdown",
32
+ "metadata": {},
33
+ "source": [
34
+ "Set ```data_dir``` to the path including vocabularies and model checkpoint"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "data_dir = '../data'"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "# code will run in gpu if available and if the flag is set to True, else it will run on cpu\n",
53
+ "use_gpu = False\n",
54
+ "device = torch.device('cuda' if torch.cuda.is_available() and use_gpu else 'cpu')\n",
55
+ "map_loc = None if torch.cuda.is_available() and use_gpu else 'cpu'"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "# code below was used to save vocab files so that they can be loaded without Vocabulary class\n",
65
+ "#ingrs_vocab = pickle.load(open(os.path.join(data_dir, 'final_recipe1m_vocab_ingrs.pkl'), 'rb'))\n",
66
+ "#ingrs_vocab = [min(w, key=len) if not isinstance(w, str) else w for w in ingrs_vocab.idx2word.values()]\n",
67
+ "#vocab = pickle.load(open(os.path.join(data_dir, 'final_recipe1m_vocab_toks.pkl'), 'rb')).idx2word\n",
68
+ "#pickle.dump(ingrs_vocab, open('../demo/ingr_vocab.pkl', 'wb'))\n",
69
+ "#pickle.dump(vocab, open('../demo/instr_vocab.pkl', 'wb'))\n",
70
+ "\n",
71
+ "ingrs_vocab = pickle.load(open(os.path.join(data_dir, 'ingr_vocab.pkl'), 'rb'))\n",
72
+ "vocab = pickle.load(open(os.path.join(data_dir, 'instr_vocab.pkl'), 'rb'))\n",
73
+ "\n",
74
+ "ingr_vocab_size = len(ingrs_vocab)\n",
75
+ "instrs_vocab_size = len(vocab)\n",
76
+ "output_dim = instrs_vocab_size"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": null,
82
+ "metadata": {},
83
+ "outputs": [],
84
+ "source": [
85
+ "print (instrs_vocab_size, ingr_vocab_size)"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "t = time.time()\n",
95
+ "import sys; sys.argv=['']; del sys\n",
96
+ "args = get_parser()\n",
97
+ "args.maxseqlen = 15\n",
98
+ "args.ingrs_only=False\n",
99
+ "model = get_model(args, ingr_vocab_size, instrs_vocab_size)\n",
100
+ "# Load the trained model parameters\n",
101
+ "model_path = os.path.join(data_dir, 'modelbest.ckpt')\n",
102
+ "model.load_state_dict(torch.load(model_path, map_location=map_loc))\n",
103
+ "model.to(device)\n",
104
+ "model.eval()\n",
105
+ "model.ingrs_only = False\n",
106
+ "model.recipe_only = False\n",
107
+ "print ('loaded model')\n",
108
+ "print (\"Elapsed time:\", time.time() -t)\n"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": null,
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "transf_list_batch = []\n",
118
+ "transf_list_batch.append(transforms.ToTensor())\n",
119
+ "transf_list_batch.append(transforms.Normalize((0.485, 0.456, 0.406), \n",
120
+ " (0.229, 0.224, 0.225)))\n",
121
+ "to_input_transf = transforms.Compose(transf_list_batch)"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": null,
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "greedy = [True, False, False, False]\n",
131
+ "beam = [-1, -1, -1, -1]\n",
132
+ "temperature = 1.0\n",
133
+ "numgens = len(greedy)"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "markdown",
138
+ "metadata": {},
139
+ "source": [
140
+ "Set ```use_urls = True``` to get recipes for images in ```demo_urls```. \n",
141
+ "\n",
142
+ "You can also set ```use_urls = False``` and get recipes for images in the path in ```data_dir/test_imgs```."
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "code",
147
+ "execution_count": null,
148
+ "metadata": {
149
+ "scrolled": true
150
+ },
151
+ "outputs": [],
152
+ "source": [
153
+ "import requests\n",
154
+ "from io import BytesIO\n",
155
+ "import random\n",
156
+ "from collections import Counter\n",
157
+ "use_urls = False # set to true to load images from demo_urls instead of those in test_imgs folder\n",
158
+ "show_anyways = False #if True, it will show the recipe even if it's not valid\n",
159
+ "image_folder = os.path.join(data_dir, 'demo_imgs')\n",
160
+ "\n",
161
+ "if not use_urls:\n",
162
+ " demo_imgs = os.listdir(image_folder)\n",
163
+ " random.shuffle(demo_imgs)\n",
164
+ "\n",
165
+ "demo_urls = ['https://food.fnr.sndimg.com/content/dam/images/food/fullset/2013/12/9/0/FNK_Cheesecake_s4x3.jpg.rend.hgtvcom.826.620.suffix/1387411272847.jpeg',\n",
166
+ " 'https://www.196flavors.com/wp-content/uploads/2014/10/california-roll-3-FP.jpg']\n",
167
+ "\n",
168
+ "demo_files = demo_urls if use_urls else demo_imgs"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "metadata": {},
175
+ "outputs": [],
176
+ "source": [
177
+ "for img_file in demo_files:\n",
178
+ " \n",
179
+ " if use_urls:\n",
180
+ " response = requests.get(img_file)\n",
181
+ " image = Image.open(BytesIO(response.content))\n",
182
+ " else:\n",
183
+ " image_path = os.path.join(image_folder, img_file)\n",
184
+ " image = Image.open(image_path).convert('RGB')\n",
185
+ " \n",
186
+ " transf_list = []\n",
187
+ " transf_list.append(transforms.Resize(256))\n",
188
+ " transf_list.append(transforms.CenterCrop(224))\n",
189
+ " transform = transforms.Compose(transf_list)\n",
190
+ " \n",
191
+ " image_transf = transform(image)\n",
192
+ " image_tensor = to_input_transf(image_transf).unsqueeze(0).to(device)\n",
193
+ " \n",
194
+ " plt.imshow(image_transf)\n",
195
+ " plt.axis('off')\n",
196
+ " plt.show()\n",
197
+ " plt.close()\n",
198
+ " \n",
199
+ " num_valid = 1\n",
200
+ " for i in range(numgens):\n",
201
+ " with torch.no_grad():\n",
202
+ " outputs = model.sample(image_tensor, greedy=greedy[i], \n",
203
+ " temperature=temperature, beam=beam[i], true_ingrs=None)\n",
204
+ " \n",
205
+ " ingr_ids = outputs['ingr_ids'].cpu().numpy()\n",
206
+ " recipe_ids = outputs['recipe_ids'].cpu().numpy()\n",
207
+ " \n",
208
+ " outs, valid = prepare_output(recipe_ids[0], ingr_ids[0], ingrs_vocab, vocab)\n",
209
+ " \n",
210
+ " if valid['is_valid'] or show_anyways:\n",
211
+ " \n",
212
+ " print ('RECIPE', num_valid)\n",
213
+ " num_valid+=1\n",
214
+ " #print (\"greedy:\", greedy[i], \"beam:\", beam[i])\n",
215
+ " \n",
216
+ " BOLD = '\\033[1m'\n",
217
+ " END = '\\033[0m'\n",
218
+ " print (BOLD + '\\nTitle:' + END,outs['title'])\n",
219
+ "\n",
220
+ " print (BOLD + '\\nIngredients:'+ END)\n",
221
+ " print (', '.join(outs['ingrs']))\n",
222
+ "\n",
223
+ " print (BOLD + '\\nInstructions:'+END)\n",
224
+ " print ('-'+'\\n-'.join(outs['recipe']))\n",
225
+ "\n",
226
+ " print ('='*20)\n",
227
+ "\n",
228
+ " else:\n",
229
+ " pass\n",
230
+ " print (\"Not a valid recipe!\")\n",
231
+ " print (\"Reason: \", valid['reason'])\n",
232
+ " "
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "code",
237
+ "execution_count": null,
238
+ "metadata": {},
239
+ "outputs": [],
240
+ "source": []
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": null,
245
+ "metadata": {},
246
+ "outputs": [],
247
+ "source": []
248
+ }
249
+ ],
250
+ "metadata": {
251
+ "kernelspec": {
252
+ "display_name": "Python 3",
253
+ "language": "python",
254
+ "name": "python3"
255
+ },
256
+ "language_info": {
257
+ "codemirror_mode": {
258
+ "name": "ipython",
259
+ "version": 3
260
+ },
261
+ "file_extension": ".py",
262
+ "mimetype": "text/x-python",
263
+ "name": "python",
264
+ "nbconvert_exporter": "python",
265
+ "pygments_lexer": "ipython3",
266
+ "version": "3.6.5"
267
+ }
268
+ },
269
+ "nbformat": 4,
270
+ "nbformat_minor": 2
271
+ }
src/demo.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import os
5
+ from args import get_parser
6
+ import pickle
7
+ from model import get_model
8
+ from torchvision import transforms
9
+ from utils.output_ing import prepare_output
10
+ from PIL import Image
11
+ from tqdm import tqdm
12
+ import time
13
+ import glob
14
+
15
+
16
+ # Set ```data_dir``` to the path including vocabularies and model checkpoint
17
+ model_dir = '../data'
18
+ image_folder = '../data/demo_imgs'
19
+ output_file = "../data/predicted_ingr.pkl"
20
+
21
+ # code will run in gpu if available and if the flag is set to True, else it will run on cpu
22
+ use_gpu = False
23
+ device = torch.device('cuda' if torch.cuda.is_available() and use_gpu else 'cpu')
24
+ map_loc = None if torch.cuda.is_available() and use_gpu else 'cpu'
25
+
26
+ # code below was used to save vocab files so that they can be loaded without Vocabulary class
27
+ #ingrs_vocab = pickle.load(open(os.path.join(data_dir, 'final_recipe1m_vocab_ingrs.pkl'), 'rb'))
28
+ #ingrs_vocab = [min(w, key=len) if not isinstance(w, str) else w for w in ingrs_vocab.idx2word.values()]
29
+ #vocab = pickle.load(open(os.path.join(data_dir, 'final_recipe1m_vocab_toks.pkl'), 'rb')).idx2word
30
+ #pickle.dump(ingrs_vocab, open('../demo/ingr_vocab.pkl', 'wb'))
31
+ #pickle.dump(vocab, open('../demo/instr_vocab.pkl', 'wb'))
32
+
33
+ ingrs_vocab = pickle.load(open(os.path.join(model_dir, 'ingr_vocab.pkl'), 'rb'))
34
+ vocab = pickle.load(open(os.path.join(model_dir, 'instr_vocab.pkl'), 'rb'))
35
+
36
+ ingr_vocab_size = len(ingrs_vocab)
37
+ instrs_vocab_size = len(vocab)
38
+ output_dim = instrs_vocab_size
39
+
40
+ print (instrs_vocab_size, ingr_vocab_size)
41
+
42
+ t = time.time()
43
+
44
+ args = get_parser()
45
+ args.maxseqlen = 15
46
+ args.ingrs_only=True
47
+ model = get_model(args, ingr_vocab_size, instrs_vocab_size)
48
+ # Load the trained model parameters
49
+ model_path = os.path.join(model_dir, 'modelbest.ckpt')
50
+ model.load_state_dict(torch.load(model_path, map_location=map_loc))
51
+ model.to(device)
52
+ model.eval()
53
+ model.ingrs_only = True
54
+ model.recipe_only = False
55
+ print ('loaded model')
56
+ print ("Elapsed time:", time.time() -t)
57
+
58
+ transf_list_batch = []
59
+ transf_list_batch.append(transforms.ToTensor())
60
+ transf_list_batch.append(transforms.Normalize((0.485, 0.456, 0.406),
61
+ (0.229, 0.224, 0.225)))
62
+ to_input_transf = transforms.Compose(transf_list_batch)
63
+
64
+
65
+ greedy = True
66
+ beam = -1
67
+ temperature = 1.0
68
+
69
+ # import requests
70
+ # from io import BytesIO
71
+ # import random
72
+ # from collections import Counter
73
+ # use_urls = False # set to true to load images from demo_urls instead of those in test_imgs folder
74
+ # show_anyways = False #if True, it will show the recipe even if it's not valid
75
+ # image_folder = os.path.join(data_dir, 'demo_imgs')
76
+
77
+ # if not use_urls:
78
+ # demo_imgs = os.listdir(image_folder)
79
+ # random.shuffle(demo_imgs)
80
+
81
+ # demo_urls = ['https://food.fnr.sndimg.com/content/dam/images/food/fullset/2013/12/9/0/FNK_Cheesecake_s4x3.jpg.rend.hgtvcom.826.620.suffix/1387411272847.jpeg',
82
+ # 'https://www.196flavors.com/wp-content/uploads/2014/10/california-roll-3-FP.jpg']
83
+
84
+ files_path = glob.glob(f"{image_folder}/*/*/*.jpg")
85
+ print(f"total data: {len(files_path)}")
86
+
87
+ res = []
88
+ for idx, img_file in tqdm(enumerate(files_path)):
89
+ # if use_urls:
90
+ # response = requests.get(img_file)
91
+ # image = Image.open(BytesIO(response.content))
92
+ # else:
93
+ image = Image.open(img_file).convert('RGB')
94
+
95
+ transf_list = []
96
+ transf_list.append(transforms.Resize(256))
97
+ transf_list.append(transforms.CenterCrop(224))
98
+ transform = transforms.Compose(transf_list)
99
+
100
+ image_transf = transform(image)
101
+ image_tensor = to_input_transf(image_transf).unsqueeze(0).to(device)
102
+
103
+ # plt.imshow(image_transf)
104
+ # plt.axis('off')
105
+ # plt.show()
106
+ # plt.close()
107
+
108
+ with torch.no_grad():
109
+ outputs = model.sample(image_tensor, greedy=greedy,
110
+ temperature=temperature, beam=beam, true_ingrs=None)
111
+
112
+ ingr_ids = outputs['ingr_ids'].cpu().numpy()
113
+ print(ingr_ids)
114
+
115
+ outs = prepare_output(ingr_ids[0], ingrs_vocab)
116
+ # print(ingrs_vocab.idx2word)
117
+
118
+ print(outs)
119
+
120
+ # print ('Pic ' + str(idx+1) + ':')
121
+
122
+ # print ('\nIngredients:')
123
+ # print (', '.join(outs['ingrs']))
124
+
125
+ # print ('='*20)
126
+
127
+ res.append({
128
+ "id": img_file,
129
+ "ingredients": outs['ingrs']
130
+ })
131
+
132
+ with open(output_file, "wb") as fp: #Pickling
133
+ pickle.dump(res, fp)
src/model.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import random
6
+ import numpy as np
7
+ from src.modules.encoder import EncoderCNN, EncoderLabels
8
+ from src.modules.transformer_decoder import DecoderTransformer
9
+ from src.modules.multihead_attention import MultiheadAttention
10
+ from src.utils.metrics import softIoU, MaskedCrossEntropyCriterion
11
+ import pickle
12
+ import os
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+
15
+
16
+ def label2onehot(labels, pad_value):
17
+
18
+ # input labels to one hot vector
19
+ inp_ = torch.unsqueeze(labels, 2)
20
+ one_hot = torch.FloatTensor(labels.size(0), labels.size(1), pad_value + 1).zero_().to(device)
21
+ one_hot.scatter_(2, inp_, 1)
22
+ one_hot, _ = one_hot.max(dim=1)
23
+ # remove pad position
24
+ one_hot = one_hot[:, :-1]
25
+ # eos position is always 0
26
+ one_hot[:, 0] = 0
27
+
28
+ return one_hot
29
+
30
+
31
+ def mask_from_eos(ids, eos_value, mult_before=True):
32
+ mask = torch.ones(ids.size()).to(device).byte()
33
+ mask_aux = torch.ones(ids.size(0)).to(device).byte()
34
+
35
+ # find eos in ingredient prediction
36
+ for idx in range(ids.size(1)):
37
+ # force mask to have 1s in the first position to avoid division by 0 when predictions start with eos
38
+ if idx == 0:
39
+ continue
40
+ if mult_before:
41
+ mask[:, idx] = mask[:, idx] * mask_aux
42
+ mask_aux = mask_aux * (ids[:, idx] != eos_value)
43
+ else:
44
+ mask_aux = mask_aux * (ids[:, idx] != eos_value)
45
+ mask[:, idx] = mask[:, idx] * mask_aux
46
+ return mask
47
+
48
+
49
+ def get_model(args, ingr_vocab_size, instrs_vocab_size):
50
+
51
+ # build ingredients embedding
52
+ encoder_ingrs = EncoderLabels(args.embed_size, ingr_vocab_size,
53
+ args.dropout_encoder, scale_grad=False).to(device)
54
+ # build image model
55
+ encoder_image = EncoderCNN(args.embed_size, args.dropout_encoder, args.image_model)
56
+
57
+ decoder = DecoderTransformer(args.embed_size, instrs_vocab_size,
58
+ dropout=args.dropout_decoder_r, seq_length=args.maxseqlen,
59
+ num_instrs=args.maxnuminstrs,
60
+ attention_nheads=args.n_att, num_layers=args.transf_layers,
61
+ normalize_before=True,
62
+ normalize_inputs=False,
63
+ last_ln=False,
64
+ scale_embed_grad=False)
65
+
66
+ ingr_decoder = DecoderTransformer(args.embed_size, ingr_vocab_size, dropout=args.dropout_decoder_i,
67
+ seq_length=args.maxnumlabels,
68
+ num_instrs=1, attention_nheads=args.n_att_ingrs,
69
+ pos_embeddings=False,
70
+ num_layers=args.transf_layers_ingrs,
71
+ learned=False,
72
+ normalize_before=True,
73
+ normalize_inputs=True,
74
+ last_ln=True,
75
+ scale_embed_grad=False)
76
+ # recipe loss
77
+ criterion = MaskedCrossEntropyCriterion(ignore_index=[instrs_vocab_size-1], reduce=False)
78
+
79
+ # ingredients loss
80
+ label_loss = nn.BCELoss(reduce=False)
81
+ eos_loss = nn.BCELoss(reduce=False)
82
+
83
+ model = InverseCookingModel(encoder_ingrs, decoder, ingr_decoder, encoder_image,
84
+ crit=criterion, crit_ingr=label_loss, crit_eos=eos_loss,
85
+ pad_value=ingr_vocab_size-1,
86
+ ingrs_only=args.ingrs_only, recipe_only=args.recipe_only,
87
+ label_smoothing=args.label_smoothing_ingr)
88
+
89
+ return model
90
+
91
+
92
+ class InverseCookingModel(nn.Module):
93
+ def __init__(self, ingredient_encoder, recipe_decoder, ingr_decoder, image_encoder,
94
+ crit=None, crit_ingr=None, crit_eos=None,
95
+ pad_value=0, ingrs_only=True,
96
+ recipe_only=False, label_smoothing=0.0):
97
+
98
+ super(InverseCookingModel, self).__init__()
99
+
100
+ self.ingredient_encoder = ingredient_encoder
101
+ self.recipe_decoder = recipe_decoder
102
+ self.image_encoder = image_encoder
103
+ self.ingredient_decoder = ingr_decoder
104
+ self.crit = crit
105
+ self.crit_ingr = crit_ingr
106
+ self.pad_value = pad_value
107
+ self.ingrs_only = ingrs_only
108
+ self.recipe_only = recipe_only
109
+ self.crit_eos = crit_eos
110
+ self.label_smoothing = label_smoothing
111
+
112
+ def forward(self, img_inputs, captions, target_ingrs,
113
+ sample=False, keep_cnn_gradients=False):
114
+
115
+ if sample:
116
+ return self.sample(img_inputs, greedy=True)
117
+
118
+ targets = captions[:, 1:]
119
+ targets = targets.contiguous().view(-1)
120
+
121
+ img_features = self.image_encoder(img_inputs, keep_cnn_gradients)
122
+
123
+ losses = {}
124
+ target_one_hot = label2onehot(target_ingrs, self.pad_value)
125
+ target_one_hot_smooth = label2onehot(target_ingrs, self.pad_value)
126
+
127
+ # ingredient prediction
128
+ if not self.recipe_only:
129
+ target_one_hot_smooth[target_one_hot_smooth == 1] = (1-self.label_smoothing)
130
+ target_one_hot_smooth[target_one_hot_smooth == 0] = self.label_smoothing / target_one_hot_smooth.size(-1)
131
+
132
+ # decode ingredients with transformer
133
+ # autoregressive mode for ingredient decoder
134
+ ingr_ids, ingr_logits = self.ingredient_decoder.sample(None, None, greedy=True,
135
+ temperature=1.0, img_features=img_features,
136
+ first_token_value=0, replacement=False)
137
+
138
+ ingr_logits = torch.nn.functional.softmax(ingr_logits, dim=-1)
139
+
140
+ # find idxs for eos ingredient
141
+ # eos probability is the one assigned to the first position of the softmax
142
+ eos = ingr_logits[:, :, 0]
143
+ target_eos = ((target_ingrs == 0) ^ (target_ingrs == self.pad_value))
144
+
145
+ eos_pos = (target_ingrs == 0)
146
+ eos_head = ((target_ingrs != self.pad_value) & (target_ingrs != 0))
147
+
148
+ # select transformer steps to pool from
149
+ mask_perminv = mask_from_eos(target_ingrs, eos_value=0, mult_before=False)
150
+ ingr_probs = ingr_logits * mask_perminv.float().unsqueeze(-1)
151
+
152
+ ingr_probs, _ = torch.max(ingr_probs, dim=1)
153
+
154
+ # ignore predicted ingredients after eos in ground truth
155
+ ingr_ids[mask_perminv == 0] = self.pad_value
156
+
157
+ ingr_loss = self.crit_ingr(ingr_probs, target_one_hot_smooth)
158
+ ingr_loss = torch.mean(ingr_loss, dim=-1)
159
+
160
+ losses['ingr_loss'] = ingr_loss
161
+
162
+ # cardinality penalty
163
+ losses['card_penalty'] = torch.abs((ingr_probs*target_one_hot).sum(1) - target_one_hot.sum(1)) + \
164
+ torch.abs((ingr_probs*(1-target_one_hot)).sum(1))
165
+
166
+ eos_loss = self.crit_eos(eos, target_eos.float())
167
+
168
+ mult = 1/2
169
+ # eos loss is only computed for timesteps <= t_eos and equally penalizes 0s and 1s
170
+ losses['eos_loss'] = mult*(eos_loss * eos_pos.float()).sum(1) / (eos_pos.float().sum(1) + 1e-6) + \
171
+ mult*(eos_loss * eos_head.float()).sum(1) / (eos_head.float().sum(1) + 1e-6)
172
+ # iou
173
+ pred_one_hot = label2onehot(ingr_ids, self.pad_value)
174
+ # iou sample during training is computed using the true eos position
175
+ losses['iou'] = softIoU(pred_one_hot, target_one_hot)
176
+
177
+ if self.ingrs_only:
178
+ return losses
179
+
180
+ # encode ingredients
181
+ target_ingr_feats = self.ingredient_encoder(target_ingrs)
182
+ target_ingr_mask = mask_from_eos(target_ingrs, eos_value=0, mult_before=False)
183
+
184
+ target_ingr_mask = target_ingr_mask.float().unsqueeze(1)
185
+
186
+ outputs, ids = self.recipe_decoder(target_ingr_feats, target_ingr_mask, captions, img_features)
187
+
188
+ outputs = outputs[:, :-1, :].contiguous()
189
+ outputs = outputs.view(outputs.size(0) * outputs.size(1), -1)
190
+
191
+ loss = self.crit(outputs, targets)
192
+
193
+ losses['recipe_loss'] = loss
194
+
195
+ return losses
196
+
197
+ def sample(self, img_inputs, greedy=True, temperature=1.0, beam=-1, true_ingrs=None):
198
+
199
+ outputs = dict()
200
+
201
+ img_features = self.image_encoder(img_inputs)
202
+
203
+ if not self.recipe_only:
204
+ ingr_ids, ingr_probs = self.ingredient_decoder.sample(None, None, greedy=True, temperature=temperature,
205
+ beam=-1,
206
+ img_features=img_features, first_token_value=0,
207
+ replacement=False)
208
+
209
+ # mask ingredients after finding eos
210
+ sample_mask = mask_from_eos(ingr_ids, eos_value=0, mult_before=False)
211
+ ingr_ids[sample_mask == 0] = self.pad_value
212
+
213
+ outputs['ingr_ids'] = ingr_ids
214
+ outputs['ingr_probs'] = ingr_probs.data
215
+
216
+ mask = sample_mask
217
+ input_mask = mask.float().unsqueeze(1)
218
+ input_feats = self.ingredient_encoder(ingr_ids)
219
+
220
+ if self.ingrs_only:
221
+ return outputs
222
+
223
+ # option during sampling to use the real ingredients and not the predicted ones to infer the recipe
224
+ if true_ingrs is not None:
225
+ input_mask = mask_from_eos(true_ingrs, eos_value=0, mult_before=False)
226
+ true_ingrs[input_mask == 0] = self.pad_value
227
+ input_feats = self.ingredient_encoder(true_ingrs)
228
+ input_mask = input_mask.unsqueeze(1)
229
+
230
+ ids, probs = self.recipe_decoder.sample(input_feats, input_mask, greedy, temperature, beam, img_features, 0,
231
+ last_token_value=1)
232
+
233
+ outputs['recipe_probs'] = probs.data
234
+ outputs['recipe_ids'] = ids
235
+
236
+ return outputs
src/model1_inf.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import os
5
+ from src.args import get_parser
6
+ import pickle
7
+ from src.model import get_model
8
+ from torchvision import transforms
9
+ from src.utils.output_ing import prepare_output
10
+ from PIL import Image
11
+
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+ map_loc = None if torch.cuda.is_available() else 'cpu'
14
+
15
+ def im2ingr(image, ingrs_vocab, model):
16
+ transf_list_batch = []
17
+ transf_list_batch.append(transforms.ToTensor())
18
+ transf_list_batch.append(transforms.Normalize((0.485, 0.456, 0.406),
19
+ (0.229, 0.224, 0.225)))
20
+ to_input_transf = transforms.Compose(transf_list_batch)
21
+
22
+ greedy = True
23
+ beam = -1
24
+ temperature = 1.0
25
+
26
+ transf_list = []
27
+ transf_list.append(transforms.Resize(256))
28
+ transf_list.append(transforms.CenterCrop(224))
29
+ transform = transforms.Compose(transf_list)
30
+
31
+ image_transf = transform(image)
32
+ image_tensor = to_input_transf(image_transf).unsqueeze(0).to(device)
33
+
34
+ with torch.no_grad():
35
+ outputs = model.sample(image_tensor, greedy=greedy,
36
+ temperature=temperature, beam=beam, true_ingrs=None)
37
+
38
+ ingr_ids = outputs['ingr_ids'].cpu().numpy()
39
+ outs = prepare_output(ingr_ids[0], ingrs_vocab)
40
+
41
+ return outs['ingrs']
42
+
43
+
src/modules/encoder.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ from torchvision.models import resnet18, resnet50, resnet101, resnet152, vgg16, vgg19, inception_v3
4
+ import torch
5
+ import torch.nn as nn
6
+ import random
7
+ import numpy as np
8
+
9
+
10
+ class EncoderCNN(nn.Module):
11
+ def __init__(self, embed_size, dropout=0.5, image_model='resnet101', pretrained=True):
12
+ """Load the pretrained ResNet-152 and replace top fc layer."""
13
+ super(EncoderCNN, self).__init__()
14
+ resnet = globals()[image_model](pretrained=pretrained)
15
+ modules = list(resnet.children())[:-2] # delete the last fc layer.
16
+ self.resnet = nn.Sequential(*modules)
17
+
18
+ self.linear = nn.Sequential(nn.Conv2d(resnet.fc.in_features, embed_size, kernel_size=1, padding=0),
19
+ nn.Dropout2d(dropout))
20
+
21
+ def forward(self, images, keep_cnn_gradients=False):
22
+ """Extract feature vectors from input images."""
23
+
24
+ if keep_cnn_gradients:
25
+ raw_conv_feats = self.resnet(images)
26
+ else:
27
+ with torch.no_grad():
28
+ raw_conv_feats = self.resnet(images)
29
+ features = self.linear(raw_conv_feats)
30
+ features = features.view(features.size(0), features.size(1), -1)
31
+
32
+ return features
33
+
34
+
35
+ class EncoderLabels(nn.Module):
36
+ def __init__(self, embed_size, num_classes, dropout=0.5, embed_weights=None, scale_grad=False):
37
+
38
+ super(EncoderLabels, self).__init__()
39
+ embeddinglayer = nn.Embedding(num_classes, embed_size, padding_idx=num_classes-1, scale_grad_by_freq=scale_grad)
40
+ if embed_weights is not None:
41
+ embeddinglayer.weight.data.copy_(embed_weights)
42
+ self.pad_value = num_classes - 1
43
+ self.linear = embeddinglayer
44
+ self.dropout = dropout
45
+ self.embed_size = embed_size
46
+
47
+ def forward(self, x, onehot_flag=False):
48
+
49
+ if onehot_flag:
50
+ embeddings = torch.matmul(x, self.linear.weight)
51
+ else:
52
+ embeddings = self.linear(x)
53
+
54
+ embeddings = nn.functional.dropout(embeddings, p=self.dropout, training=self.training)
55
+ embeddings = embeddings.permute(0, 2, 1).contiguous()
56
+
57
+ return embeddings
src/modules/multihead_attention.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2017-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the LICENSE file in
5
+ # https://github.com/pytorch/fairseq. An additional grant of patent rights
6
+ # can be found in the PATENTS file in the same directory.
7
+
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import Parameter
12
+ import torch.nn.functional as F
13
+
14
+ from src.modules.utils import fill_with_neg_inf, get_incremental_state, set_incremental_state
15
+
16
+
17
+ class MultiheadAttention(nn.Module):
18
+ """Multi-headed attention.
19
+ See "Attention Is All You Need" for more details.
20
+ """
21
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True):
22
+ super().__init__()
23
+ self.embed_dim = embed_dim
24
+ self.num_heads = num_heads
25
+ self.dropout = dropout
26
+ self.head_dim = embed_dim // num_heads
27
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
28
+ self.scaling = self.head_dim**-0.5
29
+ self._mask = None
30
+
31
+ self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim))
32
+ if bias:
33
+ self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))
34
+ else:
35
+ self.register_parameter('in_proj_bias', None)
36
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
37
+
38
+ self.reset_parameters()
39
+
40
+ def reset_parameters(self):
41
+ nn.init.xavier_uniform_(self.in_proj_weight)
42
+ nn.init.xavier_uniform_(self.out_proj.weight)
43
+ if self.in_proj_bias is not None:
44
+ nn.init.constant_(self.in_proj_bias, 0.)
45
+ nn.init.constant_(self.out_proj.bias, 0.)
46
+
47
+ def forward(self, query, key, value, mask_future_timesteps=False,
48
+ key_padding_mask=None, incremental_state=None,
49
+ need_weights=True, static_kv=False):
50
+ """Input shape: Time x Batch x Channel
51
+ Self-attention can be implemented by passing in the same arguments for
52
+ query, key and value. Future timesteps can be masked with the
53
+ `mask_future_timesteps` argument. Padding elements can be excluded from
54
+ the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
55
+ batch x src_len, where padding elements are indicated by 1s.
56
+ """
57
+
58
+ qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
59
+ kv_same = key.data_ptr() == value.data_ptr()
60
+
61
+ tgt_len, bsz, embed_dim = query.size()
62
+ assert embed_dim == self.embed_dim
63
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
64
+ assert key.size() == value.size()
65
+
66
+ if incremental_state is not None:
67
+ saved_state = self._get_input_buffer(incremental_state)
68
+ if 'prev_key' in saved_state:
69
+ # previous time steps are cached - no need to recompute
70
+ # key and value if they are static
71
+ if static_kv:
72
+ assert kv_same and not qkv_same
73
+ key = value = None
74
+ else:
75
+ saved_state = None
76
+
77
+ if qkv_same:
78
+ # self-attention
79
+ q, k, v = self.in_proj_qkv(query)
80
+ elif kv_same:
81
+ # encoder-decoder attention
82
+ q = self.in_proj_q(query)
83
+ if key is None:
84
+ assert value is None
85
+ # this will allow us to concat it with previous value and get
86
+ # just get the previous value
87
+ k = v = q.new(0)
88
+ else:
89
+ k, v = self.in_proj_kv(key)
90
+ else:
91
+ q = self.in_proj_q(query)
92
+ k = self.in_proj_k(key)
93
+ v = self.in_proj_v(value)
94
+ q *= self.scaling
95
+
96
+ if saved_state is not None:
97
+ if 'prev_key' in saved_state:
98
+ k = torch.cat((saved_state['prev_key'], k), dim=0)
99
+ if 'prev_value' in saved_state:
100
+ v = torch.cat((saved_state['prev_value'], v), dim=0)
101
+ saved_state['prev_key'] = k
102
+ saved_state['prev_value'] = v
103
+ self._set_input_buffer(incremental_state, saved_state)
104
+
105
+ src_len = k.size(0)
106
+
107
+ if key_padding_mask is not None:
108
+ assert key_padding_mask.size(0) == bsz
109
+ assert key_padding_mask.size(1) == src_len
110
+
111
+ q = q.contiguous().view(tgt_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
112
+ k = k.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
113
+ v = v.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
114
+
115
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
116
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
117
+
118
+ # only apply masking at training time (when incremental state is None)
119
+ if mask_future_timesteps and incremental_state is None:
120
+ assert query.size() == key.size(), \
121
+ 'mask_future_timesteps only applies to self-attention'
122
+ attn_weights += self.buffered_mask(attn_weights).unsqueeze(0)
123
+ if key_padding_mask is not None:
124
+ # don't attend to padding symbols
125
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
126
+ attn_weights = attn_weights.float().masked_fill(
127
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
128
+ float('-inf'),
129
+ ).type_as(attn_weights) # FP16 support: cast to float and back
130
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
131
+
132
+ attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
133
+ attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
134
+
135
+ attn = torch.bmm(attn_weights, v)
136
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
137
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
138
+ attn = self.out_proj(attn)
139
+
140
+ # average attention weights over heads
141
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
142
+ attn_weights = attn_weights.sum(dim=1) / self.num_heads
143
+
144
+ return attn, attn_weights
145
+
146
+ def in_proj_qkv(self, query):
147
+ return self._in_proj(query).chunk(3, dim=-1)
148
+
149
+ def in_proj_kv(self, key):
150
+ return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
151
+
152
+ def in_proj_q(self, query):
153
+ return self._in_proj(query, end=self.embed_dim)
154
+
155
+ def in_proj_k(self, key):
156
+ return self._in_proj(key, start=self.embed_dim, end=2*self.embed_dim)
157
+
158
+ def in_proj_v(self, value):
159
+ return self._in_proj(value, start=2*self.embed_dim)
160
+
161
+ def _in_proj(self, input, start=None, end=None):
162
+ weight = self.in_proj_weight
163
+ bias = self.in_proj_bias
164
+ if end is not None:
165
+ weight = weight[:end, :]
166
+ if bias is not None:
167
+ bias = bias[:end]
168
+ if start is not None:
169
+ weight = weight[start:, :]
170
+ if bias is not None:
171
+ bias = bias[start:]
172
+ return F.linear(input, weight, bias)
173
+
174
+ def buffered_mask(self, tensor):
175
+ dim = tensor.size(-1)
176
+ if self._mask is None:
177
+ self._mask = torch.triu(fill_with_neg_inf(tensor.new(dim, dim)), 1)
178
+ if self._mask.size(0) < dim:
179
+ self._mask = torch.triu(fill_with_neg_inf(self._mask.resize_(dim, dim)), 1)
180
+ return self._mask[:dim, :dim]
181
+
182
+ def reorder_incremental_state(self, incremental_state, new_order):
183
+ """Reorder buffered internal state (for incremental generation)."""
184
+ input_buffer = self._get_input_buffer(incremental_state)
185
+ if input_buffer is not None:
186
+ for k in input_buffer.keys():
187
+ input_buffer[k] = input_buffer[k].index_select(1, new_order)
188
+ self._set_input_buffer(incremental_state, input_buffer)
189
+
190
+ def _get_input_buffer(self, incremental_state):
191
+ return get_incremental_state(
192
+ self,
193
+ incremental_state,
194
+ 'attn_state',
195
+ ) or {}
196
+
197
+ def _set_input_buffer(self, incremental_state, buffer):
198
+ set_incremental_state(
199
+ self,
200
+ incremental_state,
201
+ 'attn_state',
202
+ buffer,
203
+ )
src/modules/transformer_decoder.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ # Code adapted from https://github.com/pytorch/fairseq
4
+ # Copyright (c) 2017-present, Facebook, Inc.
5
+ # All rights reserved.
6
+ #
7
+ # This source code is licensed under the license found in the LICENSE file in
8
+ # https://github.com/pytorch/fairseq. An additional grant of patent rights
9
+ # can be found in the PATENTS file in the same directory.
10
+
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from torch.nn.modules.utils import _single
16
+ import src.modules.utils as utils
17
+ from src.modules.multihead_attention import MultiheadAttention
18
+ import numpy as np
19
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
+ import copy
21
+
22
+
23
+ def make_positions(tensor, padding_idx, left_pad):
24
+ """Replace non-padding symbols with their position numbers.
25
+ Position numbers begin at padding_idx+1.
26
+ Padding symbols are ignored, but it is necessary to specify whether padding
27
+ is added on the left side (left_pad=True) or right side (left_pad=False).
28
+ """
29
+
30
+ # creates tensor from scratch - to avoid multigpu issues
31
+ max_pos = padding_idx + 1 + tensor.size(1)
32
+ #if not hasattr(make_positions, 'range_buf'):
33
+ range_buf = tensor.new()
34
+ #make_positions.range_buf = make_positions.range_buf.type_as(tensor)
35
+ if range_buf.numel() < max_pos:
36
+ torch.arange(padding_idx + 1, max_pos, out=range_buf)
37
+ mask = tensor.ne(padding_idx)
38
+ positions = range_buf[:tensor.size(1)].expand_as(tensor)
39
+ if left_pad:
40
+ positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
41
+
42
+ out = tensor.clone()
43
+ out = out.masked_scatter_(mask,positions[mask])
44
+ return out
45
+
46
+
47
+ class LearnedPositionalEmbedding(nn.Embedding):
48
+ """This module learns positional embeddings up to a fixed maximum size.
49
+ Padding symbols are ignored, but it is necessary to specify whether padding
50
+ is added on the left side (left_pad=True) or right side (left_pad=False).
51
+ """
52
+
53
+ def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad):
54
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
55
+ self.left_pad = left_pad
56
+ nn.init.normal_(self.weight, mean=0, std=embedding_dim ** -0.5)
57
+
58
+ def forward(self, input, incremental_state=None):
59
+ """Input is expected to be of size [bsz x seqlen]."""
60
+ if incremental_state is not None:
61
+ # positions is the same for every token when decoding a single step
62
+
63
+ positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1))
64
+ else:
65
+
66
+ positions = make_positions(input.data, self.padding_idx, self.left_pad)
67
+ return super().forward(positions)
68
+
69
+ def max_positions(self):
70
+ """Maximum number of supported positions."""
71
+ return self.num_embeddings - self.padding_idx - 1
72
+
73
+ class SinusoidalPositionalEmbedding(nn.Module):
74
+ """This module produces sinusoidal positional embeddings of any length.
75
+ Padding symbols are ignored, but it is necessary to specify whether padding
76
+ is added on the left side (left_pad=True) or right side (left_pad=False).
77
+ """
78
+
79
+ def __init__(self, embedding_dim, padding_idx, left_pad, init_size=1024):
80
+ super().__init__()
81
+ self.embedding_dim = embedding_dim
82
+ self.padding_idx = padding_idx
83
+ self.left_pad = left_pad
84
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
85
+ init_size,
86
+ embedding_dim,
87
+ padding_idx,
88
+ )
89
+ self.register_buffer('_float_tensor', torch.FloatTensor())
90
+
91
+ @staticmethod
92
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
93
+ """Build sinusoidal embeddings.
94
+ This matches the implementation in tensor2tensor, but differs slightly
95
+ from the description in Section 3.5 of "Attention Is All You Need".
96
+ """
97
+ half_dim = embedding_dim // 2
98
+ emb = math.log(10000) / (half_dim - 1)
99
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
100
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
101
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
102
+ if embedding_dim % 2 == 1:
103
+ # zero pad
104
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
105
+ if padding_idx is not None:
106
+ emb[padding_idx, :] = 0
107
+ return emb
108
+
109
+ def forward(self, input, incremental_state=None):
110
+ """Input is expected to be of size [bsz x seqlen]."""
111
+ # recompute/expand embeddings if needed
112
+ bsz, seq_len = input.size()
113
+ max_pos = self.padding_idx + 1 + seq_len
114
+ if self.weights is None or max_pos > self.weights.size(0):
115
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
116
+ max_pos,
117
+ self.embedding_dim,
118
+ self.padding_idx,
119
+ )
120
+ self.weights = self.weights.type_as(self._float_tensor)
121
+
122
+ if incremental_state is not None:
123
+ # positions is the same for every token when decoding a single step
124
+ return self.weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1)
125
+
126
+ positions = make_positions(input.data, self.padding_idx, self.left_pad)
127
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
128
+
129
+ def max_positions(self):
130
+ """Maximum number of supported positions."""
131
+ return int(1e5) # an arbitrary large number
132
+
133
+ class TransformerDecoderLayer(nn.Module):
134
+ """Decoder layer block."""
135
+
136
+ def __init__(self, embed_dim, n_att, dropout=0.5, normalize_before=True, last_ln=False):
137
+ super().__init__()
138
+
139
+ self.embed_dim = embed_dim
140
+ self.dropout = dropout
141
+ self.relu_dropout = dropout
142
+ self.normalize_before = normalize_before
143
+ num_layer_norm = 3
144
+
145
+ # self-attention on generated recipe
146
+ self.self_attn = MultiheadAttention(
147
+ self.embed_dim, n_att,
148
+ dropout=dropout,
149
+ )
150
+
151
+ self.cond_att = MultiheadAttention(
152
+ self.embed_dim, n_att,
153
+ dropout=dropout,
154
+ )
155
+
156
+ self.fc1 = Linear(self.embed_dim, self.embed_dim)
157
+ self.fc2 = Linear(self.embed_dim, self.embed_dim)
158
+ self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(num_layer_norm)])
159
+ self.use_last_ln = last_ln
160
+ if self.use_last_ln:
161
+ self.last_ln = LayerNorm(self.embed_dim)
162
+
163
+ def forward(self, x, ingr_features, ingr_mask, incremental_state, img_features):
164
+
165
+ # self attention
166
+ residual = x
167
+ x = self.maybe_layer_norm(0, x, before=True)
168
+ x, _ = self.self_attn(
169
+ query=x,
170
+ key=x,
171
+ value=x,
172
+ mask_future_timesteps=True,
173
+ incremental_state=incremental_state,
174
+ need_weights=False,
175
+ )
176
+ x = F.dropout(x, p=self.dropout, training=self.training)
177
+ x = residual + x
178
+ x = self.maybe_layer_norm(0, x, after=True)
179
+
180
+ residual = x
181
+ x = self.maybe_layer_norm(1, x, before=True)
182
+
183
+ # attention
184
+ if ingr_features is None:
185
+
186
+ x, _ = self.cond_att(query=x,
187
+ key=img_features,
188
+ value=img_features,
189
+ key_padding_mask=None,
190
+ incremental_state=incremental_state,
191
+ static_kv=True,
192
+ )
193
+ elif img_features is None:
194
+ x, _ = self.cond_att(query=x,
195
+ key=ingr_features,
196
+ value=ingr_features,
197
+ key_padding_mask=ingr_mask,
198
+ incremental_state=incremental_state,
199
+ static_kv=True,
200
+ )
201
+
202
+
203
+ else:
204
+ # attention on concatenation of encoder_out and encoder_aux, query self attn (x)
205
+ kv = torch.cat((img_features, ingr_features), 0)
206
+ mask = torch.cat((torch.zeros(img_features.shape[1], img_features.shape[0], dtype=torch.uint8).to(device),
207
+ ingr_mask), 1)
208
+ x, _ = self.cond_att(query=x,
209
+ key=kv,
210
+ value=kv,
211
+ key_padding_mask=mask,
212
+ incremental_state=incremental_state,
213
+ static_kv=True,
214
+ )
215
+ x = F.dropout(x, p=self.dropout, training=self.training)
216
+ x = residual + x
217
+ x = self.maybe_layer_norm(1, x, after=True)
218
+
219
+ residual = x
220
+ x = self.maybe_layer_norm(-1, x, before=True)
221
+ x = F.relu(self.fc1(x))
222
+ x = F.dropout(x, p=self.relu_dropout, training=self.training)
223
+ x = self.fc2(x)
224
+ x = F.dropout(x, p=self.dropout, training=self.training)
225
+ x = residual + x
226
+ x = self.maybe_layer_norm(-1, x, after=True)
227
+
228
+ if self.use_last_ln:
229
+ x = self.last_ln(x)
230
+
231
+ return x
232
+
233
+ def maybe_layer_norm(self, i, x, before=False, after=False):
234
+ assert before ^ after
235
+ if after ^ self.normalize_before:
236
+ return self.layer_norms[i](x)
237
+ else:
238
+ return x
239
+
240
+ class DecoderTransformer(nn.Module):
241
+ """Transformer decoder."""
242
+
243
+ def __init__(self, embed_size, vocab_size, dropout=0.5, seq_length=20, num_instrs=15,
244
+ attention_nheads=16, pos_embeddings=True, num_layers=8, learned=True, normalize_before=True,
245
+ normalize_inputs=False, last_ln=False, scale_embed_grad=False):
246
+ super(DecoderTransformer, self).__init__()
247
+ self.dropout = dropout
248
+ self.seq_length = seq_length * num_instrs
249
+ self.embed_tokens = nn.Embedding(vocab_size, embed_size, padding_idx=vocab_size-1,
250
+ scale_grad_by_freq=scale_embed_grad)
251
+ nn.init.normal_(self.embed_tokens.weight, mean=0, std=embed_size ** -0.5)
252
+ if pos_embeddings:
253
+ self.embed_positions = PositionalEmbedding(1024, embed_size, 0, left_pad=False, learned=learned)
254
+ else:
255
+ self.embed_positions = None
256
+ self.normalize_inputs = normalize_inputs
257
+ if self.normalize_inputs:
258
+ self.layer_norms_in = nn.ModuleList([LayerNorm(embed_size) for i in range(3)])
259
+
260
+ self.embed_scale = math.sqrt(embed_size)
261
+ self.layers = nn.ModuleList([])
262
+ self.layers.extend([
263
+ TransformerDecoderLayer(embed_size, attention_nheads, dropout=dropout, normalize_before=normalize_before,
264
+ last_ln=last_ln)
265
+ for i in range(num_layers)
266
+ ])
267
+
268
+ self.linear = Linear(embed_size, vocab_size-1)
269
+
270
+ def forward(self, ingr_features, ingr_mask, captions, img_features, incremental_state=None):
271
+
272
+ if ingr_features is not None:
273
+ ingr_features = ingr_features.permute(0, 2, 1)
274
+ ingr_features = ingr_features.transpose(0, 1)
275
+ if self.normalize_inputs:
276
+ self.layer_norms_in[0](ingr_features)
277
+
278
+ if img_features is not None:
279
+ img_features = img_features.permute(0, 2, 1)
280
+ img_features = img_features.transpose(0, 1)
281
+ if self.normalize_inputs:
282
+ self.layer_norms_in[1](img_features)
283
+
284
+ if ingr_mask is not None:
285
+ ingr_mask = (1-ingr_mask.squeeze(1)).byte()
286
+
287
+ # embed positions
288
+ if self.embed_positions is not None:
289
+ positions = self.embed_positions(captions, incremental_state=incremental_state)
290
+ if incremental_state is not None:
291
+ if self.embed_positions is not None:
292
+ positions = positions[:, -1:]
293
+ captions = captions[:, -1:]
294
+
295
+ # embed tokens and positions
296
+ x = self.embed_scale * self.embed_tokens(captions)
297
+
298
+ if self.embed_positions is not None:
299
+ x += positions
300
+
301
+ if self.normalize_inputs:
302
+ x = self.layer_norms_in[2](x)
303
+
304
+ x = F.dropout(x, p=self.dropout, training=self.training)
305
+
306
+ # B x T x C -> T x B x C
307
+ x = x.transpose(0, 1)
308
+
309
+ for p, layer in enumerate(self.layers):
310
+ x = layer(
311
+ x,
312
+ ingr_features,
313
+ ingr_mask,
314
+ incremental_state,
315
+ img_features
316
+ )
317
+
318
+ # T x B x C -> B x T x C
319
+ x = x.transpose(0, 1)
320
+
321
+ x = self.linear(x)
322
+ _, predicted = x.max(dim=-1)
323
+
324
+ return x, predicted
325
+
326
+ def sample(self, ingr_features, ingr_mask, greedy=True, temperature=1.0, beam=-1,
327
+ img_features=None, first_token_value=0,
328
+ replacement=True, last_token_value=0):
329
+
330
+ incremental_state = {}
331
+
332
+ # create dummy previous word
333
+ if ingr_features is not None:
334
+ fs = ingr_features.size(0)
335
+ else:
336
+ fs = img_features.size(0)
337
+
338
+ if beam != -1:
339
+ if fs == 1:
340
+ return self.sample_beam(ingr_features, ingr_mask, beam, img_features, first_token_value,
341
+ replacement, last_token_value)
342
+ else:
343
+ print ("Beam Search can only be used with batch size of 1. Running greedy or temperature sampling...")
344
+
345
+ first_word = torch.ones(fs)*first_token_value
346
+
347
+ first_word = first_word.to(device).long()
348
+ sampled_ids = [first_word]
349
+ logits = []
350
+
351
+ for i in range(self.seq_length):
352
+ # forward
353
+ outputs, _ = self.forward(ingr_features, ingr_mask, torch.stack(sampled_ids, 1),
354
+ img_features, incremental_state)
355
+ outputs = outputs.squeeze(1)
356
+ if not replacement:
357
+ # predicted mask
358
+ if i == 0:
359
+ predicted_mask = torch.zeros(outputs.shape).float().to(device)
360
+ else:
361
+ # ensure no repetitions in sampling if replacement==False
362
+ batch_ind = [j for j in range(fs) if sampled_ids[i][j] != 0]
363
+ sampled_ids_new = sampled_ids[i][batch_ind]
364
+ predicted_mask[batch_ind, sampled_ids_new] = float('-inf')
365
+
366
+ # mask previously selected ids
367
+ outputs += predicted_mask
368
+
369
+ logits.append(outputs)
370
+ if greedy:
371
+ outputs_prob = torch.nn.functional.softmax(outputs, dim=-1)
372
+ _, predicted = outputs_prob.max(1)
373
+ predicted = predicted.detach()
374
+ else:
375
+ k = 10
376
+ outputs_prob = torch.div(outputs.squeeze(1), temperature)
377
+ outputs_prob = torch.nn.functional.softmax(outputs_prob, dim=-1).data
378
+
379
+ # top k random sampling
380
+ prob_prev_topk, indices = torch.topk(outputs_prob, k=k, dim=1)
381
+ predicted = torch.multinomial(prob_prev_topk, 1).view(-1)
382
+ predicted = torch.index_select(indices, dim=1, index=predicted)[:, 0].detach()
383
+
384
+ sampled_ids.append(predicted)
385
+
386
+ sampled_ids = torch.stack(sampled_ids[1:], 1)
387
+ logits = torch.stack(logits, 1)
388
+
389
+ return sampled_ids, logits
390
+
391
+ def sample_beam(self, ingr_features, ingr_mask, beam=3, img_features=None, first_token_value=0,
392
+ replacement=True, last_token_value=0):
393
+ k = beam
394
+ alpha = 0.0
395
+ # create dummy previous word
396
+ if ingr_features is not None:
397
+ fs = ingr_features.size(0)
398
+ else:
399
+ fs = img_features.size(0)
400
+ first_word = torch.ones(fs)*first_token_value
401
+
402
+ first_word = first_word.to(device).long()
403
+
404
+ sequences = [[[first_word], 0, {}, False, 1]]
405
+ finished = []
406
+
407
+ for i in range(self.seq_length):
408
+ # forward
409
+ all_candidates = []
410
+ for rem in range(len(sequences)):
411
+ incremental = sequences[rem][2]
412
+ outputs, _ = self.forward(ingr_features, ingr_mask, torch.stack(sequences[rem][0], 1),
413
+ img_features, incremental)
414
+ outputs = outputs.squeeze(1)
415
+ if not replacement:
416
+ # predicted mask
417
+ if i == 0:
418
+ predicted_mask = torch.zeros(outputs.shape).float().to(device)
419
+ else:
420
+ # ensure no repetitions in sampling if replacement==False
421
+ batch_ind = [j for j in range(fs) if sequences[rem][0][i][j] != 0]
422
+ sampled_ids_new = sequences[rem][0][i][batch_ind]
423
+ predicted_mask[batch_ind, sampled_ids_new] = float('-inf')
424
+
425
+ # mask previously selected ids
426
+ outputs += predicted_mask
427
+
428
+ outputs_prob = torch.nn.functional.log_softmax(outputs, dim=-1)
429
+ probs, indices = torch.topk(outputs_prob, beam)
430
+ # tokens is [batch x beam ] and every element is a list
431
+ # score is [ batch x beam ] and every element is a scalar
432
+ # incremental is [batch x beam ] and every element is a dict
433
+
434
+
435
+ for bid in range(beam):
436
+ tokens = sequences[rem][0] + [indices[:, bid]]
437
+ score = sequences[rem][1] + probs[:, bid].squeeze().item()
438
+ if indices[:,bid].item() == last_token_value:
439
+ finished.append([tokens, score, None, True, sequences[rem][-1] + 1])
440
+ else:
441
+ all_candidates.append([tokens, score, incremental, False, sequences[rem][-1] + 1])
442
+
443
+ # if all the top-k scoring beams have finished, we can return them
444
+ ordered_all = sorted(all_candidates + finished, key=lambda tup: tup[1]/(np.power(tup[-1],alpha)),
445
+ reverse=True)[:k]
446
+ if all(el[-1] == True for el in ordered_all):
447
+ all_candidates = []
448
+
449
+ # order all candidates by score
450
+ ordered = sorted(all_candidates, key=lambda tup: tup[1]/(np.power(tup[-1],alpha)), reverse=True)
451
+ # select k best
452
+ sequences = ordered[:k]
453
+ finished = sorted(finished, key=lambda tup: tup[1]/(np.power(tup[-1],alpha)), reverse=True)[:k]
454
+
455
+ if len(finished) != 0:
456
+ sampled_ids = torch.stack(finished[0][0][1:], 1)
457
+ logits = finished[0][1]
458
+ else:
459
+ sampled_ids = torch.stack(sequences[0][0][1:], 1)
460
+ logits = sequences[0][1]
461
+ return sampled_ids, logits
462
+
463
+ def max_positions(self):
464
+ """Maximum output length supported by the decoder."""
465
+ return self.embed_positions.max_positions()
466
+
467
+ def upgrade_state_dict(self, state_dict):
468
+ if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
469
+ if 'decoder.embed_positions.weights' in state_dict:
470
+ del state_dict['decoder.embed_positions.weights']
471
+ if 'decoder.embed_positions._float_tensor' not in state_dict:
472
+ state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor()
473
+ return state_dict
474
+
475
+
476
+
477
+ def Embedding(num_embeddings, embedding_dim, padding_idx, ):
478
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
479
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
480
+ return m
481
+
482
+
483
+ def LayerNorm(embedding_dim):
484
+ m = nn.LayerNorm(embedding_dim)
485
+ return m
486
+
487
+
488
+ def Linear(in_features, out_features, bias=True):
489
+ m = nn.Linear(in_features, out_features, bias)
490
+ nn.init.xavier_uniform_(m.weight)
491
+ nn.init.constant_(m.bias, 0.)
492
+ return m
493
+
494
+
495
+ def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
496
+ if learned:
497
+ m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
498
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
499
+ nn.init.constant_(m.weight[padding_idx], 0)
500
+ else:
501
+ m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, num_embeddings)
502
+ return m
src/modules/utils.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ # Code adapted from https://github.com/pytorch/fairseq
4
+ # Copyright (c) 2017-present, Facebook, Inc.
5
+ # All rights reserved.
6
+ #
7
+ # This source code is licensed under the license found in the LICENSE file in
8
+ # https://github.com/pytorch/fairseq. An additional grant of patent rights
9
+ # can be found in the PATENTS file in the same directory.
10
+
11
+ from collections import defaultdict, OrderedDict
12
+ import logging
13
+ import os
14
+ import re
15
+ import torch
16
+ import traceback
17
+
18
+ from torch.serialization import default_restore_location
19
+
20
+
21
+ def torch_persistent_save(*args, **kwargs):
22
+ for i in range(3):
23
+ try:
24
+ return torch.save(*args, **kwargs)
25
+ except Exception:
26
+ if i == 2:
27
+ logging.error(traceback.format_exc())
28
+
29
+
30
+ def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
31
+ if isinstance(state_dict, dict):
32
+ cpu_dict = OrderedDict()
33
+ for k, v in state_dict.items():
34
+ cpu_dict[k] = convert_state_dict_type(v)
35
+ return cpu_dict
36
+ elif isinstance(state_dict, list):
37
+ return [convert_state_dict_type(v) for v in state_dict]
38
+ elif torch.is_tensor(state_dict):
39
+ return state_dict.type(ttype)
40
+ else:
41
+ return state_dict
42
+
43
+
44
+ def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
45
+ num_updates, optim_history=None, extra_state=None):
46
+ if optim_history is None:
47
+ optim_history = []
48
+ if extra_state is None:
49
+ extra_state = {}
50
+ state_dict = {
51
+ 'args': args,
52
+ 'model': convert_state_dict_type(model.state_dict()),
53
+ 'optimizer_history': optim_history + [
54
+ {
55
+ 'criterion_name': criterion.__class__.__name__,
56
+ 'optimizer_name': optimizer.__class__.__name__,
57
+ 'lr_scheduler_state': lr_scheduler.state_dict(),
58
+ 'num_updates': num_updates,
59
+ }
60
+ ],
61
+ 'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
62
+ 'extra_state': extra_state,
63
+ }
64
+ torch_persistent_save(state_dict, filename)
65
+
66
+
67
+ def load_model_state(filename, model):
68
+ if not os.path.exists(filename):
69
+ return None, [], None
70
+ state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
71
+ state = _upgrade_state_dict(state)
72
+ model.upgrade_state_dict(state['model'])
73
+
74
+ # load model parameters
75
+ try:
76
+ model.load_state_dict(state['model'], strict=True)
77
+ except Exception:
78
+ raise Exception('Cannot load model parameters from checkpoint, '
79
+ 'please ensure that the architectures match')
80
+
81
+ return state['extra_state'], state['optimizer_history'], state['last_optimizer_state']
82
+
83
+
84
+ def _upgrade_state_dict(state):
85
+ """Helper for upgrading old model checkpoints."""
86
+ # add optimizer_history
87
+ if 'optimizer_history' not in state:
88
+ state['optimizer_history'] = [
89
+ {
90
+ 'criterion_name': 'CrossEntropyCriterion',
91
+ 'best_loss': state['best_loss'],
92
+ },
93
+ ]
94
+ state['last_optimizer_state'] = state['optimizer']
95
+ del state['optimizer']
96
+ del state['best_loss']
97
+ # move extra_state into sub-dictionary
98
+ if 'epoch' in state and 'extra_state' not in state:
99
+ state['extra_state'] = {
100
+ 'epoch': state['epoch'],
101
+ 'batch_offset': state['batch_offset'],
102
+ 'val_loss': state['val_loss'],
103
+ }
104
+ del state['epoch']
105
+ del state['batch_offset']
106
+ del state['val_loss']
107
+ # reduce optimizer history's memory usage (only keep the last state)
108
+ if 'optimizer' in state['optimizer_history'][-1]:
109
+ state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer']
110
+ for optim_hist in state['optimizer_history']:
111
+ del optim_hist['optimizer']
112
+ # record the optimizer class name
113
+ if 'optimizer_name' not in state['optimizer_history'][-1]:
114
+ state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG'
115
+ # move best_loss into lr_scheduler_state
116
+ if 'lr_scheduler_state' not in state['optimizer_history'][-1]:
117
+ state['optimizer_history'][-1]['lr_scheduler_state'] = {
118
+ 'best': state['optimizer_history'][-1]['best_loss'],
119
+ }
120
+ del state['optimizer_history'][-1]['best_loss']
121
+ # keep track of number of updates
122
+ if 'num_updates' not in state['optimizer_history'][-1]:
123
+ state['optimizer_history'][-1]['num_updates'] = 0
124
+ # old model checkpoints may not have separate source/target positions
125
+ if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'):
126
+ state['args'].max_source_positions = state['args'].max_positions
127
+ state['args'].max_target_positions = state['args'].max_positions
128
+ # use stateful training data iterator
129
+ if 'train_iterator' not in state['extra_state']:
130
+ state['extra_state']['train_iterator'] = {
131
+ 'epoch': state['extra_state']['epoch'],
132
+ 'iterations_in_epoch': 0,
133
+ }
134
+ return state
135
+
136
+
137
+ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
138
+ """Load an ensemble of models for inference.
139
+ model_arg_overrides allows you to pass a dictionary model_arg_overrides --
140
+ {'arg_name': arg} -- to override model args that were used during model
141
+ training
142
+ """
143
+ # load model architectures and weights
144
+ states = []
145
+ for filename in filenames:
146
+ if not os.path.exists(filename):
147
+ raise IOError('Model file not found: {}'.format(filename))
148
+ state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
149
+ state = _upgrade_state_dict(state)
150
+ states.append(state)
151
+ args = states[0]['args']
152
+ if model_arg_overrides is not None:
153
+ args = _override_model_args(args, model_arg_overrides)
154
+
155
+ # build ensemble
156
+ ensemble = []
157
+ for state in states:
158
+ model = task.build_model(args)
159
+ model.upgrade_state_dict(state['model'])
160
+ model.load_state_dict(state['model'], strict=True)
161
+ ensemble.append(model)
162
+ return ensemble, args
163
+
164
+
165
+ def _override_model_args(args, model_arg_overrides):
166
+ # Uses model_arg_overrides {'arg_name': arg} to override model args
167
+ for arg_name, arg_val in model_arg_overrides.items():
168
+ setattr(args, arg_name, arg_val)
169
+ return args
170
+
171
+
172
+ def move_to_cuda(sample):
173
+ if len(sample) == 0:
174
+ return {}
175
+
176
+ def _move_to_cuda(maybe_tensor):
177
+ if torch.is_tensor(maybe_tensor):
178
+ return maybe_tensor.cuda()
179
+ elif isinstance(maybe_tensor, dict):
180
+ return {
181
+ key: _move_to_cuda(value)
182
+ for key, value in maybe_tensor.items()
183
+ }
184
+ elif isinstance(maybe_tensor, list):
185
+ return [_move_to_cuda(x) for x in maybe_tensor]
186
+ else:
187
+ return maybe_tensor
188
+
189
+ return _move_to_cuda(sample)
190
+
191
+
192
+ INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
193
+
194
+
195
+ def _get_full_incremental_state_key(module_instance, key):
196
+ module_name = module_instance.__class__.__name__
197
+
198
+ # assign a unique ID to each module instance, so that incremental state is
199
+ # not shared across module instances
200
+ if not hasattr(module_instance, '_fairseq_instance_id'):
201
+ INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
202
+ module_instance._fairseq_instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]
203
+
204
+ return '{}.{}.{}'.format(module_name, module_instance._fairseq_instance_id, key)
205
+
206
+
207
+ def get_incremental_state(module, incremental_state, key):
208
+ """Helper for getting incremental state for an nn.Module."""
209
+ full_key = _get_full_incremental_state_key(module, key)
210
+ if incremental_state is None or full_key not in incremental_state:
211
+ return None
212
+ return incremental_state[full_key]
213
+
214
+
215
+ def set_incremental_state(module, incremental_state, key, value):
216
+ """Helper for setting incremental state for an nn.Module."""
217
+ if incremental_state is not None:
218
+ full_key = _get_full_incremental_state_key(module, key)
219
+ incremental_state[full_key] = value
220
+
221
+
222
+ def load_align_dict(replace_unk):
223
+ if replace_unk is None:
224
+ align_dict = None
225
+ elif isinstance(replace_unk, str):
226
+ # Load alignment dictionary for unknown word replacement if it was passed as an argument.
227
+ align_dict = {}
228
+ with open(replace_unk, 'r') as f:
229
+ for line in f:
230
+ cols = line.split()
231
+ align_dict[cols[0]] = cols[1]
232
+ else:
233
+ # No alignment dictionary provided but we still want to perform unknown word replacement by copying the
234
+ # original source word.
235
+ align_dict = {}
236
+ return align_dict
237
+
238
+
239
+ def print_embed_overlap(embed_dict, vocab_dict):
240
+ embed_keys = set(embed_dict.keys())
241
+ vocab_keys = set(vocab_dict.symbols)
242
+ overlap = len(embed_keys & vocab_keys)
243
+ print("| Found {}/{} types in embedding file.".format(overlap, len(vocab_dict)))
244
+
245
+
246
+ def parse_embedding(embed_path):
247
+ """Parse embedding text file into a dictionary of word and embedding tensors.
248
+ The first line can have vocabulary size and dimension. The following lines
249
+ should contain word and embedding separated by spaces.
250
+ Example:
251
+ 2 5
252
+ the -0.0230 -0.0264 0.0287 0.0171 0.1403
253
+ at -0.0395 -0.1286 0.0275 0.0254 -0.0932
254
+ """
255
+ embed_dict = {}
256
+ with open(embed_path) as f_embed:
257
+ next(f_embed) # skip header
258
+ for line in f_embed:
259
+ pieces = line.rstrip().split(" ")
260
+ embed_dict[pieces[0]] = torch.Tensor([float(weight) for weight in pieces[1:]])
261
+ return embed_dict
262
+
263
+
264
+ def load_embedding(embed_dict, vocab, embedding):
265
+ for idx in range(len(vocab)):
266
+ token = vocab[idx]
267
+ if token in embed_dict:
268
+ embedding.weight.data[idx] = embed_dict[token]
269
+ return embedding
270
+
271
+
272
+ def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
273
+ from fairseq import tokenizer
274
+ # Tokens are strings here
275
+ hypo_tokens = tokenizer.tokenize_line(hypo_str)
276
+ # TODO: Very rare cases where the replacement is '<eos>' should be handled gracefully
277
+ src_tokens = tokenizer.tokenize_line(src_str) + ['<eos>']
278
+ for i, ht in enumerate(hypo_tokens):
279
+ if ht == unk:
280
+ src_token = src_tokens[alignment[i]]
281
+ # Either take the corresponding value in the aligned dictionary or just copy the original value.
282
+ hypo_tokens[i] = align_dict.get(src_token, src_token)
283
+ return ' '.join(hypo_tokens)
284
+
285
+
286
+ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dict, remove_bpe):
287
+ from fairseq import tokenizer
288
+ hypo_str = tgt_dict.string(hypo_tokens, remove_bpe)
289
+ if align_dict is not None:
290
+ hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, tgt_dict.unk_string())
291
+ if align_dict is not None or remove_bpe is not None:
292
+ # Convert back to tokens for evaluating with unk replacement or without BPE
293
+ # Note that the dictionary can be modified inside the method.
294
+ hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, tgt_dict, add_if_not_exist=True)
295
+ return hypo_tokens, hypo_str, alignment
296
+
297
+
298
+ def make_positions(tensor, padding_idx, left_pad):
299
+ """Replace non-padding symbols with their position numbers.
300
+ Position numbers begin at padding_idx+1.
301
+ Padding symbols are ignored, but it is necessary to specify whether padding
302
+ is added on the left side (left_pad=True) or right side (left_pad=False).
303
+ """
304
+ max_pos = padding_idx + 1 + tensor.size(1)
305
+ if not hasattr(make_positions, 'range_buf'):
306
+ make_positions.range_buf = tensor.new()
307
+ make_positions.range_buf = make_positions.range_buf.type_as(tensor)
308
+ if make_positions.range_buf.numel() < max_pos:
309
+ torch.arange(padding_idx + 1, max_pos, out=make_positions.range_buf)
310
+ mask = tensor.ne(padding_idx)
311
+ positions = make_positions.range_buf[:tensor.size(1)].expand_as(tensor)
312
+ if left_pad:
313
+ positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
314
+ return tensor.clone().masked_scatter_(mask, positions[mask])
315
+
316
+
317
+ def strip_pad(tensor, pad):
318
+ return tensor[tensor.ne(pad)]
319
+
320
+
321
+ def buffered_arange(max):
322
+ if not hasattr(buffered_arange, 'buf'):
323
+ buffered_arange.buf = torch.LongTensor()
324
+ if max > buffered_arange.buf.numel():
325
+ torch.arange(max, out=buffered_arange.buf)
326
+ return buffered_arange.buf[:max]
327
+
328
+
329
+ def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
330
+ assert right_to_left ^ left_to_right
331
+ pad_mask = src_tokens.eq(padding_idx)
332
+ if not pad_mask.any():
333
+ # no padding, return early
334
+ return src_tokens
335
+ if left_to_right and not pad_mask[:, 0].any():
336
+ # already right padded
337
+ return src_tokens
338
+ if right_to_left and not pad_mask[:, -1].any():
339
+ # already left padded
340
+ return src_tokens
341
+ max_len = src_tokens.size(1)
342
+ range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
343
+ num_pads = pad_mask.long().sum(dim=1, keepdim=True)
344
+ if right_to_left:
345
+ index = torch.remainder(range - num_pads, max_len)
346
+ else:
347
+ index = torch.remainder(range + num_pads, max_len)
348
+ return src_tokens.gather(1, index)
349
+
350
+
351
+ def item(tensor):
352
+ if hasattr(tensor, 'item'):
353
+ return tensor.item()
354
+ if hasattr(tensor, '__getitem__'):
355
+ return tensor[0]
356
+ return tensor
357
+
358
+
359
+ def clip_grad_norm_(tensor, max_norm):
360
+ grad_norm = item(torch.norm(tensor))
361
+ if grad_norm > max_norm > 0:
362
+ clip_coef = max_norm / (grad_norm + 1e-6)
363
+ tensor.mul_(clip_coef)
364
+ return grad_norm
365
+
366
+
367
+ def fill_with_neg_inf(t):
368
+ """FP16-compatible function that fills a tensor with -inf."""
369
+ return t.float().fill_(float('-inf')).type_as(t)
370
+
371
+
372
+ def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
373
+ """Retrieves all checkpoints found in `path` directory.
374
+ Checkpoints are identified by matching filename to the specified pattern. If
375
+ the pattern contains groups, the result will be sorted by the first group in
376
+ descending order.
377
+ """
378
+ pt_regexp = re.compile(pattern)
379
+ files = os.listdir(path)
380
+
381
+ entries = []
382
+ for i, f in enumerate(files):
383
+ m = pt_regexp.fullmatch(f)
384
+ if m is not None:
385
+ idx = int(m.group(1)) if len(m.groups()) > 0 else i
386
+ entries.append((idx, m.group(0)))
387
+ return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
src/read_pkl.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ with open("../data/predicted_ingr.pkl", "rb") as fp: # Unpickling
4
+ b = pickle.load(fp)
5
+
6
+ print(b)
7
+
src/sample.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ import torch
4
+ import numpy as np
5
+ from args import get_parser
6
+ import pickle
7
+ import os
8
+ from torchvision import transforms
9
+ from build_vocab import Vocabulary
10
+ from model import get_model
11
+ from tqdm import tqdm
12
+ from data_loader import get_loader
13
+ import json
14
+ import sys
15
+ from model import mask_from_eos
16
+ import random
17
+ from utils.metrics import softIoU, update_error_types, compute_metrics
18
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+ map_loc = None if torch.cuda.is_available() else 'cpu'
20
+
21
+
22
+ def compute_score(sampled_ids):
23
+
24
+ if 1 in sampled_ids:
25
+ cut = np.where(sampled_ids == 1)[0][0]
26
+ else:
27
+ cut = -1
28
+ sampled_ids = sampled_ids[0:cut]
29
+ score = float(len(set(sampled_ids))) / float(len(sampled_ids))
30
+
31
+ return score
32
+
33
+
34
+ def label2onehot(labels, pad_value):
35
+
36
+ # input labels to one hot vector
37
+ inp_ = torch.unsqueeze(labels, 2)
38
+ one_hot = torch.FloatTensor(labels.size(0), labels.size(1), pad_value + 1).zero_().to(device)
39
+ one_hot.scatter_(2, inp_, 1)
40
+ one_hot, _ = one_hot.max(dim=1)
41
+ # remove pad and eos position
42
+ one_hot = one_hot[:, 1:-1]
43
+ one_hot[:, 0] = 0
44
+
45
+ return one_hot
46
+
47
+
48
+ def main(args):
49
+
50
+ where_to_save = os.path.join(args.save_dir, args.project_name, args.model_name)
51
+ checkpoints_dir = os.path.join(where_to_save, 'checkpoints')
52
+ logs_dir = os.path.join(where_to_save, 'logs')
53
+
54
+ if not args.log_term:
55
+ print ("Eval logs will be saved to:", os.path.join(logs_dir, 'eval.log'))
56
+ sys.stdout = open(os.path.join(logs_dir, 'eval.log'), 'w')
57
+ sys.stderr = open(os.path.join(logs_dir, 'eval.err'), 'w')
58
+
59
+ vars_to_replace = ['greedy', 'recipe_only', 'ingrs_only', 'temperature', 'batch_size', 'maxseqlen',
60
+ 'get_perplexity', 'use_true_ingrs', 'eval_split', 'save_dir', 'aux_data_dir',
61
+ 'recipe1m_dir', 'project_name', 'use_lmdb', 'beam']
62
+ store_dict = {}
63
+ for var in vars_to_replace:
64
+ store_dict[var] = getattr(args, var)
65
+ args = pickle.load(open(os.path.join(checkpoints_dir, 'args.pkl'), 'rb'))
66
+ for var in vars_to_replace:
67
+ setattr(args, var, store_dict[var])
68
+ print (args)
69
+
70
+ transforms_list = []
71
+ transforms_list.append(transforms.Resize((args.crop_size)))
72
+ transforms_list.append(transforms.CenterCrop(args.crop_size))
73
+ transforms_list.append(transforms.ToTensor())
74
+ transforms_list.append(transforms.Normalize((0.485, 0.456, 0.406),
75
+ (0.229, 0.224, 0.225)))
76
+ # Image preprocessing
77
+ transform = transforms.Compose(transforms_list)
78
+
79
+ # data loader
80
+ data_dir = args.recipe1m_dir
81
+ data_loader, dataset = get_loader(data_dir, args.aux_data_dir, args.eval_split,
82
+ args.maxseqlen, args.maxnuminstrs, args.maxnumlabels,
83
+ args.maxnumims, transform, args.batch_size,
84
+ shuffle=False, num_workers=args.num_workers,
85
+ drop_last=False, max_num_samples=-1,
86
+ use_lmdb=args.use_lmdb, suff=args.suff)
87
+
88
+ ingr_vocab_size = dataset.get_ingrs_vocab_size()
89
+ instrs_vocab_size = dataset.get_instrs_vocab_size()
90
+
91
+ args.numgens = 1
92
+
93
+ # Build the model
94
+ model = get_model(args, ingr_vocab_size, instrs_vocab_size)
95
+ model_path = os.path.join(args.save_dir, args.project_name, args.model_name, 'checkpoints', 'modelbest.ckpt')
96
+
97
+ # overwrite flags for inference
98
+ model.recipe_only = args.recipe_only
99
+ model.ingrs_only = args.ingrs_only
100
+
101
+ # Load the trained model parameters
102
+ model.load_state_dict(torch.load(model_path, map_location=map_loc))
103
+
104
+ model.eval()
105
+ model = model.to(device)
106
+ results_dict = {'recipes': {}, 'ingrs': {}, 'ingr_iou': {}}
107
+ captions = {}
108
+ iou = []
109
+ error_types = {'tp_i': 0, 'fp_i': 0, 'fn_i': 0, 'tn_i': 0, 'tp_all': 0, 'fp_all': 0, 'fn_all': 0}
110
+ perplexity_list = []
111
+ n_rep, th = 0, 0.3
112
+
113
+ for i, (img_inputs, true_caps_batch, ingr_gt, imgid, impath) in tqdm(enumerate(data_loader)):
114
+
115
+ ingr_gt = ingr_gt.to(device)
116
+ true_caps_batch = true_caps_batch.to(device)
117
+
118
+ true_caps_shift = true_caps_batch.clone()[:, 1:].contiguous()
119
+ img_inputs = img_inputs.to(device)
120
+
121
+ true_ingrs = ingr_gt if args.use_true_ingrs else None
122
+ for gens in range(args.numgens):
123
+ with torch.no_grad():
124
+
125
+ if args.get_perplexity:
126
+
127
+ losses = model(img_inputs, true_caps_batch, ingr_gt, keep_cnn_gradients=False)
128
+ recipe_loss = losses['recipe_loss']
129
+ recipe_loss = recipe_loss.view(true_caps_shift.size())
130
+ non_pad_mask = true_caps_shift.ne(instrs_vocab_size - 1).float()
131
+ recipe_loss = torch.sum(recipe_loss*non_pad_mask, dim=-1) / torch.sum(non_pad_mask, dim=-1)
132
+ perplexity = torch.exp(recipe_loss)
133
+
134
+ perplexity = perplexity.detach().cpu().numpy().tolist()
135
+ perplexity_list.extend(perplexity)
136
+
137
+ else:
138
+
139
+ outputs = model.sample(img_inputs, args.greedy, args.temperature, args.beam, true_ingrs)
140
+
141
+ if not args.recipe_only:
142
+ fake_ingrs = outputs['ingr_ids']
143
+ pred_one_hot = label2onehot(fake_ingrs, ingr_vocab_size - 1)
144
+ target_one_hot = label2onehot(ingr_gt, ingr_vocab_size - 1)
145
+ iou_item = torch.mean(softIoU(pred_one_hot, target_one_hot)).item()
146
+ iou.append(iou_item)
147
+
148
+ update_error_types(error_types, pred_one_hot, target_one_hot)
149
+
150
+ fake_ingrs = fake_ingrs.detach().cpu().numpy()
151
+
152
+ for ingr_idx, fake_ingr in enumerate(fake_ingrs):
153
+
154
+ iou_item = softIoU(pred_one_hot[ingr_idx].unsqueeze(0),
155
+ target_one_hot[ingr_idx].unsqueeze(0)).item()
156
+ results_dict['ingrs'][imgid[ingr_idx]] = []
157
+ results_dict['ingrs'][imgid[ingr_idx]].append(fake_ingr)
158
+ results_dict['ingr_iou'][imgid[ingr_idx]] = iou_item
159
+
160
+ if not args.ingrs_only:
161
+ sampled_ids_batch = outputs['recipe_ids']
162
+ sampled_ids_batch = sampled_ids_batch.cpu().detach().numpy()
163
+
164
+ for j, sampled_ids in enumerate(sampled_ids_batch):
165
+ score = compute_score(sampled_ids)
166
+ if score < th:
167
+ n_rep += 1
168
+ if imgid[j] not in captions.keys():
169
+ results_dict['recipes'][imgid[j]] = []
170
+ results_dict['recipes'][imgid[j]].append(sampled_ids)
171
+ if args.get_perplexity:
172
+ print (len(perplexity_list))
173
+ print (np.mean(perplexity_list))
174
+ else:
175
+
176
+ if not args.recipe_only:
177
+ ret_metrics = {'accuracy': [], 'f1': [], 'jaccard': [], 'f1_ingredients': []}
178
+ compute_metrics(ret_metrics, error_types, ['accuracy', 'f1', 'jaccard', 'f1_ingredients'],
179
+ eps=1e-10,
180
+ weights=None)
181
+
182
+ for k, v in ret_metrics.items():
183
+ print (k, np.mean(v))
184
+
185
+ if args.greedy:
186
+ suff = 'greedy'
187
+ else:
188
+ if args.beam != -1:
189
+ suff = 'beam_'+str(args.beam)
190
+ else:
191
+ suff = 'temp_' + str(args.temperature)
192
+
193
+ results_file = os.path.join(args.save_dir, args.project_name, args.model_name, 'checkpoints',
194
+ args.eval_split + '_' + suff + '_gencaps.pkl')
195
+ print (results_file)
196
+ pickle.dump(results_dict, open(results_file, 'wb'))
197
+
198
+ print ("Number of samples with excessive repetitions:", n_rep)
199
+
200
+
201
+ if __name__ == '__main__':
202
+ args = get_parser()
203
+ torch.manual_seed(1234)
204
+ torch.cuda.manual_seed(1234)
205
+ random.seed(1234)
206
+ np.random.seed(1234)
207
+ main(args)
src/sim_ingr.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ import pickle
3
+ import argparse
4
+ from collections import Counter
5
+ import json
6
+ import os
7
+ from tqdm import *
8
+ import numpy as np
9
+ import re
10
+
11
+
12
+ def get_ingredient(det_ingr, replace_dict):
13
+ det_ingr_undrs = det_ingr['text'].lower()
14
+ det_ingr_undrs = ''.join(i for i in det_ingr_undrs if not i.isdigit())
15
+
16
+ for rep, char_list in replace_dict.items():
17
+ for c_ in char_list:
18
+ if c_ in det_ingr_undrs:
19
+ det_ingr_undrs = det_ingr_undrs.replace(c_, rep)
20
+ det_ingr_undrs = det_ingr_undrs.strip()
21
+ det_ingr_undrs = det_ingr_undrs.replace(' ', '_')
22
+
23
+ return det_ingr_undrs
24
+
25
+
26
+ def remove_plurals(counter_ingrs, ingr_clusters):
27
+ del_ingrs = []
28
+
29
+ for k, v in counter_ingrs.items():
30
+
31
+ if len(k) == 0:
32
+ del_ingrs.append(k)
33
+ continue
34
+
35
+ gotit = 0
36
+ if k[-2:] == 'es':
37
+ if k[:-2] in counter_ingrs.keys():
38
+ counter_ingrs[k[:-2]] += v
39
+ ingr_clusters[k[:-2]].extend(ingr_clusters[k])
40
+ del_ingrs.append(k)
41
+ gotit = 1
42
+
43
+ if k[-1] == 's' and gotit == 0:
44
+ if k[:-1] in counter_ingrs.keys():
45
+ counter_ingrs[k[:-1]] += v
46
+ ingr_clusters[k[:-1]].extend(ingr_clusters[k])
47
+ del_ingrs.append(k)
48
+ for item in del_ingrs:
49
+ del counter_ingrs[item]
50
+ del ingr_clusters[item]
51
+ return counter_ingrs, ingr_clusters
52
+
53
+
54
+ def cluster_ingredients(counter_ingrs):
55
+ mydict = dict()
56
+ mydict_ingrs = dict()
57
+
58
+ for k, v in counter_ingrs.items():
59
+
60
+ w1 = k.split('_')[-1]
61
+ w2 = k.split('_')[0]
62
+ lw = [w1, w2]
63
+ if len(k.split('_')) > 1:
64
+ w3 = k.split('_')[0] + '_' + k.split('_')[1]
65
+ w4 = k.split('_')[-2] + '_' + k.split('_')[-1]
66
+
67
+ lw = [w1, w2, w4, w3]
68
+
69
+ gotit = 0
70
+ for w in lw:
71
+ if w in counter_ingrs.keys():
72
+ # check if its parts are
73
+ parts = w.split('_')
74
+ if len(parts) > 0:
75
+ if parts[0] in counter_ingrs.keys():
76
+ w = parts[0]
77
+ elif parts[1] in counter_ingrs.keys():
78
+ w = parts[1]
79
+ if w in mydict.keys():
80
+ mydict[w] += v
81
+ mydict_ingrs[w].append(k)
82
+ else:
83
+ mydict[w] = v
84
+ mydict_ingrs[w] = [k]
85
+ gotit = 1
86
+ break
87
+ if gotit == 0:
88
+ mydict[k] = v
89
+ mydict_ingrs[k] = [k]
90
+
91
+ return mydict, mydict_ingrs
92
+
93
+
94
+ def update_counter(list_, counter_toks, istrain=False):
95
+ for sentence in list_:
96
+ tokens = nltk.tokenize.word_tokenize(sentence)
97
+ if istrain:
98
+ counter_toks.update(tokens)
99
+
100
+
101
+ def build_vocab_recipe1m(args):
102
+ print ("Loading data...")
103
+ dets = json.load(open(os.path.join(args.recipe1m_path, 'det_ingrs.json'), 'r'))
104
+
105
+ replace_dict_ingrs = {'and': ['&', "'n"], '': ['%', ',', '.', '#', '[', ']', '!', '?']}
106
+ replace_dict_instrs = {'and': ['&', "'n"], '': ['#', '[', ']']}
107
+
108
+ idx2ind = {}
109
+ for i, entry in enumerate(dets):
110
+ idx2ind[entry['id']] = i
111
+
112
+ ingrs_file = args.save_path + 'allingrs_count.pkl'
113
+ instrs_file = args.save_path + 'allwords_count.pkl'
114
+
115
+ # manually add missing entries for better clustering
116
+ base_words = ['peppers', 'tomato', 'spinach_leaves', 'turkey_breast', 'lettuce_leaf',
117
+ 'chicken_thighs', 'milk_powder', 'bread_crumbs', 'onion_flakes',
118
+ 'red_pepper', 'pepper_flakes', 'juice_concentrate', 'cracker_crumbs', 'hot_chili',
119
+ 'seasoning_mix', 'dill_weed', 'pepper_sauce', 'sprouts', 'cooking_spray', 'cheese_blend',
120
+ 'basil_leaves', 'pineapple_chunks', 'marshmallow', 'chile_powder',
121
+ 'cheese_blend', 'corn_kernels', 'tomato_sauce', 'chickens', 'cracker_crust',
122
+ 'lemonade_concentrate', 'red_chili', 'mushroom_caps', 'mushroom_cap', 'breaded_chicken',
123
+ 'frozen_pineapple', 'pineapple_chunks', 'seasoning_mix', 'seaweed', 'onion_flakes',
124
+ 'bouillon_granules', 'lettuce_leaf', 'stuffing_mix', 'parsley_flakes', 'chicken_breast',
125
+ 'basil_leaves', 'baguettes', 'green_tea', 'peanut_butter', 'green_onion', 'fresh_cilantro',
126
+ 'breaded_chicken', 'hot_pepper', 'dried_lavender', 'white_chocolate',
127
+ 'dill_weed', 'cake_mix', 'cheese_spread', 'turkey_breast', 'chucken_thighs', 'basil_leaves',
128
+ 'mandarin_orange', 'laurel', 'cabbage_head', 'pistachio', 'cheese_dip',
129
+ 'thyme_leave', 'boneless_pork', 'red_pepper', 'onion_dip', 'skinless_chicken', 'dark_chocolate',
130
+ 'canned_corn', 'muffin', 'cracker_crust', 'bread_crumbs', 'frozen_broccoli',
131
+ 'philadelphia', 'cracker_crust', 'chicken_breast']
132
+
133
+ for base_word in base_words:
134
+
135
+ if base_word not in counter_ingrs.keys():
136
+ counter_ingrs[base_word] = 1
137
+
138
+ counter_ingrs, cluster_ingrs = cluster_ingredients(counter_ingrs)
139
+ counter_ingrs, cluster_ingrs = remove_plurals(counter_ingrs, cluster_ingrs)
140
+
141
+ # If the word frequency is less than 'threshold', then the word is discarded.
142
+ words = [word for word, cnt in counter_toks.items() if cnt >= args.threshold_words]
143
+ ingrs = {word: cnt for word, cnt in counter_ingrs.items() if cnt >= args.threshold_ingrs}
144
+
145
+
146
+ def main(args):
147
+
148
+ vocab_ingrs, vocab_toks, dataset = build_vocab_recipe1m(args)
149
+
150
+ with open(os.path.join(args.save_path, args.suff+'recipe1m_vocab_ingrs.pkl'), 'wb') as f:
151
+ pickle.dump(vocab_ingrs, f)
152
+ with open(os.path.join(args.save_path, args.suff+'recipe1m_vocab_toks.pkl'), 'wb') as f:
153
+ pickle.dump(vocab_toks, f)
154
+
155
+ for split in dataset.keys():
156
+ with open(os.path.join(args.save_path, args.suff+'recipe1m_' + split + '.pkl'), 'wb') as f:
157
+ pickle.dump(dataset[split], f)
158
+
159
+
160
+ if __name__ == '__main__':
161
+
162
+ parser = argparse.ArgumentParser()
163
+ parser.add_argument('--recipe1m_path', type=str,
164
+ default='path/to/recipe1m',
165
+ help='recipe1m path')
166
+
167
+ parser.add_argument('--save_path', type=str, default='../data/',
168
+ help='path for saving vocabulary wrapper')
169
+
170
+ parser.add_argument('--suff', type=str, default='')
171
+
172
+ parser.add_argument('--threshold_ingrs', type=int, default=10,
173
+ help='minimum ingr count threshold')
174
+
175
+ parser.add_argument('--threshold_words', type=int, default=10,
176
+ help='minimum word count threshold')
177
+
178
+ parser.add_argument('--maxnuminstrs', type=int, default=20,
179
+ help='max number of instructions (sentences)')
180
+
181
+ parser.add_argument('--maxnumingrs', type=int, default=20,
182
+ help='max number of ingredients')
183
+
184
+ parser.add_argument('--minnuminstrs', type=int, default=2,
185
+ help='max number of instructions (sentences)')
186
+
187
+ parser.add_argument('--minnumingrs', type=int, default=2,
188
+ help='max number of ingredients')
189
+
190
+ parser.add_argument('--minnumwords', type=int, default=20,
191
+ help='minimum number of characters in recipe')
192
+
193
+ parser.add_argument('--forcegen', dest='forcegen', action='store_true')
194
+ parser.set_defaults(forcegen=False)
195
+
196
+ args = parser.parse_args()
197
+ main(args)
src/train.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ from args import get_parser
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.autograd as autograd
7
+ import numpy as np
8
+ import os
9
+ import random
10
+ import pickle
11
+ from data_loader import get_loader
12
+ from build_vocab import Vocabulary
13
+ from model import get_model
14
+ from torchvision import transforms
15
+ import sys
16
+ import json
17
+ import time
18
+ import torch.backends.cudnn as cudnn
19
+ from utils.tb_visualizer import Visualizer
20
+ from model import mask_from_eos, label2onehot
21
+ from utils.metrics import softIoU, compute_metrics, update_error_types
22
+ import random
23
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
+ map_loc = None if torch.cuda.is_available() else 'cpu'
25
+
26
+
27
+ def merge_models(args, model, ingr_vocab_size, instrs_vocab_size):
28
+ load_args = pickle.load(open(os.path.join(args.save_dir, args.project_name,
29
+ args.transfer_from, 'checkpoints/args.pkl'), 'rb'))
30
+
31
+ model_ingrs = get_model(load_args, ingr_vocab_size, instrs_vocab_size)
32
+ model_path = os.path.join(args.save_dir, args.project_name, args.transfer_from, 'checkpoints', 'modelbest.ckpt')
33
+
34
+ # Load the trained model parameters
35
+ model_ingrs.load_state_dict(torch.load(model_path, map_location=map_loc))
36
+ model.ingredient_decoder = model_ingrs.ingredient_decoder
37
+ args.transf_layers_ingrs = load_args.transf_layers_ingrs
38
+ args.n_att_ingrs = load_args.n_att_ingrs
39
+
40
+ return args, model
41
+
42
+
43
+ def save_model(model, optimizer, checkpoints_dir, suff=''):
44
+ if torch.cuda.device_count() > 1:
45
+ torch.save(model.module.state_dict(), os.path.join(
46
+ checkpoints_dir, 'model' + suff + '.ckpt'))
47
+
48
+ else:
49
+ torch.save(model.state_dict(), os.path.join(
50
+ checkpoints_dir, 'model' + suff + '.ckpt'))
51
+
52
+ torch.save(optimizer.state_dict(), os.path.join(
53
+ checkpoints_dir, 'optim' + suff + '.ckpt'))
54
+
55
+
56
+ def count_parameters(model):
57
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
58
+
59
+
60
+ def set_lr(optimizer, decay_factor):
61
+ for group in optimizer.param_groups:
62
+ group['lr'] = group['lr']*decay_factor
63
+
64
+
65
+ def make_dir(d):
66
+ if not os.path.exists(d):
67
+ os.makedirs(d)
68
+
69
+
70
+ def main(args):
71
+
72
+ # Create model directory & other aux folders for logging
73
+ where_to_save = os.path.join(args.save_dir, args.project_name, args.model_name)
74
+ checkpoints_dir = os.path.join(where_to_save, 'checkpoints')
75
+ logs_dir = os.path.join(where_to_save, 'logs')
76
+ tb_logs = os.path.join(args.save_dir, args.project_name, 'tb_logs', args.model_name)
77
+ make_dir(where_to_save)
78
+ make_dir(logs_dir)
79
+ make_dir(checkpoints_dir)
80
+ make_dir(tb_logs)
81
+ if args.tensorboard:
82
+ logger = Visualizer(tb_logs, name='visual_results')
83
+
84
+ # check if we want to resume from last checkpoint of current model
85
+ if args.resume:
86
+ args = pickle.load(open(os.path.join(checkpoints_dir, 'args.pkl'), 'rb'))
87
+ args.resume = True
88
+
89
+ # logs to disk
90
+ if not args.log_term:
91
+ print ("Training logs will be saved to:", os.path.join(logs_dir, 'train.log'))
92
+ sys.stdout = open(os.path.join(logs_dir, 'train.log'), 'w')
93
+ sys.stderr = open(os.path.join(logs_dir, 'train.err'), 'w')
94
+
95
+ print(args)
96
+ pickle.dump(args, open(os.path.join(checkpoints_dir, 'args.pkl'), 'wb'))
97
+
98
+ # patience init
99
+ curr_pat = 0
100
+
101
+ # Build data loader
102
+ data_loaders = {}
103
+ datasets = {}
104
+
105
+ data_dir = args.recipe1m_dir
106
+ for split in ['train', 'val']:
107
+
108
+ transforms_list = [transforms.Resize((args.image_size))]
109
+
110
+ if split == 'train':
111
+ # Image preprocessing, normalization for the pretrained resnet
112
+ transforms_list.append(transforms.RandomHorizontalFlip())
113
+ transforms_list.append(transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)))
114
+ transforms_list.append(transforms.RandomCrop(args.crop_size))
115
+
116
+ else:
117
+ transforms_list.append(transforms.CenterCrop(args.crop_size))
118
+ transforms_list.append(transforms.ToTensor())
119
+ transforms_list.append(transforms.Normalize((0.485, 0.456, 0.406),
120
+ (0.229, 0.224, 0.225)))
121
+
122
+ transform = transforms.Compose(transforms_list)
123
+ max_num_samples = max(args.max_eval, args.batch_size) if split == 'val' else -1
124
+ data_loaders[split], datasets[split] = get_loader(data_dir, args.aux_data_dir, split,
125
+ args.maxseqlen,
126
+ args.maxnuminstrs,
127
+ args.maxnumlabels,
128
+ args.maxnumims,
129
+ transform, args.batch_size,
130
+ shuffle=split == 'train', num_workers=args.num_workers,
131
+ drop_last=True,
132
+ max_num_samples=max_num_samples,
133
+ use_lmdb=args.use_lmdb,
134
+ suff=args.suff)
135
+
136
+ ingr_vocab_size = datasets[split].get_ingrs_vocab_size()
137
+ instrs_vocab_size = datasets[split].get_instrs_vocab_size()
138
+
139
+ # Build the model
140
+ model = get_model(args, ingr_vocab_size, instrs_vocab_size)
141
+ keep_cnn_gradients = False
142
+
143
+ decay_factor = 1.0
144
+
145
+ # add model parameters
146
+ if args.ingrs_only:
147
+ params = list(model.ingredient_decoder.parameters())
148
+ elif args.recipe_only:
149
+ params = list(model.recipe_decoder.parameters()) + list(model.ingredient_encoder.parameters())
150
+ else:
151
+ params = list(model.recipe_decoder.parameters()) + list(model.ingredient_decoder.parameters()) \
152
+ + list(model.ingredient_encoder.parameters())
153
+
154
+ # only train the linear layer in the encoder if we are not transfering from another model
155
+ if args.transfer_from == '':
156
+ params += list(model.image_encoder.linear.parameters())
157
+ params_cnn = list(model.image_encoder.resnet.parameters())
158
+
159
+ print ("CNN params:", sum(p.numel() for p in params_cnn if p.requires_grad))
160
+ print ("decoder params:", sum(p.numel() for p in params if p.requires_grad))
161
+ # start optimizing cnn from the beginning
162
+ if params_cnn is not None and args.finetune_after == 0:
163
+ optimizer = torch.optim.Adam([{'params': params}, {'params': params_cnn,
164
+ 'lr': args.learning_rate*args.scale_learning_rate_cnn}],
165
+ lr=args.learning_rate, weight_decay=args.weight_decay)
166
+ keep_cnn_gradients = True
167
+ print ("Fine tuning resnet")
168
+ else:
169
+ optimizer = torch.optim.Adam(params, lr=args.learning_rate)
170
+
171
+ if args.resume:
172
+ model_path = os.path.join(args.save_dir, args.project_name, args.model_name, 'checkpoints', 'model.ckpt')
173
+ optim_path = os.path.join(args.save_dir, args.project_name, args.model_name, 'checkpoints', 'optim.ckpt')
174
+ optimizer.load_state_dict(torch.load(optim_path, map_location=map_loc))
175
+ for state in optimizer.state.values():
176
+ for k, v in state.items():
177
+ if isinstance(v, torch.Tensor):
178
+ state[k] = v.to(device)
179
+ model.load_state_dict(torch.load(model_path, map_location=map_loc))
180
+
181
+ if args.transfer_from != '':
182
+ # loads CNN encoder from transfer_from model
183
+ model_path = os.path.join(args.save_dir, args.project_name, args.transfer_from, 'checkpoints', 'modelbest.ckpt')
184
+ pretrained_dict = torch.load(model_path, map_location=map_loc)
185
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if 'encoder' in k}
186
+ model.load_state_dict(pretrained_dict, strict=False)
187
+ args, model = merge_models(args, model, ingr_vocab_size, instrs_vocab_size)
188
+
189
+ if device != 'cpu' and torch.cuda.device_count() > 1:
190
+ model = nn.DataParallel(model)
191
+
192
+ model = model.to(device)
193
+ cudnn.benchmark = True
194
+
195
+ if not hasattr(args, 'current_epoch'):
196
+ args.current_epoch = 0
197
+
198
+ es_best = 10000 if args.es_metric == 'loss' else 0
199
+ # Train the model
200
+ start = args.current_epoch
201
+ for epoch in range(start, args.num_epochs):
202
+
203
+ # save current epoch for resuming
204
+ if args.tensorboard:
205
+ logger.reset()
206
+
207
+ args.current_epoch = epoch
208
+ # increase / decrase values for moving params
209
+ if args.decay_lr:
210
+ frac = epoch // args.lr_decay_every
211
+ decay_factor = args.lr_decay_rate ** frac
212
+ new_lr = args.learning_rate*decay_factor
213
+ print ('Epoch %d. lr: %.5f'%(epoch, new_lr))
214
+ set_lr(optimizer, decay_factor)
215
+
216
+ if args.finetune_after != -1 and args.finetune_after < epoch \
217
+ and not keep_cnn_gradients and params_cnn is not None:
218
+
219
+ print("Starting to fine tune CNN")
220
+ # start with learning rates as they were (if decayed during training)
221
+ optimizer = torch.optim.Adam([{'params': params},
222
+ {'params': params_cnn,
223
+ 'lr': decay_factor*args.learning_rate*args.scale_learning_rate_cnn}],
224
+ lr=decay_factor*args.learning_rate)
225
+ keep_cnn_gradients = True
226
+
227
+ for split in ['train', 'val']:
228
+
229
+ if split == 'train':
230
+ model.train()
231
+ else:
232
+ model.eval()
233
+ total_step = len(data_loaders[split])
234
+ loader = iter(data_loaders[split])
235
+
236
+ total_loss_dict = {'recipe_loss': [], 'ingr_loss': [],
237
+ 'eos_loss': [], 'loss': [],
238
+ 'iou': [], 'perplexity': [], 'iou_sample': [],
239
+ 'f1': [],
240
+ 'card_penalty': []}
241
+
242
+ error_types = {'tp_i': 0, 'fp_i': 0, 'fn_i': 0, 'tn_i': 0,
243
+ 'tp_all': 0, 'fp_all': 0, 'fn_all': 0}
244
+
245
+ torch.cuda.synchronize()
246
+ start = time.time()
247
+
248
+ for i in range(total_step):
249
+
250
+ img_inputs, captions, ingr_gt, img_ids, paths = loader.next()
251
+
252
+ ingr_gt = ingr_gt.to(device)
253
+ img_inputs = img_inputs.to(device)
254
+ captions = captions.to(device)
255
+ true_caps_batch = captions.clone()[:, 1:].contiguous()
256
+ loss_dict = {}
257
+
258
+ if split == 'val':
259
+ with torch.no_grad():
260
+ losses = model(img_inputs, captions, ingr_gt)
261
+
262
+ if not args.recipe_only:
263
+ outputs = model(img_inputs, captions, ingr_gt, sample=True)
264
+
265
+ ingr_ids_greedy = outputs['ingr_ids']
266
+
267
+ mask = mask_from_eos(ingr_ids_greedy, eos_value=0, mult_before=False)
268
+ ingr_ids_greedy[mask == 0] = ingr_vocab_size-1
269
+ pred_one_hot = label2onehot(ingr_ids_greedy, ingr_vocab_size-1)
270
+ target_one_hot = label2onehot(ingr_gt, ingr_vocab_size-1)
271
+ iou_sample = softIoU(pred_one_hot, target_one_hot)
272
+ iou_sample = iou_sample.sum() / (torch.nonzero(iou_sample.data).size(0) + 1e-6)
273
+ loss_dict['iou_sample'] = iou_sample.item()
274
+
275
+ update_error_types(error_types, pred_one_hot, target_one_hot)
276
+
277
+ del outputs, pred_one_hot, target_one_hot, iou_sample
278
+
279
+ else:
280
+ losses = model(img_inputs, captions, ingr_gt,
281
+ keep_cnn_gradients=keep_cnn_gradients)
282
+
283
+ if not args.ingrs_only:
284
+ recipe_loss = losses['recipe_loss']
285
+
286
+ recipe_loss = recipe_loss.view(true_caps_batch.size())
287
+ non_pad_mask = true_caps_batch.ne(instrs_vocab_size - 1).float()
288
+
289
+ recipe_loss = torch.sum(recipe_loss*non_pad_mask, dim=-1) / torch.sum(non_pad_mask, dim=-1)
290
+ perplexity = torch.exp(recipe_loss)
291
+
292
+ recipe_loss = recipe_loss.mean()
293
+ perplexity = perplexity.mean()
294
+
295
+ loss_dict['recipe_loss'] = recipe_loss.item()
296
+ loss_dict['perplexity'] = perplexity.item()
297
+ else:
298
+ recipe_loss = 0
299
+
300
+ if not args.recipe_only:
301
+
302
+ ingr_loss = losses['ingr_loss']
303
+ ingr_loss = ingr_loss.mean()
304
+ loss_dict['ingr_loss'] = ingr_loss.item()
305
+
306
+ eos_loss = losses['eos_loss']
307
+ eos_loss = eos_loss.mean()
308
+ loss_dict['eos_loss'] = eos_loss.item()
309
+
310
+ iou_seq = losses['iou']
311
+ iou_seq = iou_seq.mean()
312
+ loss_dict['iou'] = iou_seq.item()
313
+
314
+ card_penalty = losses['card_penalty'].mean()
315
+ loss_dict['card_penalty'] = card_penalty.item()
316
+ else:
317
+ ingr_loss, eos_loss, card_penalty = 0, 0, 0
318
+
319
+ loss = args.loss_weight[0] * recipe_loss + args.loss_weight[1] * ingr_loss \
320
+ + args.loss_weight[2]*eos_loss + args.loss_weight[3]*card_penalty
321
+
322
+ loss_dict['loss'] = loss.item()
323
+
324
+ for key in loss_dict.keys():
325
+ total_loss_dict[key].append(loss_dict[key])
326
+
327
+ if split == 'train':
328
+ model.zero_grad()
329
+ loss.backward()
330
+ optimizer.step()
331
+
332
+ # Print log info
333
+ if args.log_step != -1 and i % args.log_step == 0:
334
+ elapsed_time = time.time()-start
335
+ lossesstr = ""
336
+ for k in total_loss_dict.keys():
337
+ if len(total_loss_dict[k]) == 0:
338
+ continue
339
+ this_one = "%s: %.4f" % (k, np.mean(total_loss_dict[k][-args.log_step:]))
340
+ lossesstr += this_one + ', '
341
+ # this only displays nll loss on captions, the rest of losses will be in tensorboard logs
342
+ strtoprint = 'Split: %s, Epoch [%d/%d], Step [%d/%d], Losses: %sTime: %.4f' % (split, epoch,
343
+ args.num_epochs, i,
344
+ total_step,
345
+ lossesstr,
346
+ elapsed_time)
347
+ print(strtoprint)
348
+
349
+ if args.tensorboard:
350
+ # logger.histo_summary(model=model, step=total_step * epoch + i)
351
+ logger.scalar_summary(mode=split+'_iter', epoch=total_step*epoch+i,
352
+ **{k: np.mean(v[-args.log_step:]) for k, v in total_loss_dict.items() if v})
353
+
354
+ torch.cuda.synchronize()
355
+ start = time.time()
356
+ del loss, losses, captions, img_inputs
357
+
358
+ if split == 'val' and not args.recipe_only:
359
+ ret_metrics = {'accuracy': [], 'f1': [], 'jaccard': [], 'f1_ingredients': [], 'dice': []}
360
+ compute_metrics(ret_metrics, error_types,
361
+ ['accuracy', 'f1', 'jaccard', 'f1_ingredients', 'dice'], eps=1e-10,
362
+ weights=None)
363
+
364
+ total_loss_dict['f1'] = ret_metrics['f1']
365
+ if args.tensorboard:
366
+ # 1. Log scalar values (scalar summary)
367
+ logger.scalar_summary(mode=split,
368
+ epoch=epoch,
369
+ **{k: np.mean(v) for k, v in total_loss_dict.items() if v})
370
+
371
+ # Save the model's best checkpoint if performance was improved
372
+ es_value = np.mean(total_loss_dict[args.es_metric])
373
+
374
+ # save current model as well
375
+ save_model(model, optimizer, checkpoints_dir, suff='')
376
+ if (args.es_metric == 'loss' and es_value < es_best) or (args.es_metric == 'iou_sample' and es_value > es_best):
377
+ es_best = es_value
378
+ save_model(model, optimizer, checkpoints_dir, suff='best')
379
+ pickle.dump(args, open(os.path.join(checkpoints_dir, 'args.pkl'), 'wb'))
380
+ curr_pat = 0
381
+ print('Saved checkpoint.')
382
+ else:
383
+ curr_pat += 1
384
+
385
+ if curr_pat > args.patience:
386
+ break
387
+
388
+ if args.tensorboard:
389
+ logger.close()
390
+
391
+
392
+ if __name__ == '__main__':
393
+ args = get_parser()
394
+ torch.manual_seed(1234)
395
+ torch.cuda.manual_seed(1234)
396
+ random.seed(1234)
397
+ np.random.seed(1234)
398
+ main(args)
src/utils/ims2file.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ import pickle
4
+ from tqdm import tqdm
5
+ import os
6
+ import numpy as np
7
+ from PIL import Image
8
+ import argparse
9
+ import lmdb
10
+ from torchvision import transforms
11
+
12
+
13
+ MAX_SIZE = 1e12
14
+
15
+
16
+ def load_and_resize(root, path, imscale):
17
+
18
+ transf_list = []
19
+ transf_list.append(transforms.Resize(imscale))
20
+ transf_list.append(transforms.CenterCrop(imscale))
21
+ transform = transforms.Compose(transf_list)
22
+
23
+ img = Image.open(os.path.join(root, path[0], path[1], path[2], path[3], path)).convert('RGB')
24
+ img = transform(img)
25
+
26
+ return img
27
+
28
+
29
+ def main(args):
30
+
31
+ parts = {}
32
+ datasets = {}
33
+ imname2pos = {'train': {}, 'val': {}, 'test': {}}
34
+ for split in ['train', 'val', 'test']:
35
+ datasets[split] = pickle.load(open(os.path.join(args.save_dir, args.suff + 'recipe1m_' + split + '.pkl'), 'rb'))
36
+
37
+ parts[split] = lmdb.open(os.path.join(args.save_dir, 'lmdb_'+split), map_size=int(MAX_SIZE))
38
+ with parts[split].begin() as txn:
39
+ present_entries = [key for key, _ in txn.cursor()]
40
+ j = 0
41
+ for i, entry in tqdm(enumerate(datasets[split])):
42
+ impaths = entry['images'][0:5]
43
+
44
+ for n, p in enumerate(impaths):
45
+ if n == args.maxnumims:
46
+ break
47
+ if p.encode() not in present_entries:
48
+ im = load_and_resize(os.path.join(args.root, 'images', split), p, args.imscale)
49
+ im = np.array(im).astype(np.uint8)
50
+ with parts[split].begin(write=True) as txn:
51
+ txn.put(p.encode(), im)
52
+ imname2pos[split][p] = j
53
+ j += 1
54
+ pickle.dump(imname2pos, open(os.path.join(args.save_dir, 'imname2pos.pkl'), 'wb'))
55
+
56
+
57
+ def test(args):
58
+
59
+ imname2pos = pickle.load(open(os.path.join(args.save_dir, 'imname2pos.pkl'), 'rb'))
60
+ paths = imname2pos['val']
61
+
62
+ for k, v in paths.items():
63
+ path = k
64
+ break
65
+ image_file = lmdb.open(os.path.join(args.save_dir, 'lmdb_' + 'val'), max_readers=1, readonly=True,
66
+ lock=False, readahead=False, meminit=False)
67
+ with image_file.begin(write=False) as txn:
68
+ image = txn.get(path.encode())
69
+ image = np.fromstring(image, dtype=np.uint8)
70
+ image = np.reshape(image, (args.imscale, args.imscale, 3))
71
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
72
+ print (np.shape(image))
73
+
74
+
75
+ if __name__ == "__main__":
76
+
77
+ parser = argparse.ArgumentParser()
78
+ parser.add_argument('--root', type=str, default='path/to/recipe1m',
79
+ help='path to the recipe1m dataset')
80
+ parser.add_argument('--save_dir', type=str, default='../data',
81
+ help='path where the lmdbs will be saved')
82
+ parser.add_argument('--imscale', type=int, default=256,
83
+ help='size of images (will be rescaled and center cropped)')
84
+ parser.add_argument('--maxnumims', type=int, default=5,
85
+ help='maximum number of images to allow for each sample')
86
+ parser.add_argument('--suff', type=str, default='',
87
+ help='id of the vocabulary to use')
88
+ parser.add_argument('--test_only', dest='test_only', action='store_true')
89
+ parser.set_defaults(test_only=False)
90
+ args = parser.parse_args()
91
+
92
+ if not args.test_only:
93
+ main(args)
94
+ test(args)
src/utils/metrics.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ import sys
4
+ import time
5
+ import math
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn.modules.loss import _WeightedLoss
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ map_loc = None if torch.cuda.is_available() else 'cpu'
13
+
14
+
15
+ class MaskedCrossEntropyCriterion(_WeightedLoss):
16
+
17
+ def __init__(self, ignore_index=[-100], reduce=None):
18
+ super(MaskedCrossEntropyCriterion, self).__init__()
19
+ self.padding_idx = ignore_index
20
+ self.reduce = reduce
21
+
22
+ def forward(self, outputs, targets):
23
+ lprobs = nn.functional.log_softmax(outputs, dim=-1)
24
+ lprobs = lprobs.view(-1, lprobs.size(-1))
25
+
26
+ for idx in self.padding_idx:
27
+ # remove padding idx from targets to allow gathering without error (padded entries will be suppressed later)
28
+ targets[targets == idx] = 0
29
+
30
+ nll_loss = -lprobs.gather(dim=-1, index=targets.unsqueeze(1))
31
+ if self.reduce:
32
+ nll_loss = nll_loss.sum()
33
+
34
+ return nll_loss.squeeze()
35
+
36
+
37
+ def softIoU(out, target, e=1e-6, sum_axis=1):
38
+
39
+ num = (out*target).sum(sum_axis, True)
40
+ den = (out+target-out*target).sum(sum_axis, True) + e
41
+ iou = num / den
42
+
43
+ return iou
44
+
45
+
46
+ def update_error_types(error_types, y_pred, y_true):
47
+
48
+ error_types['tp_i'] += (y_pred * y_true).sum(0).cpu().data.numpy()
49
+ error_types['fp_i'] += (y_pred * (1-y_true)).sum(0).cpu().data.numpy()
50
+ error_types['fn_i'] += ((1-y_pred) * y_true).sum(0).cpu().data.numpy()
51
+ error_types['tn_i'] += ((1-y_pred) * (1-y_true)).sum(0).cpu().data.numpy()
52
+
53
+ error_types['tp_all'] += (y_pred * y_true).sum().item()
54
+ error_types['fp_all'] += (y_pred * (1-y_true)).sum().item()
55
+ error_types['fn_all'] += ((1-y_pred) * y_true).sum().item()
56
+
57
+
58
+ def compute_metrics(ret_metrics, error_types, metric_names, eps=1e-10, weights=None):
59
+
60
+ if 'accuracy' in metric_names:
61
+ ret_metrics['accuracy'].append(np.mean((error_types['tp_i'] + error_types['tn_i']) / (error_types['tp_i'] + error_types['fp_i'] + error_types['fn_i'] + error_types['tn_i'])))
62
+ if 'jaccard' in metric_names:
63
+ ret_metrics['jaccard'].append(error_types['tp_all'] / (error_types['tp_all'] + error_types['fp_all'] + error_types['fn_all'] + eps))
64
+ if 'dice' in metric_names:
65
+ ret_metrics['dice'].append(2*error_types['tp_all'] / (2*(error_types['tp_all'] + error_types['fp_all'] + error_types['fn_all']) + eps))
66
+ if 'f1' in metric_names:
67
+ pre = error_types['tp_i'] / (error_types['tp_i'] + error_types['fp_i'] + eps)
68
+ rec = error_types['tp_i'] / (error_types['tp_i'] + error_types['fn_i'] + eps)
69
+ f1_perclass = 2*(pre * rec) / (pre + rec + eps)
70
+ if 'f1_ingredients' not in ret_metrics.keys():
71
+ ret_metrics['f1_ingredients'] = [np.average(f1_perclass, weights=weights)]
72
+ else:
73
+ ret_metrics['f1_ingredients'].append(np.average(f1_perclass, weights=weights))
74
+
75
+ pre = error_types['tp_all'] / (error_types['tp_all'] + error_types['fp_all'] + eps)
76
+ rec = error_types['tp_all'] / (error_types['tp_all'] + error_types['fn_all'] + eps)
77
+ f1 = 2*(pre * rec) / (pre + rec + eps)
78
+ ret_metrics['f1'].append(f1)
src/utils/output_ing.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ replace_dict = {' .': '.',
4
+ ' ,': ',',
5
+ ' ;': ';',
6
+ ' :': ':',
7
+ '( ': '(',
8
+ ' )': ')',
9
+ " '": "'"}
10
+
11
+ def get_ingrs(ids, ingr_vocab_list):
12
+ gen_ingrs = []
13
+ for ingr_idx in ids:
14
+ ingr_name = ingr_vocab_list[ingr_idx]
15
+ if ingr_name == '<pad>':
16
+ break
17
+ gen_ingrs.append(ingr_name)
18
+ return gen_ingrs
19
+
20
+
21
+ def prepare_output(gen_ingrs, ingr_vocab_list):
22
+
23
+ if gen_ingrs is not None:
24
+ gen_ingrs = get_ingrs(gen_ingrs, ingr_vocab_list)
25
+
26
+ outs = {'ingrs': gen_ingrs}
27
+
28
+ return outs
src/utils/output_utils.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ replace_dict = {' .': '.',
4
+ ' ,': ',',
5
+ ' ;': ';',
6
+ ' :': ':',
7
+ '( ': '(',
8
+ ' )': ')',
9
+ " '": "'"}
10
+
11
+
12
+ def get_recipe(ids, vocab):
13
+ toks = []
14
+ for id_ in ids:
15
+ toks.append(vocab[id_])
16
+ return toks
17
+
18
+
19
+ def get_ingrs(ids, ingr_vocab_list):
20
+ gen_ingrs = []
21
+ for ingr_idx in ids:
22
+ ingr_name = ingr_vocab_list[ingr_idx]
23
+ if ingr_name == '<pad>':
24
+ break
25
+ gen_ingrs.append(ingr_name)
26
+ return gen_ingrs
27
+
28
+
29
+ def prettify(toks, replace_dict):
30
+ toks = ' '.join(toks)
31
+ toks = toks.split('<end>')[0]
32
+ sentences = toks.split('<eoi>')
33
+
34
+ pretty_sentences = []
35
+ for sentence in sentences:
36
+ sentence = sentence.strip()
37
+ sentence = sentence.capitalize()
38
+ for k, v in replace_dict.items():
39
+ sentence = sentence.replace(k, v)
40
+ if sentence != '':
41
+ pretty_sentences.append(sentence)
42
+ return pretty_sentences
43
+
44
+
45
+ def colorized_list(ingrs, ingrs_gt, colorize=False):
46
+ if colorize:
47
+ colorized_list = []
48
+ for word in ingrs:
49
+ if word in ingrs_gt:
50
+ word = '\033[1;30;42m ' + word + ' \x1b[0m'
51
+ else:
52
+ word = '\033[1;30;41m ' + word + ' \x1b[0m'
53
+ colorized_list.append(word)
54
+ return colorized_list
55
+ else:
56
+ return ingrs
57
+
58
+
59
+ def prepare_output(ids, gen_ingrs, ingr_vocab_list, vocab):
60
+
61
+ toks = get_recipe(ids, vocab)
62
+ is_valid = True
63
+ reason = 'All ok.'
64
+ try:
65
+ cut = toks.index('<end>')
66
+ toks_trunc = toks[0:cut]
67
+ except:
68
+ toks_trunc = toks
69
+ is_valid = False
70
+ reason = 'no eos found'
71
+
72
+ # repetition score
73
+ score = float(len(set(toks_trunc))) / float(len(toks_trunc))
74
+
75
+ prev_word = ''
76
+ found_repeat = False
77
+ for word in toks_trunc:
78
+ if prev_word == word and prev_word != '<eoi>':
79
+ found_repeat = True
80
+ break
81
+ prev_word = word
82
+
83
+ toks = prettify(toks, replace_dict)
84
+ title = toks[0]
85
+ toks = toks[1:]
86
+
87
+ if gen_ingrs is not None:
88
+ gen_ingrs = get_ingrs(gen_ingrs, ingr_vocab_list)
89
+
90
+ if score <= 0.3:
91
+ reason = 'Diversity score.'
92
+ is_valid = False
93
+ elif len(toks) != len(set(toks)):
94
+ reason = 'Repeated instructions.'
95
+ is_valid = False
96
+ elif found_repeat:
97
+ reason = 'Found word repeat.'
98
+ is_valid = False
99
+
100
+ valid = {'is_valid': is_valid, 'reason': reason, 'score': score}
101
+ outs = {'title': title, 'recipe': toks, 'ingrs': gen_ingrs}
102
+
103
+ return outs, valid
src/utils/tb_visualizer.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+
3
+ import numpy as np
4
+ import os
5
+ import ntpath
6
+ import time
7
+ import glob
8
+ from scipy.misc import imresize
9
+ import torchvision.utils as vutils
10
+ from operator import itemgetter
11
+ from tensorboardX import SummaryWriter
12
+
13
+
14
+ class Visualizer():
15
+ def __init__(self, checkpoints_dir, name):
16
+ self.win_size = 256
17
+ self.name = name
18
+ self.saved = False
19
+ self.checkpoints_dir = checkpoints_dir
20
+ self.ncols = 4
21
+
22
+ # remove existing
23
+ for filename in glob.glob(self.checkpoints_dir+"/events*"):
24
+ os.remove(filename)
25
+ self.writer = SummaryWriter(checkpoints_dir)
26
+
27
+ def reset(self):
28
+ self.saved = False
29
+
30
+ # images: (b, c, 0, 1) array of images
31
+ def image_summary(self, mode, epoch, images):
32
+ images = vutils.make_grid(images, normalize=True, scale_each=True)
33
+ self.writer.add_image('{}/Image'.format(mode), images, epoch)
34
+
35
+ # text: type: ingredients/recipe
36
+ def text_summary(self, mode, epoch, type, text, vocabulary, gt=True, max_length=20):
37
+ for i, el in enumerate(text): # text_list
38
+ if not gt: # we are printing a sample
39
+ idx = el.nonzero().squeeze() + 1
40
+ else:
41
+ idx = el # we are printing the ground truth
42
+
43
+ words_list = itemgetter(*idx)(vocabulary)
44
+
45
+ if len(words_list) <= max_length:
46
+ self.writer.add_text('{}/{}_{}_{}'.format(mode, type, i, 'gt' if gt else 'prediction'),
47
+ ', '.join(filter(lambda x: x != '<pad>', words_list)), epoch)
48
+ else:
49
+ self.writer.add_text('{}/{}_{}_{}'.format(mode, type, i, 'gt' if gt else 'prediction'),
50
+ 'Number of sampled ingredients is too big: {}'.format(len(words_list)), epoch)
51
+
52
+ # losses: dictionary of error labels and values
53
+ def scalar_summary(self, mode, epoch, **args):
54
+ for k, v in args.items():
55
+ self.writer.add_scalar('{}/{}'.format(mode, k), v, epoch)
56
+
57
+ self.writer.export_scalars_to_json("{}/tensorboard_all_scalars.json".format(self.checkpoints_dir))
58
+
59
+ def histo_summary(self, model, step):
60
+ """Log a histogram of the tensor of values."""
61
+
62
+ for name, param in model.named_parameters():
63
+ self.writer.add_histogram(name, param, step)
64
+
65
+ def close(self):
66
+ self.writer.close()